1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "chromeos/services/machine_learning/public/cpp/service_connection.h"
6
7 #include <utility>
8 #include <vector>
9
10 #include "base/bind.h"
11 #include "base/macros.h"
12 #include "base/message_loop/message_pump_type.h"
13 #include "base/run_loop.h"
14 #include "base/test/task_environment.h"
15 #include "base/threading/thread.h"
16 #include "chromeos/dbus/machine_learning/machine_learning_client.h"
17 #include "chromeos/services/machine_learning/public/cpp/fake_service_connection.h"
18 #include "chromeos/services/machine_learning/public/mojom/graph_executor.mojom.h"
19 #include "chromeos/services/machine_learning/public/mojom/handwriting_recognizer.mojom.h"
20 #include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
21 #include "chromeos/services/machine_learning/public/mojom/model.mojom.h"
22 #include "chromeos/services/machine_learning/public/mojom/tensor.mojom.h"
23 #include "mojo/core/embedder/embedder.h"
24 #include "mojo/core/embedder/scoped_ipc_support.h"
25 #include "mojo/public/cpp/bindings/remote.h"
26 #include "testing/gtest/include/gtest/gtest.h"
27
28 namespace chromeos {
29 namespace machine_learning {
30 namespace {
31
32 class ServiceConnectionTest : public testing::Test {
33 public:
34 ServiceConnectionTest() = default;
35
SetUp()36 void SetUp() override { MachineLearningClient::InitializeFake(); }
37
TearDown()38 void TearDown() override { MachineLearningClient::Shutdown(); }
39
40 protected:
SetUpTestCase()41 static void SetUpTestCase() {
42 static base::Thread ipc_thread("ipc");
43 ipc_thread.StartWithOptions(
44 base::Thread::Options(base::MessagePumpType::IO, 0));
45 static mojo::core::ScopedIPCSupport ipc_support(
46 ipc_thread.task_runner(),
47 mojo::core::ScopedIPCSupport::ShutdownPolicy::CLEAN);
48 }
49
50 private:
51 base::test::TaskEnvironment task_environment_;
52
53 DISALLOW_COPY_AND_ASSIGN(ServiceConnectionTest);
54 };
55
56 // Tests that LoadBuiltinModel runs OK (no crash) in a basic Mojo
57 // environment.
TEST_F(ServiceConnectionTest,LoadBuiltinModel)58 TEST_F(ServiceConnectionTest, LoadBuiltinModel) {
59 mojo::Remote<mojom::Model> model;
60 mojom::BuiltinModelSpecPtr spec =
61 mojom::BuiltinModelSpec::New(mojom::BuiltinModelId::TEST_MODEL);
62 ServiceConnection::GetInstance()->LoadBuiltinModel(
63 std::move(spec), model.BindNewPipeAndPassReceiver(),
64 base::BindOnce([](mojom::LoadModelResult result) {}));
65 }
66
67 // Tests that LoadFlatBufferModel runs OK (no crash) in a basic Mojo
68 // environment.
TEST_F(ServiceConnectionTest,LoadFlatBufferModel)69 TEST_F(ServiceConnectionTest, LoadFlatBufferModel) {
70 mojo::Remote<mojom::Model> model;
71 mojom::FlatBufferModelSpecPtr spec = mojom::FlatBufferModelSpec::New();
72 ServiceConnection::GetInstance()->LoadFlatBufferModel(
73 std::move(spec), model.BindNewPipeAndPassReceiver(),
74 base::BindOnce([](mojom::LoadModelResult result) {}));
75 }
76
77 // Tests that LoadTextClassifier runs OK (no crash) in a basic Mojo
78 // environment.
TEST_F(ServiceConnectionTest,LoadTextClassifier)79 TEST_F(ServiceConnectionTest, LoadTextClassifier) {
80 mojo::Remote<mojom::TextClassifier> text_classifier;
81 ServiceConnection::GetInstance()->LoadTextClassifier(
82 text_classifier.BindNewPipeAndPassReceiver(),
83 base::BindOnce([](mojom::LoadModelResult result) {}));
84 }
85
86 // Tests that LoadHandwritingModelWithSpec runs OK (no crash) in a basic Mojo
87 // environment.
TEST_F(ServiceConnectionTest,LoadHandwritingModelWithSpec)88 TEST_F(ServiceConnectionTest, LoadHandwritingModelWithSpec) {
89 mojo::Remote<mojom::HandwritingRecognizer> handwriting_recognizer;
90 ServiceConnection::GetInstance()->LoadHandwritingModelWithSpec(
91 mojom::HandwritingRecognizerSpec::New("en"),
92 handwriting_recognizer.BindNewPipeAndPassReceiver(),
93 base::BindOnce([](mojom::LoadModelResult result) {}));
94 }
95
96 // Tests that LoadGrammarChecker runs OK (no crash) in a basic Mojo environment.
TEST_F(ServiceConnectionTest,LoadGrammarModel)97 TEST_F(ServiceConnectionTest, LoadGrammarModel) {
98 mojo::Remote<mojom::GrammarChecker> grammar_checker;
99 ServiceConnection::GetInstance()->LoadGrammarChecker(
100 grammar_checker.BindNewPipeAndPassReceiver(),
101 base::BindOnce([](mojom::LoadModelResult result) {}));
102 }
103
104 // Tests the fake ML service for builtin model.
TEST_F(ServiceConnectionTest,FakeServiceConnectionForBuiltinModel)105 TEST_F(ServiceConnectionTest, FakeServiceConnectionForBuiltinModel) {
106 mojo::Remote<mojom::Model> model;
107 bool callback_done = false;
108 FakeServiceConnectionImpl fake_service_connection;
109 ServiceConnection::UseFakeServiceConnectionForTesting(
110 &fake_service_connection);
111
112 const double expected_value = 200.002;
113 fake_service_connection.SetOutputValue(std::vector<int64_t>{1L},
114 std::vector<double>{expected_value});
115 ServiceConnection::GetInstance()->LoadBuiltinModel(
116 mojom::BuiltinModelSpec::New(mojom::BuiltinModelId::TEST_MODEL),
117 model.BindNewPipeAndPassReceiver(),
118 base::BindOnce(
119 [](bool* callback_done, mojom::LoadModelResult result) {
120 EXPECT_EQ(result, mojom::LoadModelResult::OK);
121 *callback_done = true;
122 },
123 &callback_done));
124 base::RunLoop().RunUntilIdle();
125 ASSERT_TRUE(callback_done);
126 ASSERT_TRUE(model.is_bound());
127
128 callback_done = false;
129 mojo::Remote<mojom::GraphExecutor> graph;
130 model->CreateGraphExecutor(
131 graph.BindNewPipeAndPassReceiver(),
132 base::BindOnce(
133 [](bool* callback_done, mojom::CreateGraphExecutorResult result) {
134 EXPECT_EQ(result, mojom::CreateGraphExecutorResult::OK);
135 *callback_done = true;
136 },
137 &callback_done));
138 base::RunLoop().RunUntilIdle();
139 ASSERT_TRUE(callback_done);
140 ASSERT_TRUE(graph.is_bound());
141
142 callback_done = false;
143 base::flat_map<std::string, mojom::TensorPtr> inputs;
144 std::vector<std::string> outputs;
145 graph->Execute(std::move(inputs), std::move(outputs),
146 base::BindOnce(
147 [](bool* callback_done, double expected_value,
148 const mojom::ExecuteResult result,
149 base::Optional<std::vector<mojom::TensorPtr>> outputs) {
150 EXPECT_EQ(result, mojom::ExecuteResult::OK);
151 ASSERT_TRUE(outputs.has_value());
152 ASSERT_EQ(outputs->size(), 1LU);
153 mojom::TensorPtr& tensor = (*outputs)[0];
154 EXPECT_EQ(tensor->data->get_float_list()->value[0],
155 expected_value);
156
157 *callback_done = true;
158 },
159 &callback_done, expected_value));
160
161 base::RunLoop().RunUntilIdle();
162 ASSERT_TRUE(callback_done);
163 }
164
165 // Tests the fake ML service for flatbuffer model.
TEST_F(ServiceConnectionTest,FakeServiceConnectionForFlatBufferModel)166 TEST_F(ServiceConnectionTest, FakeServiceConnectionForFlatBufferModel) {
167 mojo::Remote<mojom::Model> model;
168 bool callback_done = false;
169 FakeServiceConnectionImpl fake_service_connection;
170 ServiceConnection::UseFakeServiceConnectionForTesting(
171 &fake_service_connection);
172
173 const double expected_value = 200.002;
174 fake_service_connection.SetOutputValue(std::vector<int64_t>{1L},
175 std::vector<double>{expected_value});
176
177 ServiceConnection::GetInstance()->LoadFlatBufferModel(
178 mojom::FlatBufferModelSpec::New(), model.BindNewPipeAndPassReceiver(),
179 base::BindOnce(
180 [](bool* callback_done, mojom::LoadModelResult result) {
181 EXPECT_EQ(result, mojom::LoadModelResult::OK);
182 *callback_done = true;
183 },
184 &callback_done));
185 base::RunLoop().RunUntilIdle();
186 ASSERT_TRUE(callback_done);
187 ASSERT_TRUE(model.is_bound());
188
189 callback_done = false;
190 mojo::Remote<mojom::GraphExecutor> graph;
191 model->CreateGraphExecutor(
192 graph.BindNewPipeAndPassReceiver(),
193 base::BindOnce(
194 [](bool* callback_done, mojom::CreateGraphExecutorResult result) {
195 EXPECT_EQ(result, mojom::CreateGraphExecutorResult::OK);
196 *callback_done = true;
197 },
198 &callback_done));
199 base::RunLoop().RunUntilIdle();
200 ASSERT_TRUE(callback_done);
201 ASSERT_TRUE(graph.is_bound());
202
203 callback_done = false;
204 base::flat_map<std::string, mojom::TensorPtr> inputs;
205 std::vector<std::string> outputs;
206 graph->Execute(std::move(inputs), std::move(outputs),
207 base::BindOnce(
208 [](bool* callback_done, double expected_value,
209 const mojom::ExecuteResult result,
210 base::Optional<std::vector<mojom::TensorPtr>> outputs) {
211 EXPECT_EQ(result, mojom::ExecuteResult::OK);
212 ASSERT_TRUE(outputs.has_value());
213 ASSERT_EQ(outputs->size(), 1LU);
214 mojom::TensorPtr& tensor = (*outputs)[0];
215 EXPECT_EQ(tensor->data->get_float_list()->value[0],
216 expected_value);
217
218 *callback_done = true;
219 },
220 &callback_done, expected_value));
221
222 base::RunLoop().RunUntilIdle();
223 ASSERT_TRUE(callback_done);
224 }
225
226 // Tests the fake ML service for text classifier annotation.
TEST_F(ServiceConnectionTest,FakeServiceConnectionForTextClassifierAnnotation)227 TEST_F(ServiceConnectionTest,
228 FakeServiceConnectionForTextClassifierAnnotation) {
229 mojo::Remote<mojom::TextClassifier> text_classifier;
230 bool callback_done = false;
231 FakeServiceConnectionImpl fake_service_connection;
232 ServiceConnection::UseFakeServiceConnectionForTesting(
233 &fake_service_connection);
234
235 auto dummy_data = mojom::TextEntityData::New();
236 dummy_data->set_numeric_value(123456789.);
237 std::vector<mojom::TextEntityPtr> entities;
238 entities.emplace_back(
239 mojom::TextEntity::New("dummy", // Entity name.
240 1.0, // Confidence score.
241 std::move(dummy_data))); // Data extracted.
242 auto dummy_annotation = mojom::TextAnnotation::New(123, // Start offset.
243 321, // End offset.
244 std::move(entities));
245 std::vector<mojom::TextAnnotationPtr> annotations;
246 annotations.emplace_back(std::move(dummy_annotation));
247 fake_service_connection.SetOutputAnnotation(annotations);
248
249 ServiceConnection::GetInstance()->LoadTextClassifier(
250 text_classifier.BindNewPipeAndPassReceiver(),
251 base::BindOnce(
252 [](bool* callback_done, mojom::LoadModelResult result) {
253 EXPECT_EQ(result, mojom::LoadModelResult::OK);
254 *callback_done = true;
255 },
256 &callback_done));
257 base::RunLoop().RunUntilIdle();
258 ASSERT_TRUE(callback_done);
259 ASSERT_TRUE(text_classifier.is_bound());
260
261 auto request = mojom::TextAnnotationRequest::New();
262 bool infer_callback_done = false;
263 text_classifier->Annotate(
264 std::move(request),
265 base::BindOnce(
266 [](bool* infer_callback_done,
267 std::vector<mojom::TextAnnotationPtr> annotations) {
268 *infer_callback_done = true;
269 // Check if the annotation is correct.
270 EXPECT_EQ(annotations[0]->start_offset, 123u);
271 EXPECT_EQ(annotations[0]->end_offset, 321u);
272 EXPECT_EQ(annotations[0]->entities[0]->name, "dummy");
273 EXPECT_EQ(annotations[0]->entities[0]->confidence_score, 1.0);
274 EXPECT_EQ(annotations[0]->entities[0]->data->get_numeric_value(),
275 123456789.);
276 },
277 &infer_callback_done));
278 base::RunLoop().RunUntilIdle();
279 ASSERT_TRUE(infer_callback_done);
280 }
281
282 // Tests the fake ML service for text classifier suggest selection.
TEST_F(ServiceConnectionTest,FakeServiceConnectionForTextClassifierSuggestSelection)283 TEST_F(ServiceConnectionTest,
284 FakeServiceConnectionForTextClassifierSuggestSelection) {
285 mojo::Remote<mojom::TextClassifier> text_classifier;
286 bool callback_done = false;
287 FakeServiceConnectionImpl fake_service_connection;
288 ServiceConnection::UseFakeServiceConnectionForTesting(
289 &fake_service_connection);
290
291 auto span = mojom::CodepointSpan::New();
292 span->start_offset = 1;
293 span->end_offset = 2;
294 fake_service_connection.SetOutputSelection(span);
295
296 ServiceConnection::GetInstance()->LoadTextClassifier(
297 text_classifier.BindNewPipeAndPassReceiver(),
298 base::BindOnce(
299 [](bool* callback_done, mojom::LoadModelResult result) {
300 EXPECT_EQ(result, mojom::LoadModelResult::OK);
301 *callback_done = true;
302 },
303 &callback_done));
304 base::RunLoop().RunUntilIdle();
305 ASSERT_TRUE(callback_done);
306 ASSERT_TRUE(text_classifier.is_bound());
307
308 auto request = mojom::TextSuggestSelectionRequest::New();
309 request->user_selection = mojom::CodepointSpan::New();
310 bool infer_callback_done = false;
311 text_classifier->SuggestSelection(
312 std::move(request), base::BindOnce(
313 [](bool* infer_callback_done,
314 mojom::CodepointSpanPtr suggested_span) {
315 *infer_callback_done = true;
316 // Check if the suggestion is correct.
317 EXPECT_EQ(suggested_span->start_offset, 1u);
318 EXPECT_EQ(suggested_span->end_offset, 2u);
319 },
320 &infer_callback_done));
321 base::RunLoop().RunUntilIdle();
322 ASSERT_TRUE(infer_callback_done);
323 }
324
325 // Tests the fake ML service for text classifier language identification.
TEST_F(ServiceConnectionTest,FakeServiceConnectionForTextClassifierFindLanguages)326 TEST_F(ServiceConnectionTest,
327 FakeServiceConnectionForTextClassifierFindLanguages) {
328 mojo::Remote<mojom::TextClassifier> text_classifier;
329 bool callback_done = false;
330 FakeServiceConnectionImpl fake_service_connection;
331 ServiceConnection::UseFakeServiceConnectionForTesting(
332 &fake_service_connection);
333
334 std::vector<mojom::TextLanguagePtr> languages;
335 languages.emplace_back(mojom::TextLanguage::New("en", 0.9));
336 languages.emplace_back(mojom::TextLanguage::New("fr", 0.1));
337 fake_service_connection.SetOutputLanguages(languages);
338
339 ServiceConnection::GetInstance()->LoadTextClassifier(
340 text_classifier.BindNewPipeAndPassReceiver(),
341 base::BindOnce(
342 [](bool* callback_done, mojom::LoadModelResult result) {
343 EXPECT_EQ(result, mojom::LoadModelResult::OK);
344 *callback_done = true;
345 },
346 &callback_done));
347 base::RunLoop().RunUntilIdle();
348 ASSERT_TRUE(callback_done);
349 ASSERT_TRUE(text_classifier.is_bound());
350
351 std::string input_text = "dummy input text";
352 bool infer_callback_done = false;
353 text_classifier->FindLanguages(
354 input_text, base::BindOnce(
355 [](bool* infer_callback_done,
356 std::vector<mojom::TextLanguagePtr> languages) {
357 *infer_callback_done = true;
358 // Check if the suggestion is correct.
359 ASSERT_EQ(languages.size(), 2ul);
360 EXPECT_EQ(languages[0]->locale, "en");
361 EXPECT_EQ(languages[0]->confidence, 0.9f);
362 EXPECT_EQ(languages[1]->locale, "fr");
363 EXPECT_EQ(languages[1]->confidence, 0.1f);
364 },
365 &infer_callback_done));
366 base::RunLoop().RunUntilIdle();
367 ASSERT_TRUE(infer_callback_done);
368 }
369
370 // Tests the fake ML service for handwriting.
TEST_F(ServiceConnectionTest,FakeHandWritingRecognizer)371 TEST_F(ServiceConnectionTest, FakeHandWritingRecognizer) {
372 mojo::Remote<mojom::HandwritingRecognizer> recognizer;
373 bool callback_done = false;
374 FakeServiceConnectionImpl fake_service_connection;
375 ServiceConnection::UseFakeServiceConnectionForTesting(
376 &fake_service_connection);
377
378 ServiceConnection::GetInstance()->LoadHandwritingModel(
379 mojom::HandwritingRecognizerSpec::New("en"),
380 recognizer.BindNewPipeAndPassReceiver(),
381 base::BindOnce(
382 [](bool* callback_done, mojom::LoadHandwritingModelResult result) {
383 EXPECT_EQ(result, mojom::LoadHandwritingModelResult::OK);
384 *callback_done = true;
385 },
386 &callback_done));
387 base::RunLoop().RunUntilIdle();
388 ASSERT_TRUE(callback_done);
389 ASSERT_TRUE(recognizer.is_bound());
390
391 // Construct fake output.
392 mojom::HandwritingRecognizerResultPtr result =
393 mojom::HandwritingRecognizerResult::New();
394 result->status = mojom::HandwritingRecognizerResult::Status::OK;
395 mojom::HandwritingRecognizerCandidatePtr candidate =
396 mojom::HandwritingRecognizerCandidate::New();
397 candidate->text = "cat";
398 candidate->score = 0.5f;
399 result->candidates.emplace_back(std::move(candidate));
400 fake_service_connection.SetOutputHandwritingRecognizerResult(result);
401
402 auto query = mojom::HandwritingRecognitionQuery::New();
403 bool infer_callback_done = false;
404 recognizer->Recognize(
405 std::move(query),
406 base::Bind(
407 [](bool* infer_callback_done,
408 mojom::HandwritingRecognizerResultPtr result) {
409 *infer_callback_done = true;
410 // Check if the annotation is correct.
411 ASSERT_EQ(result->status,
412 mojom::HandwritingRecognizerResult::Status::OK);
413 EXPECT_EQ(result->candidates.at(0)->text, "cat");
414 EXPECT_EQ(result->candidates.at(0)->score, 0.5f);
415 },
416 &infer_callback_done));
417 base::RunLoop().RunUntilIdle();
418 ASSERT_TRUE(infer_callback_done);
419 }
420
421 // Tests the deprecated fake ML service for handwriting.
422 // Deprecated API.
TEST_F(ServiceConnectionTest,FakeHandWritingRecognizerWithSpec)423 TEST_F(ServiceConnectionTest, FakeHandWritingRecognizerWithSpec) {
424 mojo::Remote<mojom::HandwritingRecognizer> recognizer;
425 bool callback_done = false;
426 FakeServiceConnectionImpl fake_service_connection;
427 ServiceConnection::UseFakeServiceConnectionForTesting(
428 &fake_service_connection);
429
430 ServiceConnection::GetInstance()->LoadHandwritingModelWithSpec(
431 mojom::HandwritingRecognizerSpec::New("en"),
432 recognizer.BindNewPipeAndPassReceiver(),
433 base::BindOnce(
434 [](bool* callback_done, mojom::LoadModelResult result) {
435 EXPECT_EQ(result, mojom::LoadModelResult::OK);
436 *callback_done = true;
437 },
438 &callback_done));
439 base::RunLoop().RunUntilIdle();
440 ASSERT_TRUE(callback_done);
441 ASSERT_TRUE(recognizer.is_bound());
442
443 // Construct fake output.
444 mojom::HandwritingRecognizerResultPtr result =
445 mojom::HandwritingRecognizerResult::New();
446 result->status = mojom::HandwritingRecognizerResult::Status::OK;
447 mojom::HandwritingRecognizerCandidatePtr candidate =
448 mojom::HandwritingRecognizerCandidate::New();
449 candidate->text = "cat";
450 candidate->score = 0.5f;
451 result->candidates.emplace_back(std::move(candidate));
452 fake_service_connection.SetOutputHandwritingRecognizerResult(result);
453
454 auto query = mojom::HandwritingRecognitionQuery::New();
455 bool infer_callback_done = false;
456 recognizer->Recognize(
457 std::move(query),
458 base::BindOnce(
459 [](bool* infer_callback_done,
460 mojom::HandwritingRecognizerResultPtr result) {
461 *infer_callback_done = true;
462 // Check if the annotation is correct.
463 ASSERT_EQ(result->status,
464 mojom::HandwritingRecognizerResult::Status::OK);
465 EXPECT_EQ(result->candidates.at(0)->text, "cat");
466 EXPECT_EQ(result->candidates.at(0)->score, 0.5f);
467 },
468 &infer_callback_done));
469 base::RunLoop().RunUntilIdle();
470 ASSERT_TRUE(infer_callback_done);
471 }
472
TEST_F(ServiceConnectionTest,FakeGrammarChecker)473 TEST_F(ServiceConnectionTest, FakeGrammarChecker) {
474 mojo::Remote<mojom::GrammarChecker> checker;
475 bool callback_done = false;
476 FakeServiceConnectionImpl fake_service_connection;
477 ServiceConnection::UseFakeServiceConnectionForTesting(
478 &fake_service_connection);
479
480 ServiceConnection::GetInstance()->LoadGrammarChecker(
481 checker.BindNewPipeAndPassReceiver(),
482 base::BindOnce(
483 [](bool* callback_done, mojom::LoadModelResult result) {
484 EXPECT_EQ(result, mojom::LoadModelResult::OK);
485 *callback_done = true;
486 },
487 &callback_done));
488 base::RunLoop().RunUntilIdle();
489 ASSERT_TRUE(callback_done);
490 ASSERT_TRUE(checker.is_bound());
491
492 // Construct fake output
493 mojom::GrammarCheckerResultPtr result = mojom::GrammarCheckerResult::New();
494 result->status = mojom::GrammarCheckerResult::Status::OK;
495 mojom::GrammarCheckerCandidatePtr candidate =
496 mojom::GrammarCheckerCandidate::New();
497 candidate->text = "cat";
498 candidate->score = 0.5f;
499 result->candidates.emplace_back(std::move(candidate));
500 fake_service_connection.SetOutputGrammarCheckerResult(result);
501
502 auto query = mojom::GrammarCheckerQuery::New();
503 bool infer_callback_done = false;
504 checker->Check(
505 std::move(query),
506 base::BindOnce(
507 [](bool* infer_callback_done, mojom::GrammarCheckerResultPtr result) {
508 *infer_callback_done = true;
509 // Check if the annotation is correct.
510 ASSERT_EQ(result->status, mojom::GrammarCheckerResult::Status::OK);
511 EXPECT_EQ(result->candidates.at(0)->text, "cat");
512 EXPECT_EQ(result->candidates.at(0)->score, 0.5f);
513 },
514 &infer_callback_done));
515 base::RunLoop().RunUntilIdle();
516 ASSERT_TRUE(infer_callback_done);
517 }
518
519 } // namespace
520 } // namespace machine_learning
521 } // namespace chromeos
522