1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chromeos/services/machine_learning/public/cpp/service_connection.h"
6 
7 #include <utility>
8 #include <vector>
9 
10 #include "base/bind.h"
11 #include "base/macros.h"
12 #include "base/message_loop/message_pump_type.h"
13 #include "base/run_loop.h"
14 #include "base/test/task_environment.h"
15 #include "base/threading/thread.h"
16 #include "chromeos/dbus/machine_learning/machine_learning_client.h"
17 #include "chromeos/services/machine_learning/public/cpp/fake_service_connection.h"
18 #include "chromeos/services/machine_learning/public/mojom/graph_executor.mojom.h"
19 #include "chromeos/services/machine_learning/public/mojom/handwriting_recognizer.mojom.h"
20 #include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
21 #include "chromeos/services/machine_learning/public/mojom/model.mojom.h"
22 #include "chromeos/services/machine_learning/public/mojom/tensor.mojom.h"
23 #include "mojo/core/embedder/embedder.h"
24 #include "mojo/core/embedder/scoped_ipc_support.h"
25 #include "mojo/public/cpp/bindings/remote.h"
26 #include "testing/gtest/include/gtest/gtest.h"
27 
28 namespace chromeos {
29 namespace machine_learning {
30 namespace {
31 
32 class ServiceConnectionTest : public testing::Test {
33  public:
34   ServiceConnectionTest() = default;
35 
SetUp()36   void SetUp() override { MachineLearningClient::InitializeFake(); }
37 
TearDown()38   void TearDown() override { MachineLearningClient::Shutdown(); }
39 
40  protected:
SetUpTestCase()41   static void SetUpTestCase() {
42     static base::Thread ipc_thread("ipc");
43     ipc_thread.StartWithOptions(
44         base::Thread::Options(base::MessagePumpType::IO, 0));
45     static mojo::core::ScopedIPCSupport ipc_support(
46         ipc_thread.task_runner(),
47         mojo::core::ScopedIPCSupport::ShutdownPolicy::CLEAN);
48   }
49 
50  private:
51   base::test::TaskEnvironment task_environment_;
52 
53   DISALLOW_COPY_AND_ASSIGN(ServiceConnectionTest);
54 };
55 
56 // Tests that LoadBuiltinModel runs OK (no crash) in a basic Mojo
57 // environment.
TEST_F(ServiceConnectionTest,LoadBuiltinModel)58 TEST_F(ServiceConnectionTest, LoadBuiltinModel) {
59   mojo::Remote<mojom::Model> model;
60   mojom::BuiltinModelSpecPtr spec =
61       mojom::BuiltinModelSpec::New(mojom::BuiltinModelId::TEST_MODEL);
62   ServiceConnection::GetInstance()->LoadBuiltinModel(
63       std::move(spec), model.BindNewPipeAndPassReceiver(),
64       base::BindOnce([](mojom::LoadModelResult result) {}));
65 }
66 
67 // Tests that LoadFlatBufferModel runs OK (no crash) in a basic Mojo
68 // environment.
TEST_F(ServiceConnectionTest,LoadFlatBufferModel)69 TEST_F(ServiceConnectionTest, LoadFlatBufferModel) {
70   mojo::Remote<mojom::Model> model;
71   mojom::FlatBufferModelSpecPtr spec = mojom::FlatBufferModelSpec::New();
72   ServiceConnection::GetInstance()->LoadFlatBufferModel(
73       std::move(spec), model.BindNewPipeAndPassReceiver(),
74       base::BindOnce([](mojom::LoadModelResult result) {}));
75 }
76 
77 // Tests that LoadTextClassifier runs OK (no crash) in a basic Mojo
78 // environment.
TEST_F(ServiceConnectionTest,LoadTextClassifier)79 TEST_F(ServiceConnectionTest, LoadTextClassifier) {
80   mojo::Remote<mojom::TextClassifier> text_classifier;
81   ServiceConnection::GetInstance()->LoadTextClassifier(
82       text_classifier.BindNewPipeAndPassReceiver(),
83       base::BindOnce([](mojom::LoadModelResult result) {}));
84 }
85 
86 // Tests that LoadHandwritingModelWithSpec runs OK (no crash) in a basic Mojo
87 // environment.
TEST_F(ServiceConnectionTest,LoadHandwritingModelWithSpec)88 TEST_F(ServiceConnectionTest, LoadHandwritingModelWithSpec) {
89   mojo::Remote<mojom::HandwritingRecognizer> handwriting_recognizer;
90   ServiceConnection::GetInstance()->LoadHandwritingModelWithSpec(
91       mojom::HandwritingRecognizerSpec::New("en"),
92       handwriting_recognizer.BindNewPipeAndPassReceiver(),
93       base::BindOnce([](mojom::LoadModelResult result) {}));
94 }
95 
96 // Tests that LoadGrammarChecker runs OK (no crash) in a basic Mojo environment.
TEST_F(ServiceConnectionTest,LoadGrammarModel)97 TEST_F(ServiceConnectionTest, LoadGrammarModel) {
98   mojo::Remote<mojom::GrammarChecker> grammar_checker;
99   ServiceConnection::GetInstance()->LoadGrammarChecker(
100       grammar_checker.BindNewPipeAndPassReceiver(),
101       base::BindOnce([](mojom::LoadModelResult result) {}));
102 }
103 
104 // Tests the fake ML service for builtin model.
TEST_F(ServiceConnectionTest,FakeServiceConnectionForBuiltinModel)105 TEST_F(ServiceConnectionTest, FakeServiceConnectionForBuiltinModel) {
106   mojo::Remote<mojom::Model> model;
107   bool callback_done = false;
108   FakeServiceConnectionImpl fake_service_connection;
109   ServiceConnection::UseFakeServiceConnectionForTesting(
110       &fake_service_connection);
111 
112   const double expected_value = 200.002;
113   fake_service_connection.SetOutputValue(std::vector<int64_t>{1L},
114                                          std::vector<double>{expected_value});
115   ServiceConnection::GetInstance()->LoadBuiltinModel(
116       mojom::BuiltinModelSpec::New(mojom::BuiltinModelId::TEST_MODEL),
117       model.BindNewPipeAndPassReceiver(),
118       base::BindOnce(
119           [](bool* callback_done, mojom::LoadModelResult result) {
120             EXPECT_EQ(result, mojom::LoadModelResult::OK);
121             *callback_done = true;
122           },
123           &callback_done));
124   base::RunLoop().RunUntilIdle();
125   ASSERT_TRUE(callback_done);
126   ASSERT_TRUE(model.is_bound());
127 
128   callback_done = false;
129   mojo::Remote<mojom::GraphExecutor> graph;
130   model->CreateGraphExecutor(
131       graph.BindNewPipeAndPassReceiver(),
132       base::BindOnce(
133           [](bool* callback_done, mojom::CreateGraphExecutorResult result) {
134             EXPECT_EQ(result, mojom::CreateGraphExecutorResult::OK);
135             *callback_done = true;
136           },
137           &callback_done));
138   base::RunLoop().RunUntilIdle();
139   ASSERT_TRUE(callback_done);
140   ASSERT_TRUE(graph.is_bound());
141 
142   callback_done = false;
143   base::flat_map<std::string, mojom::TensorPtr> inputs;
144   std::vector<std::string> outputs;
145   graph->Execute(std::move(inputs), std::move(outputs),
146                  base::BindOnce(
147                      [](bool* callback_done, double expected_value,
148                         const mojom::ExecuteResult result,
149                         base::Optional<std::vector<mojom::TensorPtr>> outputs) {
150                        EXPECT_EQ(result, mojom::ExecuteResult::OK);
151                        ASSERT_TRUE(outputs.has_value());
152                        ASSERT_EQ(outputs->size(), 1LU);
153                        mojom::TensorPtr& tensor = (*outputs)[0];
154                        EXPECT_EQ(tensor->data->get_float_list()->value[0],
155                                  expected_value);
156 
157                        *callback_done = true;
158                      },
159                      &callback_done, expected_value));
160 
161   base::RunLoop().RunUntilIdle();
162   ASSERT_TRUE(callback_done);
163 }
164 
165 // Tests the fake ML service for flatbuffer model.
TEST_F(ServiceConnectionTest,FakeServiceConnectionForFlatBufferModel)166 TEST_F(ServiceConnectionTest, FakeServiceConnectionForFlatBufferModel) {
167   mojo::Remote<mojom::Model> model;
168   bool callback_done = false;
169   FakeServiceConnectionImpl fake_service_connection;
170   ServiceConnection::UseFakeServiceConnectionForTesting(
171       &fake_service_connection);
172 
173   const double expected_value = 200.002;
174   fake_service_connection.SetOutputValue(std::vector<int64_t>{1L},
175                                          std::vector<double>{expected_value});
176 
177   ServiceConnection::GetInstance()->LoadFlatBufferModel(
178       mojom::FlatBufferModelSpec::New(), model.BindNewPipeAndPassReceiver(),
179       base::BindOnce(
180           [](bool* callback_done, mojom::LoadModelResult result) {
181             EXPECT_EQ(result, mojom::LoadModelResult::OK);
182             *callback_done = true;
183           },
184           &callback_done));
185   base::RunLoop().RunUntilIdle();
186   ASSERT_TRUE(callback_done);
187   ASSERT_TRUE(model.is_bound());
188 
189   callback_done = false;
190   mojo::Remote<mojom::GraphExecutor> graph;
191   model->CreateGraphExecutor(
192       graph.BindNewPipeAndPassReceiver(),
193       base::BindOnce(
194           [](bool* callback_done, mojom::CreateGraphExecutorResult result) {
195             EXPECT_EQ(result, mojom::CreateGraphExecutorResult::OK);
196             *callback_done = true;
197           },
198           &callback_done));
199   base::RunLoop().RunUntilIdle();
200   ASSERT_TRUE(callback_done);
201   ASSERT_TRUE(graph.is_bound());
202 
203   callback_done = false;
204   base::flat_map<std::string, mojom::TensorPtr> inputs;
205   std::vector<std::string> outputs;
206   graph->Execute(std::move(inputs), std::move(outputs),
207                  base::BindOnce(
208                      [](bool* callback_done, double expected_value,
209                         const mojom::ExecuteResult result,
210                         base::Optional<std::vector<mojom::TensorPtr>> outputs) {
211                        EXPECT_EQ(result, mojom::ExecuteResult::OK);
212                        ASSERT_TRUE(outputs.has_value());
213                        ASSERT_EQ(outputs->size(), 1LU);
214                        mojom::TensorPtr& tensor = (*outputs)[0];
215                        EXPECT_EQ(tensor->data->get_float_list()->value[0],
216                                  expected_value);
217 
218                        *callback_done = true;
219                      },
220                      &callback_done, expected_value));
221 
222   base::RunLoop().RunUntilIdle();
223   ASSERT_TRUE(callback_done);
224 }
225 
226 // Tests the fake ML service for text classifier annotation.
TEST_F(ServiceConnectionTest,FakeServiceConnectionForTextClassifierAnnotation)227 TEST_F(ServiceConnectionTest,
228        FakeServiceConnectionForTextClassifierAnnotation) {
229   mojo::Remote<mojom::TextClassifier> text_classifier;
230   bool callback_done = false;
231   FakeServiceConnectionImpl fake_service_connection;
232   ServiceConnection::UseFakeServiceConnectionForTesting(
233       &fake_service_connection);
234 
235   auto dummy_data = mojom::TextEntityData::New();
236   dummy_data->set_numeric_value(123456789.);
237   std::vector<mojom::TextEntityPtr> entities;
238   entities.emplace_back(
239       mojom::TextEntity::New("dummy",                      // Entity name.
240                              1.0,                          // Confidence score.
241                              std::move(dummy_data)));      // Data extracted.
242   auto dummy_annotation = mojom::TextAnnotation::New(123,  // Start offset.
243                                                      321,  // End offset.
244                                                      std::move(entities));
245   std::vector<mojom::TextAnnotationPtr> annotations;
246   annotations.emplace_back(std::move(dummy_annotation));
247   fake_service_connection.SetOutputAnnotation(annotations);
248 
249   ServiceConnection::GetInstance()->LoadTextClassifier(
250       text_classifier.BindNewPipeAndPassReceiver(),
251       base::BindOnce(
252           [](bool* callback_done, mojom::LoadModelResult result) {
253             EXPECT_EQ(result, mojom::LoadModelResult::OK);
254             *callback_done = true;
255           },
256           &callback_done));
257   base::RunLoop().RunUntilIdle();
258   ASSERT_TRUE(callback_done);
259   ASSERT_TRUE(text_classifier.is_bound());
260 
261   auto request = mojom::TextAnnotationRequest::New();
262   bool infer_callback_done = false;
263   text_classifier->Annotate(
264       std::move(request),
265       base::BindOnce(
266           [](bool* infer_callback_done,
267              std::vector<mojom::TextAnnotationPtr> annotations) {
268             *infer_callback_done = true;
269             // Check if the annotation is correct.
270             EXPECT_EQ(annotations[0]->start_offset, 123u);
271             EXPECT_EQ(annotations[0]->end_offset, 321u);
272             EXPECT_EQ(annotations[0]->entities[0]->name, "dummy");
273             EXPECT_EQ(annotations[0]->entities[0]->confidence_score, 1.0);
274             EXPECT_EQ(annotations[0]->entities[0]->data->get_numeric_value(),
275                       123456789.);
276           },
277           &infer_callback_done));
278   base::RunLoop().RunUntilIdle();
279   ASSERT_TRUE(infer_callback_done);
280 }
281 
282 // Tests the fake ML service for text classifier suggest selection.
TEST_F(ServiceConnectionTest,FakeServiceConnectionForTextClassifierSuggestSelection)283 TEST_F(ServiceConnectionTest,
284        FakeServiceConnectionForTextClassifierSuggestSelection) {
285   mojo::Remote<mojom::TextClassifier> text_classifier;
286   bool callback_done = false;
287   FakeServiceConnectionImpl fake_service_connection;
288   ServiceConnection::UseFakeServiceConnectionForTesting(
289       &fake_service_connection);
290 
291   auto span = mojom::CodepointSpan::New();
292   span->start_offset = 1;
293   span->end_offset = 2;
294   fake_service_connection.SetOutputSelection(span);
295 
296   ServiceConnection::GetInstance()->LoadTextClassifier(
297       text_classifier.BindNewPipeAndPassReceiver(),
298       base::BindOnce(
299           [](bool* callback_done, mojom::LoadModelResult result) {
300             EXPECT_EQ(result, mojom::LoadModelResult::OK);
301             *callback_done = true;
302           },
303           &callback_done));
304   base::RunLoop().RunUntilIdle();
305   ASSERT_TRUE(callback_done);
306   ASSERT_TRUE(text_classifier.is_bound());
307 
308   auto request = mojom::TextSuggestSelectionRequest::New();
309   request->user_selection = mojom::CodepointSpan::New();
310   bool infer_callback_done = false;
311   text_classifier->SuggestSelection(
312       std::move(request), base::BindOnce(
313                               [](bool* infer_callback_done,
314                                  mojom::CodepointSpanPtr suggested_span) {
315                                 *infer_callback_done = true;
316                                 // Check if the suggestion is correct.
317                                 EXPECT_EQ(suggested_span->start_offset, 1u);
318                                 EXPECT_EQ(suggested_span->end_offset, 2u);
319                               },
320                               &infer_callback_done));
321   base::RunLoop().RunUntilIdle();
322   ASSERT_TRUE(infer_callback_done);
323 }
324 
325 // Tests the fake ML service for text classifier language identification.
TEST_F(ServiceConnectionTest,FakeServiceConnectionForTextClassifierFindLanguages)326 TEST_F(ServiceConnectionTest,
327        FakeServiceConnectionForTextClassifierFindLanguages) {
328   mojo::Remote<mojom::TextClassifier> text_classifier;
329   bool callback_done = false;
330   FakeServiceConnectionImpl fake_service_connection;
331   ServiceConnection::UseFakeServiceConnectionForTesting(
332       &fake_service_connection);
333 
334   std::vector<mojom::TextLanguagePtr> languages;
335   languages.emplace_back(mojom::TextLanguage::New("en", 0.9));
336   languages.emplace_back(mojom::TextLanguage::New("fr", 0.1));
337   fake_service_connection.SetOutputLanguages(languages);
338 
339   ServiceConnection::GetInstance()->LoadTextClassifier(
340       text_classifier.BindNewPipeAndPassReceiver(),
341       base::BindOnce(
342           [](bool* callback_done, mojom::LoadModelResult result) {
343             EXPECT_EQ(result, mojom::LoadModelResult::OK);
344             *callback_done = true;
345           },
346           &callback_done));
347   base::RunLoop().RunUntilIdle();
348   ASSERT_TRUE(callback_done);
349   ASSERT_TRUE(text_classifier.is_bound());
350 
351   std::string input_text = "dummy input text";
352   bool infer_callback_done = false;
353   text_classifier->FindLanguages(
354       input_text, base::BindOnce(
355                       [](bool* infer_callback_done,
356                          std::vector<mojom::TextLanguagePtr> languages) {
357                         *infer_callback_done = true;
358                         // Check if the suggestion is correct.
359                         ASSERT_EQ(languages.size(), 2ul);
360                         EXPECT_EQ(languages[0]->locale, "en");
361                         EXPECT_EQ(languages[0]->confidence, 0.9f);
362                         EXPECT_EQ(languages[1]->locale, "fr");
363                         EXPECT_EQ(languages[1]->confidence, 0.1f);
364                       },
365                       &infer_callback_done));
366   base::RunLoop().RunUntilIdle();
367   ASSERT_TRUE(infer_callback_done);
368 }
369 
370 // Tests the fake ML service for handwriting.
TEST_F(ServiceConnectionTest,FakeHandWritingRecognizer)371 TEST_F(ServiceConnectionTest, FakeHandWritingRecognizer) {
372   mojo::Remote<mojom::HandwritingRecognizer> recognizer;
373   bool callback_done = false;
374   FakeServiceConnectionImpl fake_service_connection;
375   ServiceConnection::UseFakeServiceConnectionForTesting(
376       &fake_service_connection);
377 
378   ServiceConnection::GetInstance()->LoadHandwritingModel(
379       mojom::HandwritingRecognizerSpec::New("en"),
380       recognizer.BindNewPipeAndPassReceiver(),
381       base::BindOnce(
382           [](bool* callback_done, mojom::LoadHandwritingModelResult result) {
383             EXPECT_EQ(result, mojom::LoadHandwritingModelResult::OK);
384             *callback_done = true;
385           },
386           &callback_done));
387   base::RunLoop().RunUntilIdle();
388   ASSERT_TRUE(callback_done);
389   ASSERT_TRUE(recognizer.is_bound());
390 
391   // Construct fake output.
392   mojom::HandwritingRecognizerResultPtr result =
393       mojom::HandwritingRecognizerResult::New();
394   result->status = mojom::HandwritingRecognizerResult::Status::OK;
395   mojom::HandwritingRecognizerCandidatePtr candidate =
396       mojom::HandwritingRecognizerCandidate::New();
397   candidate->text = "cat";
398   candidate->score = 0.5f;
399   result->candidates.emplace_back(std::move(candidate));
400   fake_service_connection.SetOutputHandwritingRecognizerResult(result);
401 
402   auto query = mojom::HandwritingRecognitionQuery::New();
403   bool infer_callback_done = false;
404   recognizer->Recognize(
405       std::move(query),
406       base::Bind(
407           [](bool* infer_callback_done,
408              mojom::HandwritingRecognizerResultPtr result) {
409             *infer_callback_done = true;
410             // Check if the annotation is correct.
411             ASSERT_EQ(result->status,
412                       mojom::HandwritingRecognizerResult::Status::OK);
413             EXPECT_EQ(result->candidates.at(0)->text, "cat");
414             EXPECT_EQ(result->candidates.at(0)->score, 0.5f);
415           },
416           &infer_callback_done));
417   base::RunLoop().RunUntilIdle();
418   ASSERT_TRUE(infer_callback_done);
419 }
420 
421 // Tests the deprecated fake ML service for handwriting.
422 // Deprecated API.
TEST_F(ServiceConnectionTest,FakeHandWritingRecognizerWithSpec)423 TEST_F(ServiceConnectionTest, FakeHandWritingRecognizerWithSpec) {
424   mojo::Remote<mojom::HandwritingRecognizer> recognizer;
425   bool callback_done = false;
426   FakeServiceConnectionImpl fake_service_connection;
427   ServiceConnection::UseFakeServiceConnectionForTesting(
428       &fake_service_connection);
429 
430   ServiceConnection::GetInstance()->LoadHandwritingModelWithSpec(
431       mojom::HandwritingRecognizerSpec::New("en"),
432       recognizer.BindNewPipeAndPassReceiver(),
433       base::BindOnce(
434           [](bool* callback_done, mojom::LoadModelResult result) {
435             EXPECT_EQ(result, mojom::LoadModelResult::OK);
436             *callback_done = true;
437           },
438           &callback_done));
439   base::RunLoop().RunUntilIdle();
440   ASSERT_TRUE(callback_done);
441   ASSERT_TRUE(recognizer.is_bound());
442 
443   // Construct fake output.
444   mojom::HandwritingRecognizerResultPtr result =
445       mojom::HandwritingRecognizerResult::New();
446   result->status = mojom::HandwritingRecognizerResult::Status::OK;
447   mojom::HandwritingRecognizerCandidatePtr candidate =
448       mojom::HandwritingRecognizerCandidate::New();
449   candidate->text = "cat";
450   candidate->score = 0.5f;
451   result->candidates.emplace_back(std::move(candidate));
452   fake_service_connection.SetOutputHandwritingRecognizerResult(result);
453 
454   auto query = mojom::HandwritingRecognitionQuery::New();
455   bool infer_callback_done = false;
456   recognizer->Recognize(
457       std::move(query),
458       base::BindOnce(
459           [](bool* infer_callback_done,
460              mojom::HandwritingRecognizerResultPtr result) {
461             *infer_callback_done = true;
462             // Check if the annotation is correct.
463             ASSERT_EQ(result->status,
464                       mojom::HandwritingRecognizerResult::Status::OK);
465             EXPECT_EQ(result->candidates.at(0)->text, "cat");
466             EXPECT_EQ(result->candidates.at(0)->score, 0.5f);
467           },
468           &infer_callback_done));
469   base::RunLoop().RunUntilIdle();
470   ASSERT_TRUE(infer_callback_done);
471 }
472 
TEST_F(ServiceConnectionTest,FakeGrammarChecker)473 TEST_F(ServiceConnectionTest, FakeGrammarChecker) {
474   mojo::Remote<mojom::GrammarChecker> checker;
475   bool callback_done = false;
476   FakeServiceConnectionImpl fake_service_connection;
477   ServiceConnection::UseFakeServiceConnectionForTesting(
478       &fake_service_connection);
479 
480   ServiceConnection::GetInstance()->LoadGrammarChecker(
481       checker.BindNewPipeAndPassReceiver(),
482       base::BindOnce(
483           [](bool* callback_done, mojom::LoadModelResult result) {
484             EXPECT_EQ(result, mojom::LoadModelResult::OK);
485             *callback_done = true;
486           },
487           &callback_done));
488   base::RunLoop().RunUntilIdle();
489   ASSERT_TRUE(callback_done);
490   ASSERT_TRUE(checker.is_bound());
491 
492   // Construct fake output
493   mojom::GrammarCheckerResultPtr result = mojom::GrammarCheckerResult::New();
494   result->status = mojom::GrammarCheckerResult::Status::OK;
495   mojom::GrammarCheckerCandidatePtr candidate =
496       mojom::GrammarCheckerCandidate::New();
497   candidate->text = "cat";
498   candidate->score = 0.5f;
499   result->candidates.emplace_back(std::move(candidate));
500   fake_service_connection.SetOutputGrammarCheckerResult(result);
501 
502   auto query = mojom::GrammarCheckerQuery::New();
503   bool infer_callback_done = false;
504   checker->Check(
505       std::move(query),
506       base::BindOnce(
507           [](bool* infer_callback_done, mojom::GrammarCheckerResultPtr result) {
508             *infer_callback_done = true;
509             // Check if the annotation is correct.
510             ASSERT_EQ(result->status, mojom::GrammarCheckerResult::Status::OK);
511             EXPECT_EQ(result->candidates.at(0)->text, "cat");
512             EXPECT_EQ(result->candidates.at(0)->score, 0.5f);
513           },
514           &infer_callback_done));
515   base::RunLoop().RunUntilIdle();
516   ASSERT_TRUE(infer_callback_done);
517 }
518 
519 }  // namespace
520 }  // namespace machine_learning
521 }  // namespace chromeos
522