1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chromeos/services/machine_learning/public/cpp/service_connection.h"
6 
7 #include "base/bind.h"
8 #include "base/macros.h"
9 #include "base/no_destructor.h"
10 #include "base/sequence_checker.h"
11 #include "chromeos/dbus/machine_learning/machine_learning_client.h"
12 #include "chromeos/services/machine_learning/public/mojom/handwriting_recognizer.mojom.h"
13 #include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
14 #include "chromeos/services/machine_learning/public/mojom/model.mojom.h"
15 #include "mojo/public/cpp/bindings/remote.h"
16 #include "mojo/public/cpp/platform/platform_channel.h"
17 #include "mojo/public/cpp/system/invitation.h"
18 #include "third_party/cros_system_api/dbus/service_constants.h"
19 
20 namespace chromeos {
21 namespace machine_learning {
22 
23 namespace {
24 
25 // Real Impl of ServiceConnection
26 class ServiceConnectionImpl : public ServiceConnection {
27  public:
28   ServiceConnectionImpl();
29   ~ServiceConnectionImpl() override = default;
30 
31   void LoadBuiltinModel(mojom::BuiltinModelSpecPtr spec,
32                         mojo::PendingReceiver<mojom::Model> receiver,
33                         mojom::MachineLearningService::LoadBuiltinModelCallback
34                             result_callback) override;
35 
36   void LoadFlatBufferModel(
37       mojom::FlatBufferModelSpecPtr spec,
38       mojo::PendingReceiver<mojom::Model> receiver,
39       mojom::MachineLearningService::LoadFlatBufferModelCallback
40           result_callback) override;
41 
42   void LoadTextClassifier(
43       mojo::PendingReceiver<mojom::TextClassifier> receiver,
44       mojom::MachineLearningService::LoadTextClassifierCallback
45           result_callback) override;
46 
47   void LoadHandwritingModel(
48       mojom::HandwritingRecognizerSpecPtr spec,
49       mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,
50       mojom::MachineLearningService::LoadHandwritingModelCallback
51           result_callback) override;
52 
53   void LoadHandwritingModelWithSpec(
54       mojom::HandwritingRecognizerSpecPtr spec,
55       mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,
56       mojom::MachineLearningService::LoadHandwritingModelWithSpecCallback
57           result_callback) override;
58 
59   void LoadGrammarChecker(
60       mojo::PendingReceiver<mojom::GrammarChecker> receiver,
61       mojom::MachineLearningService::LoadGrammarCheckerCallback result_callback)
62       override;
63 
64  private:
65   // Binds the top level interface |machine_learning_service_| to an
66   // implementation in the ML Service daemon, if it is not already bound. The
67   // binding is accomplished via D-Bus bootstrap.
68   void BindMachineLearningServiceIfNeeded();
69 
70   // Mojo disconnect handler. Resets |machine_learning_service_|, which
71   // will be reconnected upon next use.
72   void OnMojoDisconnect();
73 
74   // Response callback for MlClient::BootstrapMojoConnection.
75   void OnBootstrapMojoConnectionResponse(bool success);
76 
77   mojo::Remote<mojom::MachineLearningService> machine_learning_service_;
78 
79   SEQUENCE_CHECKER(sequence_checker_);
80 
81   DISALLOW_COPY_AND_ASSIGN(ServiceConnectionImpl);
82 };
83 
LoadBuiltinModel(mojom::BuiltinModelSpecPtr spec,mojo::PendingReceiver<mojom::Model> receiver,mojom::MachineLearningService::LoadBuiltinModelCallback result_callback)84 void ServiceConnectionImpl::LoadBuiltinModel(
85     mojom::BuiltinModelSpecPtr spec,
86     mojo::PendingReceiver<mojom::Model> receiver,
87     mojom::MachineLearningService::LoadBuiltinModelCallback result_callback) {
88   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
89   BindMachineLearningServiceIfNeeded();
90   machine_learning_service_->LoadBuiltinModel(
91       std::move(spec), std::move(receiver), std::move(result_callback));
92 }
93 
LoadFlatBufferModel(mojom::FlatBufferModelSpecPtr spec,mojo::PendingReceiver<mojom::Model> receiver,mojom::MachineLearningService::LoadFlatBufferModelCallback result_callback)94 void ServiceConnectionImpl::LoadFlatBufferModel(
95     mojom::FlatBufferModelSpecPtr spec,
96     mojo::PendingReceiver<mojom::Model> receiver,
97     mojom::MachineLearningService::LoadFlatBufferModelCallback
98         result_callback) {
99   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
100   BindMachineLearningServiceIfNeeded();
101   machine_learning_service_->LoadFlatBufferModel(
102       std::move(spec), std::move(receiver), std::move(result_callback));
103 }
104 
LoadTextClassifier(mojo::PendingReceiver<mojom::TextClassifier> receiver,mojom::MachineLearningService::LoadTextClassifierCallback result_callback)105 void ServiceConnectionImpl::LoadTextClassifier(
106     mojo::PendingReceiver<mojom::TextClassifier> receiver,
107     mojom::MachineLearningService::LoadTextClassifierCallback result_callback) {
108   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
109   BindMachineLearningServiceIfNeeded();
110   machine_learning_service_->LoadTextClassifier(std::move(receiver),
111                                                 std::move(result_callback));
112 }
113 
LoadHandwritingModel(mojom::HandwritingRecognizerSpecPtr spec,mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,mojom::MachineLearningService::LoadHandwritingModelCallback result_callback)114 void ServiceConnectionImpl::LoadHandwritingModel(
115     mojom::HandwritingRecognizerSpecPtr spec,
116     mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,
117     mojom::MachineLearningService::LoadHandwritingModelCallback
118         result_callback) {
119   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
120   BindMachineLearningServiceIfNeeded();
121   machine_learning_service_->LoadHandwritingModel(
122       std::move(spec), std::move(receiver), std::move(result_callback));
123 }
124 
LoadHandwritingModelWithSpec(mojom::HandwritingRecognizerSpecPtr spec,mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,mojom::MachineLearningService::LoadHandwritingModelWithSpecCallback result_callback)125 void ServiceConnectionImpl::LoadHandwritingModelWithSpec(
126     mojom::HandwritingRecognizerSpecPtr spec,
127     mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,
128     mojom::MachineLearningService::LoadHandwritingModelWithSpecCallback
129         result_callback) {
130   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
131   BindMachineLearningServiceIfNeeded();
132   machine_learning_service_->LoadHandwritingModelWithSpec(
133       std::move(spec), std::move(receiver), std::move(result_callback));
134 }
135 
LoadGrammarChecker(mojo::PendingReceiver<mojom::GrammarChecker> receiver,mojom::MachineLearningService::LoadGrammarCheckerCallback result_callback)136 void ServiceConnectionImpl::LoadGrammarChecker(
137     mojo::PendingReceiver<mojom::GrammarChecker> receiver,
138     mojom::MachineLearningService::LoadGrammarCheckerCallback result_callback) {
139   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
140   BindMachineLearningServiceIfNeeded();
141   machine_learning_service_->LoadGrammarChecker(std::move(receiver),
142                                                 std::move(result_callback));
143 }
144 
BindMachineLearningServiceIfNeeded()145 void ServiceConnectionImpl::BindMachineLearningServiceIfNeeded() {
146   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
147   if (machine_learning_service_) {
148     return;
149   }
150 
151   mojo::PlatformChannel platform_channel;
152 
153   // Prepare a Mojo invitation to send through |platform_channel|.
154   mojo::OutgoingInvitation invitation;
155   // Include an initial Mojo pipe in the invitation.
156   mojo::ScopedMessagePipeHandle pipe =
157       invitation.AttachMessagePipe(ml::kBootstrapMojoConnectionChannelToken);
158   mojo::OutgoingInvitation::Send(std::move(invitation),
159                                  base::kNullProcessHandle,
160                                  platform_channel.TakeLocalEndpoint());
161 
162   // Bind our end of |pipe| to our mojo::Remote<MachineLearningService>. The
163   // daemon should bind its end to a MachineLearningService implementation.
164   machine_learning_service_.Bind(
165       mojo::PendingRemote<machine_learning::mojom::MachineLearningService>(
166           std::move(pipe), 0u /* version */));
167   machine_learning_service_.set_disconnect_handler(base::BindOnce(
168       &ServiceConnectionImpl::OnMojoDisconnect, base::Unretained(this)));
169 
170   // Send the file descriptor for the other end of |platform_channel| to the
171   // ML service daemon over D-Bus.
172   MachineLearningClient::Get()->BootstrapMojoConnection(
173       platform_channel.TakeRemoteEndpoint().TakePlatformHandle().TakeFD(),
174       base::BindOnce(&ServiceConnectionImpl::OnBootstrapMojoConnectionResponse,
175                      base::Unretained(this)));
176 }
177 
ServiceConnectionImpl()178 ServiceConnectionImpl::ServiceConnectionImpl() {
179   DETACH_FROM_SEQUENCE(sequence_checker_);
180 }
181 
OnMojoDisconnect()182 void ServiceConnectionImpl::OnMojoDisconnect() {
183   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
184   // Connection errors are not expected so log a warning.
185   LOG(WARNING) << "ML Service Mojo connection closed";
186   machine_learning_service_.reset();
187 }
188 
OnBootstrapMojoConnectionResponse(const bool success)189 void ServiceConnectionImpl::OnBootstrapMojoConnectionResponse(
190     const bool success) {
191   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
192   if (!success) {
193     LOG(WARNING) << "BootstrapMojoConnection D-Bus call failed";
194     machine_learning_service_.reset();
195   }
196 }
197 
198 static ServiceConnection* g_fake_service_connection_for_testing = nullptr;
199 
200 }  // namespace
201 
GetInstance()202 ServiceConnection* ServiceConnection::GetInstance() {
203   if (g_fake_service_connection_for_testing) {
204     return g_fake_service_connection_for_testing;
205   }
206   static base::NoDestructor<ServiceConnectionImpl> service_connection;
207   return service_connection.get();
208 }
209 
UseFakeServiceConnectionForTesting(ServiceConnection * const fake_service_connection)210 void ServiceConnection::UseFakeServiceConnectionForTesting(
211     ServiceConnection* const fake_service_connection) {
212   g_fake_service_connection_for_testing = fake_service_connection;
213 }
214 
215 }  // namespace machine_learning
216 }  // namespace chromeos
217