1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "chromeos/services/machine_learning/public/cpp/service_connection.h"
6
7 #include "base/bind.h"
8 #include "base/macros.h"
9 #include "base/no_destructor.h"
10 #include "base/sequence_checker.h"
11 #include "chromeos/dbus/machine_learning/machine_learning_client.h"
12 #include "chromeos/services/machine_learning/public/mojom/handwriting_recognizer.mojom.h"
13 #include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
14 #include "chromeos/services/machine_learning/public/mojom/model.mojom.h"
15 #include "mojo/public/cpp/bindings/remote.h"
16 #include "mojo/public/cpp/platform/platform_channel.h"
17 #include "mojo/public/cpp/system/invitation.h"
18 #include "third_party/cros_system_api/dbus/service_constants.h"
19
20 namespace chromeos {
21 namespace machine_learning {
22
23 namespace {
24
25 // Real Impl of ServiceConnection
26 class ServiceConnectionImpl : public ServiceConnection {
27 public:
28 ServiceConnectionImpl();
29 ~ServiceConnectionImpl() override = default;
30
31 void LoadBuiltinModel(mojom::BuiltinModelSpecPtr spec,
32 mojo::PendingReceiver<mojom::Model> receiver,
33 mojom::MachineLearningService::LoadBuiltinModelCallback
34 result_callback) override;
35
36 void LoadFlatBufferModel(
37 mojom::FlatBufferModelSpecPtr spec,
38 mojo::PendingReceiver<mojom::Model> receiver,
39 mojom::MachineLearningService::LoadFlatBufferModelCallback
40 result_callback) override;
41
42 void LoadTextClassifier(
43 mojo::PendingReceiver<mojom::TextClassifier> receiver,
44 mojom::MachineLearningService::LoadTextClassifierCallback
45 result_callback) override;
46
47 void LoadHandwritingModel(
48 mojom::HandwritingRecognizerSpecPtr spec,
49 mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,
50 mojom::MachineLearningService::LoadHandwritingModelCallback
51 result_callback) override;
52
53 void LoadHandwritingModelWithSpec(
54 mojom::HandwritingRecognizerSpecPtr spec,
55 mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,
56 mojom::MachineLearningService::LoadHandwritingModelWithSpecCallback
57 result_callback) override;
58
59 void LoadGrammarChecker(
60 mojo::PendingReceiver<mojom::GrammarChecker> receiver,
61 mojom::MachineLearningService::LoadGrammarCheckerCallback result_callback)
62 override;
63
64 private:
65 // Binds the top level interface |machine_learning_service_| to an
66 // implementation in the ML Service daemon, if it is not already bound. The
67 // binding is accomplished via D-Bus bootstrap.
68 void BindMachineLearningServiceIfNeeded();
69
70 // Mojo disconnect handler. Resets |machine_learning_service_|, which
71 // will be reconnected upon next use.
72 void OnMojoDisconnect();
73
74 // Response callback for MlClient::BootstrapMojoConnection.
75 void OnBootstrapMojoConnectionResponse(bool success);
76
77 mojo::Remote<mojom::MachineLearningService> machine_learning_service_;
78
79 SEQUENCE_CHECKER(sequence_checker_);
80
81 DISALLOW_COPY_AND_ASSIGN(ServiceConnectionImpl);
82 };
83
LoadBuiltinModel(mojom::BuiltinModelSpecPtr spec,mojo::PendingReceiver<mojom::Model> receiver,mojom::MachineLearningService::LoadBuiltinModelCallback result_callback)84 void ServiceConnectionImpl::LoadBuiltinModel(
85 mojom::BuiltinModelSpecPtr spec,
86 mojo::PendingReceiver<mojom::Model> receiver,
87 mojom::MachineLearningService::LoadBuiltinModelCallback result_callback) {
88 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
89 BindMachineLearningServiceIfNeeded();
90 machine_learning_service_->LoadBuiltinModel(
91 std::move(spec), std::move(receiver), std::move(result_callback));
92 }
93
LoadFlatBufferModel(mojom::FlatBufferModelSpecPtr spec,mojo::PendingReceiver<mojom::Model> receiver,mojom::MachineLearningService::LoadFlatBufferModelCallback result_callback)94 void ServiceConnectionImpl::LoadFlatBufferModel(
95 mojom::FlatBufferModelSpecPtr spec,
96 mojo::PendingReceiver<mojom::Model> receiver,
97 mojom::MachineLearningService::LoadFlatBufferModelCallback
98 result_callback) {
99 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
100 BindMachineLearningServiceIfNeeded();
101 machine_learning_service_->LoadFlatBufferModel(
102 std::move(spec), std::move(receiver), std::move(result_callback));
103 }
104
LoadTextClassifier(mojo::PendingReceiver<mojom::TextClassifier> receiver,mojom::MachineLearningService::LoadTextClassifierCallback result_callback)105 void ServiceConnectionImpl::LoadTextClassifier(
106 mojo::PendingReceiver<mojom::TextClassifier> receiver,
107 mojom::MachineLearningService::LoadTextClassifierCallback result_callback) {
108 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
109 BindMachineLearningServiceIfNeeded();
110 machine_learning_service_->LoadTextClassifier(std::move(receiver),
111 std::move(result_callback));
112 }
113
LoadHandwritingModel(mojom::HandwritingRecognizerSpecPtr spec,mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,mojom::MachineLearningService::LoadHandwritingModelCallback result_callback)114 void ServiceConnectionImpl::LoadHandwritingModel(
115 mojom::HandwritingRecognizerSpecPtr spec,
116 mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,
117 mojom::MachineLearningService::LoadHandwritingModelCallback
118 result_callback) {
119 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
120 BindMachineLearningServiceIfNeeded();
121 machine_learning_service_->LoadHandwritingModel(
122 std::move(spec), std::move(receiver), std::move(result_callback));
123 }
124
LoadHandwritingModelWithSpec(mojom::HandwritingRecognizerSpecPtr spec,mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,mojom::MachineLearningService::LoadHandwritingModelWithSpecCallback result_callback)125 void ServiceConnectionImpl::LoadHandwritingModelWithSpec(
126 mojom::HandwritingRecognizerSpecPtr spec,
127 mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,
128 mojom::MachineLearningService::LoadHandwritingModelWithSpecCallback
129 result_callback) {
130 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
131 BindMachineLearningServiceIfNeeded();
132 machine_learning_service_->LoadHandwritingModelWithSpec(
133 std::move(spec), std::move(receiver), std::move(result_callback));
134 }
135
LoadGrammarChecker(mojo::PendingReceiver<mojom::GrammarChecker> receiver,mojom::MachineLearningService::LoadGrammarCheckerCallback result_callback)136 void ServiceConnectionImpl::LoadGrammarChecker(
137 mojo::PendingReceiver<mojom::GrammarChecker> receiver,
138 mojom::MachineLearningService::LoadGrammarCheckerCallback result_callback) {
139 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
140 BindMachineLearningServiceIfNeeded();
141 machine_learning_service_->LoadGrammarChecker(std::move(receiver),
142 std::move(result_callback));
143 }
144
BindMachineLearningServiceIfNeeded()145 void ServiceConnectionImpl::BindMachineLearningServiceIfNeeded() {
146 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
147 if (machine_learning_service_) {
148 return;
149 }
150
151 mojo::PlatformChannel platform_channel;
152
153 // Prepare a Mojo invitation to send through |platform_channel|.
154 mojo::OutgoingInvitation invitation;
155 // Include an initial Mojo pipe in the invitation.
156 mojo::ScopedMessagePipeHandle pipe =
157 invitation.AttachMessagePipe(ml::kBootstrapMojoConnectionChannelToken);
158 mojo::OutgoingInvitation::Send(std::move(invitation),
159 base::kNullProcessHandle,
160 platform_channel.TakeLocalEndpoint());
161
162 // Bind our end of |pipe| to our mojo::Remote<MachineLearningService>. The
163 // daemon should bind its end to a MachineLearningService implementation.
164 machine_learning_service_.Bind(
165 mojo::PendingRemote<machine_learning::mojom::MachineLearningService>(
166 std::move(pipe), 0u /* version */));
167 machine_learning_service_.set_disconnect_handler(base::BindOnce(
168 &ServiceConnectionImpl::OnMojoDisconnect, base::Unretained(this)));
169
170 // Send the file descriptor for the other end of |platform_channel| to the
171 // ML service daemon over D-Bus.
172 MachineLearningClient::Get()->BootstrapMojoConnection(
173 platform_channel.TakeRemoteEndpoint().TakePlatformHandle().TakeFD(),
174 base::BindOnce(&ServiceConnectionImpl::OnBootstrapMojoConnectionResponse,
175 base::Unretained(this)));
176 }
177
ServiceConnectionImpl()178 ServiceConnectionImpl::ServiceConnectionImpl() {
179 DETACH_FROM_SEQUENCE(sequence_checker_);
180 }
181
OnMojoDisconnect()182 void ServiceConnectionImpl::OnMojoDisconnect() {
183 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
184 // Connection errors are not expected so log a warning.
185 LOG(WARNING) << "ML Service Mojo connection closed";
186 machine_learning_service_.reset();
187 }
188
OnBootstrapMojoConnectionResponse(const bool success)189 void ServiceConnectionImpl::OnBootstrapMojoConnectionResponse(
190 const bool success) {
191 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
192 if (!success) {
193 LOG(WARNING) << "BootstrapMojoConnection D-Bus call failed";
194 machine_learning_service_.reset();
195 }
196 }
197
198 static ServiceConnection* g_fake_service_connection_for_testing = nullptr;
199
200 } // namespace
201
GetInstance()202 ServiceConnection* ServiceConnection::GetInstance() {
203 if (g_fake_service_connection_for_testing) {
204 return g_fake_service_connection_for_testing;
205 }
206 static base::NoDestructor<ServiceConnectionImpl> service_connection;
207 return service_connection.get();
208 }
209
UseFakeServiceConnectionForTesting(ServiceConnection * const fake_service_connection)210 void ServiceConnection::UseFakeServiceConnectionForTesting(
211 ServiceConnection* const fake_service_connection) {
212 g_fake_service_connection_for_testing = fake_service_connection;
213 }
214
215 } // namespace machine_learning
216 } // namespace chromeos
217