1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "chrome/browser/chromeos/power/ml/smart_dim/download_worker.h"
6
7 #include "base/bind.h"
8 #include "base/task/task_traits.h"
9 #include "base/threading/sequenced_task_runner_handle.h"
10 #include "chrome/browser/chromeos/power/ml/smart_dim/metrics.h"
11 #include "chrome/browser/chromeos/power/ml/smart_dim/ml_agent_util.h"
12 #include "chromeos/services/machine_learning/public/cpp/service_connection.h"
13 #include "components/assist_ranker/proto/example_preprocessor.pb.h"
14 #include "content/public/browser/browser_task_traits.h"
15 #include "content/public/browser/browser_thread.h"
16 #include "ui/base/resource/resource_bundle.h"
17
18 namespace chromeos {
19 namespace power {
20 namespace ml {
21
22 namespace {
23 using chromeos::machine_learning::mojom::FlatBufferModelSpec;
24 } // namespace
25
DownloadWorker()26 DownloadWorker::DownloadWorker() : SmartDimWorker(), metrics_model_name_("") {}
27
28 DownloadWorker::~DownloadWorker() = default;
29
30 const assist_ranker::ExamplePreprocessorConfig*
GetPreprocessorConfig()31 DownloadWorker::GetPreprocessorConfig() {
32 return preprocessor_config_.get();
33 }
34
35 const mojo::Remote<chromeos::machine_learning::mojom::GraphExecutor>&
GetExecutor()36 DownloadWorker::GetExecutor() {
37 return executor_;
38 }
39
LoadModelCallback(chromeos::machine_learning::mojom::LoadModelResult result)40 void DownloadWorker::LoadModelCallback(
41 chromeos::machine_learning::mojom::LoadModelResult result) {
42 if (result != chromeos::machine_learning::mojom::LoadModelResult::OK) {
43 LogLoadComponentEvent(LoadComponentEvent::kLoadModelError);
44 DVLOG(1) << "Failed to load Smart Dim flatbuffer model.";
45 }
46 }
47
CreateGraphExecutorCallback(chromeos::machine_learning::mojom::CreateGraphExecutorResult result)48 void DownloadWorker::CreateGraphExecutorCallback(
49 chromeos::machine_learning::mojom::CreateGraphExecutorResult result) {
50 if (result !=
51 chromeos::machine_learning::mojom::CreateGraphExecutorResult::OK) {
52 LogLoadComponentEvent(LoadComponentEvent::kCreateGraphExecutorError);
53 DVLOG(1) << "Failed to create a Smart Dim graph executor.";
54 } else {
55 LogLoadComponentEvent(LoadComponentEvent::kSuccess);
56 }
57 }
58
IsReady()59 bool DownloadWorker::IsReady() {
60 return preprocessor_config_ && model_ && executor_ &&
61 expected_feature_size_ > 0 && metrics_model_name_ != "";
62 }
63
InitializeFromComponent(const ComponentFileContents & contents)64 void DownloadWorker::InitializeFromComponent(
65 const ComponentFileContents& contents) {
66 DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
67
68 std::string metadata_json, preprocessor_proto, model_flatbuffer;
69 std::tie(metadata_json, preprocessor_proto, model_flatbuffer) = contents;
70
71 preprocessor_config_ =
72 std::make_unique<assist_ranker::ExamplePreprocessorConfig>();
73 if (!preprocessor_config_->ParseFromString(preprocessor_proto)) {
74 LogLoadComponentEvent(LoadComponentEvent::kLoadPreprocessorError);
75 DVLOG(1) << "Failed to load preprocessor_config.";
76 preprocessor_config_.reset();
77 return;
78 }
79
80 // Meta data contains necessary info to construct FlatBufferModelSpec, and
81 // other optional info.
82 data_decoder::DataDecoder::ParseJsonIsolated(
83 std::move(metadata_json),
84 base::BindOnce(&DownloadWorker::OnJsonParsed, base::Unretained(this),
85 std::move(model_flatbuffer)));
86 }
87
OnJsonParsed(const std::string & model_flatbuffer,const data_decoder::DataDecoder::ValueOrError result)88 void DownloadWorker::OnJsonParsed(
89 const std::string& model_flatbuffer,
90 const data_decoder::DataDecoder::ValueOrError result) {
91 DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
92 if (!result.value || !result.value->is_dict() ||
93 !ParseMetaInfoFromJsonObject(result.value.value(), &metrics_model_name_,
94 &dim_threshold_, &expected_feature_size_,
95 &inputs_, &outputs_)) {
96 LogLoadComponentEvent(LoadComponentEvent::kLoadMetadataError);
97 DVLOG(1) << "Failed to parse meta info from metadata_json.";
98 return;
99 }
100 content::GetUIThreadTaskRunner({base::TaskPriority::BEST_EFFORT})
101 ->PostTask(
102 FROM_HERE,
103 base::BindOnce(&DownloadWorker::LoadModelAndCreateGraphExecutor,
104 base::Unretained(this), std::move(model_flatbuffer)));
105 }
106
LoadModelAndCreateGraphExecutor(const std::string & model_flatbuffer)107 void DownloadWorker::LoadModelAndCreateGraphExecutor(
108 const std::string& model_flatbuffer) {
109 DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
110 DCHECK(!model_.is_bound() && !executor_.is_bound());
111
112 chromeos::machine_learning::ServiceConnection::GetInstance()
113 ->LoadFlatBufferModel(
114 FlatBufferModelSpec::New(std::move(model_flatbuffer), inputs_,
115 outputs_, metrics_model_name_),
116 model_.BindNewPipeAndPassReceiver(),
117 base::BindOnce(&DownloadWorker::LoadModelCallback,
118 base::Unretained(this)));
119 model_->CreateGraphExecutor(
120 executor_.BindNewPipeAndPassReceiver(),
121 base::BindOnce(&DownloadWorker::CreateGraphExecutorCallback,
122 base::Unretained(this)));
123 executor_.set_disconnect_handler(base::BindOnce(
124 &DownloadWorker::OnConnectionError, base::Unretained(this)));
125 }
126
127 } // namespace ml
128 } // namespace power
129 } // namespace chromeos
130