1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <memory>
6 #include <utility>
7 
8 #include "base/bind.h"
9 #include "base/bind_helpers.h"
10 #include "base/macros.h"
11 #include "base/memory/ptr_util.h"
12 #include "base/test/task_environment.h"
13 #include "base/threading/thread.h"
14 #include "media/learning/mojo/public/cpp/mojo_learning_task_controller.h"
15 #include "mojo/public/cpp/bindings/receiver.h"
16 #include "testing/gtest/include/gtest/gtest.h"
17 
18 namespace media {
19 namespace learning {
20 
21 class MojoLearningTaskControllerTest : public ::testing::Test {
22  public:
23   // Impl of a mojom::LearningTaskController that remembers call arguments.
24   class FakeMojoLearningTaskController : public mojom::LearningTaskController {
25    public:
BeginObservation(const base::UnguessableToken & id,const FeatureVector & features,const base::Optional<TargetValue> & default_target)26     void BeginObservation(
27         const base::UnguessableToken& id,
28         const FeatureVector& features,
29         const base::Optional<TargetValue>& default_target) override {
30       begin_args_.id_ = id;
31       begin_args_.features_ = features;
32       begin_args_.default_target_ = default_target;
33     }
34 
CompleteObservation(const base::UnguessableToken & id,const ObservationCompletion & completion)35     void CompleteObservation(const base::UnguessableToken& id,
36                              const ObservationCompletion& completion) override {
37       complete_args_.id_ = id;
38       complete_args_.completion_ = completion;
39     }
40 
CancelObservation(const base::UnguessableToken & id)41     void CancelObservation(const base::UnguessableToken& id) override {
42       cancel_args_.id_ = id;
43     }
44 
UpdateDefaultTarget(const base::UnguessableToken & id,const base::Optional<TargetValue> & default_target)45     void UpdateDefaultTarget(
46         const base::UnguessableToken& id,
47         const base::Optional<TargetValue>& default_target) override {
48       update_default_args_.id_ = id;
49       update_default_args_.default_target_ = default_target;
50     }
51 
PredictDistribution(const FeatureVector & features,PredictDistributionCallback callback)52     void PredictDistribution(const FeatureVector& features,
53                              PredictDistributionCallback callback) override {
54       predict_args_.features_ = features;
55       predict_args_.callback_ = std::move(callback);
56     }
57 
58     struct {
59       base::UnguessableToken id_;
60       FeatureVector features_;
61       base::Optional<TargetValue> default_target_;
62     } begin_args_;
63 
64     struct {
65       base::UnguessableToken id_;
66       ObservationCompletion completion_;
67     } complete_args_;
68 
69     struct {
70       base::UnguessableToken id_;
71     } cancel_args_;
72 
73     struct {
74       base::UnguessableToken id_;
75       base::Optional<TargetValue> default_target_;
76     } update_default_args_;
77 
78     struct {
79       FeatureVector features_;
80       PredictDistributionCallback callback_;
81     } predict_args_;
82   };
83 
84  public:
MojoLearningTaskControllerTest()85   MojoLearningTaskControllerTest()
86       : learning_controller_receiver_(&fake_learning_controller_) {}
87   ~MojoLearningTaskControllerTest() override = default;
88 
SetUp()89   void SetUp() override {
90     // Create a LearningTask.
91     task_.name = "MyLearningTask";
92 
93     // Tell |learning_controller_| to forward to the fake learner impl.
94     mojo::Remote<media::learning::mojom::LearningTaskController> remote(
95         learning_controller_receiver_.BindNewPipeAndPassRemote());
96     learning_controller_ =
97         std::make_unique<MojoLearningTaskController>(task_, std::move(remote));
98   }
99 
100   // Mojo stuff.
101   base::test::TaskEnvironment task_environment_;
102 
103   LearningTask task_;
104   FakeMojoLearningTaskController fake_learning_controller_;
105   mojo::Receiver<mojom::LearningTaskController> learning_controller_receiver_;
106 
107   // The learner under test.
108   std::unique_ptr<MojoLearningTaskController> learning_controller_;
109 };
110 
TEST_F(MojoLearningTaskControllerTest,GetLearningTask)111 TEST_F(MojoLearningTaskControllerTest, GetLearningTask) {
112   EXPECT_EQ(learning_controller_->GetLearningTask().name, task_.name);
113 }
114 
TEST_F(MojoLearningTaskControllerTest,BeginWithoutDefaultTarget)115 TEST_F(MojoLearningTaskControllerTest, BeginWithoutDefaultTarget) {
116   base::UnguessableToken id = base::UnguessableToken::Create();
117   FeatureVector features = {FeatureValue(123), FeatureValue(456)};
118   learning_controller_->BeginObservation(id, features, base::nullopt,
119                                          base::nullopt);
120   task_environment_.RunUntilIdle();
121   EXPECT_EQ(id, fake_learning_controller_.begin_args_.id_);
122   EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
123   EXPECT_FALSE(fake_learning_controller_.begin_args_.default_target_);
124 }
125 
TEST_F(MojoLearningTaskControllerTest,BeginWithDefaultTarget)126 TEST_F(MojoLearningTaskControllerTest, BeginWithDefaultTarget) {
127   base::UnguessableToken id = base::UnguessableToken::Create();
128   TargetValue default_target(987);
129   FeatureVector features = {FeatureValue(123), FeatureValue(456)};
130   learning_controller_->BeginObservation(id, features, default_target,
131                                          base::nullopt);
132   task_environment_.RunUntilIdle();
133   EXPECT_EQ(id, fake_learning_controller_.begin_args_.id_);
134   EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
135   EXPECT_EQ(default_target,
136             fake_learning_controller_.begin_args_.default_target_);
137 }
138 
TEST_F(MojoLearningTaskControllerTest,UpdateDefaultTargetToValue)139 TEST_F(MojoLearningTaskControllerTest, UpdateDefaultTargetToValue) {
140   // Test if we can update the default target to a non-nullopt.
141   base::UnguessableToken id = base::UnguessableToken::Create();
142   FeatureVector features = {FeatureValue(123), FeatureValue(456)};
143   learning_controller_->BeginObservation(id, features, base::nullopt,
144                                          base::nullopt);
145   TargetValue default_target(987);
146   learning_controller_->UpdateDefaultTarget(id, default_target);
147   task_environment_.RunUntilIdle();
148   EXPECT_EQ(id, fake_learning_controller_.update_default_args_.id_);
149   EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
150   EXPECT_EQ(default_target,
151             fake_learning_controller_.update_default_args_.default_target_);
152 }
153 
TEST_F(MojoLearningTaskControllerTest,UpdateDefaultTargetToNoValue)154 TEST_F(MojoLearningTaskControllerTest, UpdateDefaultTargetToNoValue) {
155   // Test if we can update the default target to nullopt.
156   base::UnguessableToken id = base::UnguessableToken::Create();
157   FeatureVector features = {FeatureValue(123), FeatureValue(456)};
158   TargetValue default_target(987);
159   learning_controller_->BeginObservation(id, features, default_target,
160                                          base::nullopt);
161   learning_controller_->UpdateDefaultTarget(id, base::nullopt);
162   task_environment_.RunUntilIdle();
163   EXPECT_EQ(id, fake_learning_controller_.update_default_args_.id_);
164   EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
165   EXPECT_EQ(base::nullopt,
166             fake_learning_controller_.update_default_args_.default_target_);
167 }
168 
TEST_F(MojoLearningTaskControllerTest,Complete)169 TEST_F(MojoLearningTaskControllerTest, Complete) {
170   base::UnguessableToken id = base::UnguessableToken::Create();
171   ObservationCompletion completion(TargetValue(1234));
172   learning_controller_->CompleteObservation(id, completion);
173   task_environment_.RunUntilIdle();
174   EXPECT_EQ(id, fake_learning_controller_.complete_args_.id_);
175   EXPECT_EQ(completion.target_value,
176             fake_learning_controller_.complete_args_.completion_.target_value);
177 }
178 
TEST_F(MojoLearningTaskControllerTest,Cancel)179 TEST_F(MojoLearningTaskControllerTest, Cancel) {
180   base::UnguessableToken id = base::UnguessableToken::Create();
181   learning_controller_->CancelObservation(id);
182   task_environment_.RunUntilIdle();
183   EXPECT_EQ(id, fake_learning_controller_.cancel_args_.id_);
184 }
185 
TEST_F(MojoLearningTaskControllerTest,PredictDistribution)186 TEST_F(MojoLearningTaskControllerTest, PredictDistribution) {
187   FeatureVector features = {FeatureValue(123), FeatureValue(456)};
188 
189   TargetHistogram observed_prediction;
190   learning_controller_->PredictDistribution(
191       features, base::BindOnce(
192                     [](TargetHistogram* test_storage,
193                        const base::Optional<TargetHistogram>& predicted) {
194                       *test_storage = *predicted;
195                     },
196                     &observed_prediction));
197   task_environment_.RunUntilIdle();
198   EXPECT_EQ(features, fake_learning_controller_.predict_args_.features_);
199   EXPECT_FALSE(fake_learning_controller_.predict_args_.callback_.is_null());
200 
201   TargetHistogram expected_prediction;
202   expected_prediction[TargetValue(1)] = 1.0;
203   expected_prediction[TargetValue(2)] = 2.0;
204   expected_prediction[TargetValue(3)] = 3.0;
205   std::move(fake_learning_controller_.predict_args_.callback_)
206       .Run(expected_prediction);
207   task_environment_.RunUntilIdle();
208   EXPECT_EQ(observed_prediction, expected_prediction);
209 }
210 
211 }  // namespace learning
212 }  // namespace media
213