1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chrome/browser/chromeos/power/ml/smart_dim/download_worker.h"
6 
7 #include "base/bind.h"
8 #include "base/task/task_traits.h"
9 #include "base/threading/sequenced_task_runner_handle.h"
10 #include "chrome/browser/chromeos/power/ml/smart_dim/metrics.h"
11 #include "chrome/browser/chromeos/power/ml/smart_dim/ml_agent_util.h"
12 #include "chromeos/services/machine_learning/public/cpp/service_connection.h"
13 #include "components/assist_ranker/proto/example_preprocessor.pb.h"
14 #include "content/public/browser/browser_task_traits.h"
15 #include "content/public/browser/browser_thread.h"
16 #include "ui/base/resource/resource_bundle.h"
17 
18 namespace chromeos {
19 namespace power {
20 namespace ml {
21 
22 namespace {
23 using chromeos::machine_learning::mojom::FlatBufferModelSpec;
24 }  // namespace
25 
DownloadWorker()26 DownloadWorker::DownloadWorker() : SmartDimWorker(), metrics_model_name_("") {}
27 
28 DownloadWorker::~DownloadWorker() = default;
29 
30 const assist_ranker::ExamplePreprocessorConfig*
GetPreprocessorConfig()31 DownloadWorker::GetPreprocessorConfig() {
32   return preprocessor_config_.get();
33 }
34 
35 const mojo::Remote<chromeos::machine_learning::mojom::GraphExecutor>&
GetExecutor()36 DownloadWorker::GetExecutor() {
37   return executor_;
38 }
39 
LoadModelCallback(chromeos::machine_learning::mojom::LoadModelResult result)40 void DownloadWorker::LoadModelCallback(
41     chromeos::machine_learning::mojom::LoadModelResult result) {
42   if (result != chromeos::machine_learning::mojom::LoadModelResult::OK) {
43     LogLoadComponentEvent(LoadComponentEvent::kLoadModelError);
44     DVLOG(1) << "Failed to load Smart Dim flatbuffer model.";
45   }
46 }
47 
CreateGraphExecutorCallback(chromeos::machine_learning::mojom::CreateGraphExecutorResult result)48 void DownloadWorker::CreateGraphExecutorCallback(
49     chromeos::machine_learning::mojom::CreateGraphExecutorResult result) {
50   if (result !=
51       chromeos::machine_learning::mojom::CreateGraphExecutorResult::OK) {
52     LogLoadComponentEvent(LoadComponentEvent::kCreateGraphExecutorError);
53     DVLOG(1) << "Failed to create a Smart Dim graph executor.";
54   } else {
55     LogLoadComponentEvent(LoadComponentEvent::kSuccess);
56   }
57 }
58 
IsReady()59 bool DownloadWorker::IsReady() {
60   return preprocessor_config_ && model_ && executor_ &&
61          expected_feature_size_ > 0 && metrics_model_name_ != "";
62 }
63 
InitializeFromComponent(const ComponentFileContents & contents)64 void DownloadWorker::InitializeFromComponent(
65     const ComponentFileContents& contents) {
66   DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
67 
68   std::string metadata_json, preprocessor_proto, model_flatbuffer;
69   std::tie(metadata_json, preprocessor_proto, model_flatbuffer) = contents;
70 
71   preprocessor_config_ =
72       std::make_unique<assist_ranker::ExamplePreprocessorConfig>();
73   if (!preprocessor_config_->ParseFromString(preprocessor_proto)) {
74     LogLoadComponentEvent(LoadComponentEvent::kLoadPreprocessorError);
75     DVLOG(1) << "Failed to load preprocessor_config.";
76     preprocessor_config_.reset();
77     return;
78   }
79 
80   // Meta data contains necessary info to construct FlatBufferModelSpec, and
81   // other optional info.
82   data_decoder::DataDecoder::ParseJsonIsolated(
83       std::move(metadata_json),
84       base::BindOnce(&DownloadWorker::OnJsonParsed, base::Unretained(this),
85                      std::move(model_flatbuffer)));
86 }
87 
OnJsonParsed(const std::string & model_flatbuffer,const data_decoder::DataDecoder::ValueOrError result)88 void DownloadWorker::OnJsonParsed(
89     const std::string& model_flatbuffer,
90     const data_decoder::DataDecoder::ValueOrError result) {
91   DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
92   if (!result.value || !result.value->is_dict() ||
93       !ParseMetaInfoFromJsonObject(result.value.value(), &metrics_model_name_,
94                                    &dim_threshold_, &expected_feature_size_,
95                                    &inputs_, &outputs_)) {
96     LogLoadComponentEvent(LoadComponentEvent::kLoadMetadataError);
97     DVLOG(1) << "Failed to parse meta info from metadata_json.";
98     return;
99   }
100   content::GetUIThreadTaskRunner({base::TaskPriority::BEST_EFFORT})
101       ->PostTask(
102           FROM_HERE,
103           base::BindOnce(&DownloadWorker::LoadModelAndCreateGraphExecutor,
104                          base::Unretained(this), std::move(model_flatbuffer)));
105 }
106 
LoadModelAndCreateGraphExecutor(const std::string & model_flatbuffer)107 void DownloadWorker::LoadModelAndCreateGraphExecutor(
108     const std::string& model_flatbuffer) {
109   DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
110   DCHECK(!model_.is_bound() && !executor_.is_bound());
111 
112   chromeos::machine_learning::ServiceConnection::GetInstance()
113       ->LoadFlatBufferModel(
114           FlatBufferModelSpec::New(std::move(model_flatbuffer), inputs_,
115                                    outputs_, metrics_model_name_),
116           model_.BindNewPipeAndPassReceiver(),
117           base::BindOnce(&DownloadWorker::LoadModelCallback,
118                          base::Unretained(this)));
119   model_->CreateGraphExecutor(
120       executor_.BindNewPipeAndPassReceiver(),
121       base::BindOnce(&DownloadWorker::CreateGraphExecutorCallback,
122                      base::Unretained(this)));
123   executor_.set_disconnect_handler(base::BindOnce(
124       &DownloadWorker::OnConnectionError, base::Unretained(this)));
125 }
126 
127 }  // namespace ml
128 }  // namespace power
129 }  // namespace chromeos
130