1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <memory>
6 #include <utility>
7
8 #include "base/bind.h"
9 #include "base/bind_helpers.h"
10 #include "base/macros.h"
11 #include "base/memory/ptr_util.h"
12 #include "base/test/task_environment.h"
13 #include "base/threading/thread.h"
14 #include "media/learning/mojo/public/cpp/mojo_learning_task_controller.h"
15 #include "mojo/public/cpp/bindings/receiver.h"
16 #include "testing/gtest/include/gtest/gtest.h"
17
18 namespace media {
19 namespace learning {
20
21 class MojoLearningTaskControllerTest : public ::testing::Test {
22 public:
23 // Impl of a mojom::LearningTaskController that remembers call arguments.
24 class FakeMojoLearningTaskController : public mojom::LearningTaskController {
25 public:
BeginObservation(const base::UnguessableToken & id,const FeatureVector & features,const base::Optional<TargetValue> & default_target)26 void BeginObservation(
27 const base::UnguessableToken& id,
28 const FeatureVector& features,
29 const base::Optional<TargetValue>& default_target) override {
30 begin_args_.id_ = id;
31 begin_args_.features_ = features;
32 begin_args_.default_target_ = default_target;
33 }
34
CompleteObservation(const base::UnguessableToken & id,const ObservationCompletion & completion)35 void CompleteObservation(const base::UnguessableToken& id,
36 const ObservationCompletion& completion) override {
37 complete_args_.id_ = id;
38 complete_args_.completion_ = completion;
39 }
40
CancelObservation(const base::UnguessableToken & id)41 void CancelObservation(const base::UnguessableToken& id) override {
42 cancel_args_.id_ = id;
43 }
44
UpdateDefaultTarget(const base::UnguessableToken & id,const base::Optional<TargetValue> & default_target)45 void UpdateDefaultTarget(
46 const base::UnguessableToken& id,
47 const base::Optional<TargetValue>& default_target) override {
48 update_default_args_.id_ = id;
49 update_default_args_.default_target_ = default_target;
50 }
51
PredictDistribution(const FeatureVector & features,PredictDistributionCallback callback)52 void PredictDistribution(const FeatureVector& features,
53 PredictDistributionCallback callback) override {
54 predict_args_.features_ = features;
55 predict_args_.callback_ = std::move(callback);
56 }
57
58 struct {
59 base::UnguessableToken id_;
60 FeatureVector features_;
61 base::Optional<TargetValue> default_target_;
62 } begin_args_;
63
64 struct {
65 base::UnguessableToken id_;
66 ObservationCompletion completion_;
67 } complete_args_;
68
69 struct {
70 base::UnguessableToken id_;
71 } cancel_args_;
72
73 struct {
74 base::UnguessableToken id_;
75 base::Optional<TargetValue> default_target_;
76 } update_default_args_;
77
78 struct {
79 FeatureVector features_;
80 PredictDistributionCallback callback_;
81 } predict_args_;
82 };
83
84 public:
MojoLearningTaskControllerTest()85 MojoLearningTaskControllerTest()
86 : learning_controller_receiver_(&fake_learning_controller_) {}
87 ~MojoLearningTaskControllerTest() override = default;
88
SetUp()89 void SetUp() override {
90 // Create a LearningTask.
91 task_.name = "MyLearningTask";
92
93 // Tell |learning_controller_| to forward to the fake learner impl.
94 mojo::Remote<media::learning::mojom::LearningTaskController> remote(
95 learning_controller_receiver_.BindNewPipeAndPassRemote());
96 learning_controller_ =
97 std::make_unique<MojoLearningTaskController>(task_, std::move(remote));
98 }
99
100 // Mojo stuff.
101 base::test::TaskEnvironment task_environment_;
102
103 LearningTask task_;
104 FakeMojoLearningTaskController fake_learning_controller_;
105 mojo::Receiver<mojom::LearningTaskController> learning_controller_receiver_;
106
107 // The learner under test.
108 std::unique_ptr<MojoLearningTaskController> learning_controller_;
109 };
110
TEST_F(MojoLearningTaskControllerTest,GetLearningTask)111 TEST_F(MojoLearningTaskControllerTest, GetLearningTask) {
112 EXPECT_EQ(learning_controller_->GetLearningTask().name, task_.name);
113 }
114
TEST_F(MojoLearningTaskControllerTest,BeginWithoutDefaultTarget)115 TEST_F(MojoLearningTaskControllerTest, BeginWithoutDefaultTarget) {
116 base::UnguessableToken id = base::UnguessableToken::Create();
117 FeatureVector features = {FeatureValue(123), FeatureValue(456)};
118 learning_controller_->BeginObservation(id, features, base::nullopt,
119 base::nullopt);
120 task_environment_.RunUntilIdle();
121 EXPECT_EQ(id, fake_learning_controller_.begin_args_.id_);
122 EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
123 EXPECT_FALSE(fake_learning_controller_.begin_args_.default_target_);
124 }
125
TEST_F(MojoLearningTaskControllerTest,BeginWithDefaultTarget)126 TEST_F(MojoLearningTaskControllerTest, BeginWithDefaultTarget) {
127 base::UnguessableToken id = base::UnguessableToken::Create();
128 TargetValue default_target(987);
129 FeatureVector features = {FeatureValue(123), FeatureValue(456)};
130 learning_controller_->BeginObservation(id, features, default_target,
131 base::nullopt);
132 task_environment_.RunUntilIdle();
133 EXPECT_EQ(id, fake_learning_controller_.begin_args_.id_);
134 EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
135 EXPECT_EQ(default_target,
136 fake_learning_controller_.begin_args_.default_target_);
137 }
138
TEST_F(MojoLearningTaskControllerTest,UpdateDefaultTargetToValue)139 TEST_F(MojoLearningTaskControllerTest, UpdateDefaultTargetToValue) {
140 // Test if we can update the default target to a non-nullopt.
141 base::UnguessableToken id = base::UnguessableToken::Create();
142 FeatureVector features = {FeatureValue(123), FeatureValue(456)};
143 learning_controller_->BeginObservation(id, features, base::nullopt,
144 base::nullopt);
145 TargetValue default_target(987);
146 learning_controller_->UpdateDefaultTarget(id, default_target);
147 task_environment_.RunUntilIdle();
148 EXPECT_EQ(id, fake_learning_controller_.update_default_args_.id_);
149 EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
150 EXPECT_EQ(default_target,
151 fake_learning_controller_.update_default_args_.default_target_);
152 }
153
TEST_F(MojoLearningTaskControllerTest,UpdateDefaultTargetToNoValue)154 TEST_F(MojoLearningTaskControllerTest, UpdateDefaultTargetToNoValue) {
155 // Test if we can update the default target to nullopt.
156 base::UnguessableToken id = base::UnguessableToken::Create();
157 FeatureVector features = {FeatureValue(123), FeatureValue(456)};
158 TargetValue default_target(987);
159 learning_controller_->BeginObservation(id, features, default_target,
160 base::nullopt);
161 learning_controller_->UpdateDefaultTarget(id, base::nullopt);
162 task_environment_.RunUntilIdle();
163 EXPECT_EQ(id, fake_learning_controller_.update_default_args_.id_);
164 EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
165 EXPECT_EQ(base::nullopt,
166 fake_learning_controller_.update_default_args_.default_target_);
167 }
168
TEST_F(MojoLearningTaskControllerTest,Complete)169 TEST_F(MojoLearningTaskControllerTest, Complete) {
170 base::UnguessableToken id = base::UnguessableToken::Create();
171 ObservationCompletion completion(TargetValue(1234));
172 learning_controller_->CompleteObservation(id, completion);
173 task_environment_.RunUntilIdle();
174 EXPECT_EQ(id, fake_learning_controller_.complete_args_.id_);
175 EXPECT_EQ(completion.target_value,
176 fake_learning_controller_.complete_args_.completion_.target_value);
177 }
178
TEST_F(MojoLearningTaskControllerTest,Cancel)179 TEST_F(MojoLearningTaskControllerTest, Cancel) {
180 base::UnguessableToken id = base::UnguessableToken::Create();
181 learning_controller_->CancelObservation(id);
182 task_environment_.RunUntilIdle();
183 EXPECT_EQ(id, fake_learning_controller_.cancel_args_.id_);
184 }
185
TEST_F(MojoLearningTaskControllerTest,PredictDistribution)186 TEST_F(MojoLearningTaskControllerTest, PredictDistribution) {
187 FeatureVector features = {FeatureValue(123), FeatureValue(456)};
188
189 TargetHistogram observed_prediction;
190 learning_controller_->PredictDistribution(
191 features, base::BindOnce(
192 [](TargetHistogram* test_storage,
193 const base::Optional<TargetHistogram>& predicted) {
194 *test_storage = *predicted;
195 },
196 &observed_prediction));
197 task_environment_.RunUntilIdle();
198 EXPECT_EQ(features, fake_learning_controller_.predict_args_.features_);
199 EXPECT_FALSE(fake_learning_controller_.predict_args_.callback_.is_null());
200
201 TargetHistogram expected_prediction;
202 expected_prediction[TargetValue(1)] = 1.0;
203 expected_prediction[TargetValue(2)] = 2.0;
204 expected_prediction[TargetValue(3)] = 3.0;
205 std::move(fake_learning_controller_.predict_args_.callback_)
206 .Run(expected_prediction);
207 task_environment_.RunUntilIdle();
208 EXPECT_EQ(observed_prediction, expected_prediction);
209 }
210
211 } // namespace learning
212 } // namespace media
213