1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "chrome/browser/ui/app_list/search/search_result_ranker/ml_app_rank_provider.h"
6
7 #include <utility>
8
9 #include "base/bind.h"
10 #include "base/callback.h"
11 #include "base/location.h"
12 #include "base/memory/ref_counted_memory.h"
13 #include "base/strings/stringprintf.h"
14 #include "base/task/post_task.h"
15 #include "base/task/task_traits.h"
16 #include "base/task/thread_pool.h"
17 #include "base/time/time.h"
18 #include "chrome/browser/chromeos/power/ml/user_activity_ukm_logger_helpers.h"
19 #include "chrome/browser/ui/app_list/search/search_result_ranker/app_launch_event_logger_helper.h"
20 #include "chrome/grit/browser_resources.h"
21 #include "chromeos/services/machine_learning/public/cpp/service_connection.h"
22 #include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
23 #include "components/assist_ranker/example_preprocessing.h"
24 #include "components/crx_file/id_util.h"
25 #include "content/public/browser/browser_task_traits.h"
26 #include "content/public/browser/browser_thread.h"
27 #include "ui/base/resource/resource_bundle.h"
28
29 using ::chromeos::machine_learning::mojom::BuiltinModelId;
30 using ::chromeos::machine_learning::mojom::BuiltinModelSpec;
31 using ::chromeos::machine_learning::mojom::BuiltinModelSpecPtr;
32 using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
33 using ::chromeos::machine_learning::mojom::ExecuteResult;
34 using ::chromeos::machine_learning::mojom::FloatList;
35 using ::chromeos::machine_learning::mojom::Int64List;
36 using ::chromeos::machine_learning::mojom::LoadModelResult;
37 using ::chromeos::machine_learning::mojom::Tensor;
38 using ::chromeos::machine_learning::mojom::TensorPtr;
39 using ::chromeos::machine_learning::mojom::ValueList;
40
41 namespace app_list {
42
43 namespace {
44
LoadModelCallback(LoadModelResult result)45 void LoadModelCallback(LoadModelResult result) {
46 if (result != LoadModelResult::OK) {
47 LOG(ERROR) << "Failed to load Top Cat model.";
48 }
49 }
50
CreateGraphExecutorCallback(CreateGraphExecutorResult result)51 void CreateGraphExecutorCallback(CreateGraphExecutorResult result) {
52 if (result != CreateGraphExecutorResult::OK) {
53 LOG(ERROR) << "Failed to create a Top Cat Graph Executor.";
54 }
55 }
56
57 // Returns: true if preprocessor config loaded, false if it could not be loaded.
LoadExamplePreprocessorConfig(assist_ranker::ExamplePreprocessorConfig * preprocessor_config)58 bool LoadExamplePreprocessorConfig(
59 assist_ranker::ExamplePreprocessorConfig* preprocessor_config) {
60 DCHECK(preprocessor_config);
61
62 const int resource_id = IDR_TOP_CAT_20190722_EXAMPLE_PREPROCESSOR_CONFIG_PB;
63 const scoped_refptr<base::RefCountedMemory> raw_config =
64 ui::ResourceBundle::GetSharedInstance().LoadDataResourceBytes(
65 resource_id);
66 if (!raw_config || !raw_config->front()) {
67 LOG(ERROR) << "Failed to load TopCatModel example preprocessor config.";
68 return false;
69 }
70
71 if (!preprocessor_config->ParseFromArray(raw_config->front(),
72 raw_config->size())) {
73 LOG(ERROR) << "Failed to parse TopCatModel example preprocessor config.";
74 return false;
75 }
76 return true;
77 }
78
79 // Perform the inference given the |features| and |app_id| of an app.
80 // Posts |callback| to |task_runner| to perform the actual inference.
DoInference(const std::string & app_id,const std::vector<float> & features,scoped_refptr<base::SequencedTaskRunner> task_runner,const base::RepeatingCallback<void (base::flat_map<std::string,TensorPtr> inputs,const std::vector<std::string> outputs,const std::string app_id)> callback)81 void DoInference(const std::string& app_id,
82 const std::vector<float>& features,
83 scoped_refptr<base::SequencedTaskRunner> task_runner,
84 const base::RepeatingCallback<
85 void(base::flat_map<std::string, TensorPtr> inputs,
86 const std::vector<std::string> outputs,
87 const std::string app_id)> callback) {
88 // Prepare the input tensor.
89 base::flat_map<std::string, TensorPtr> inputs;
90 auto tensor = Tensor::New();
91 tensor->shape = Int64List::New();
92 tensor->shape->value = std::vector<int64_t>({1, features.size()});
93 tensor->data = ValueList::New();
94 tensor->data->set_float_list(FloatList::New());
95 tensor->data->get_float_list()->value =
96 std::vector<double>(std::begin(features), std::end(features));
97 inputs.emplace(std::string("input"), std::move(tensor));
98
99 const std::vector<std::string> outputs({std::string("output")});
100 DCHECK(task_runner);
101 task_runner->PostTask(FROM_HERE, base::BindOnce(callback, std::move(inputs),
102 std::move(outputs), app_id));
103 }
104
105 // Process the RankerExample to vectorize the feature list for inference.
106 // Returns true on success.
RankerExampleToVectorizedFeatures(const assist_ranker::ExamplePreprocessorConfig & preprocessor_config,assist_ranker::RankerExample * example,std::vector<float> * vectorized_features)107 bool RankerExampleToVectorizedFeatures(
108 const assist_ranker::ExamplePreprocessorConfig& preprocessor_config,
109 assist_ranker::RankerExample* example,
110 std::vector<float>* vectorized_features) {
111 int preprocessor_error = assist_ranker::ExamplePreprocessor::Process(
112 preprocessor_config, example, true);
113 // kNoFeatureIndexFound can occur normally (e.g., when the app URL
114 // isn't known to the model or a rarely seen enum value is used).
115 if (preprocessor_error != assist_ranker::ExamplePreprocessor::kSuccess &&
116 preprocessor_error !=
117 assist_ranker::ExamplePreprocessor::kNoFeatureIndexFound) {
118 // TODO: Log to UMA.
119 return false;
120 }
121
122 const auto& extracted_features =
123 example->features()
124 .at(assist_ranker::ExamplePreprocessor::kVectorizedFeatureDefaultName)
125 .float_list()
126 .float_value();
127 vectorized_features->assign(extracted_features.begin(),
128 extracted_features.end());
129 return true;
130 }
131
132 // Does the CPU-intensive part of CreateRankings (preparing the Tensor inputs
133 // from |app_features_map|, intended to be called on a low-priority
134 // background thread. Invokes |callback| on |task_runner| once for each app in
135 // |app_features_map|.
CreateRankingsImpl(base::flat_map<std::string,AppLaunchFeatures> app_features_map,int total_hours,int all_clicks_last_hour,int all_clicks_last_24_hours,scoped_refptr<base::SequencedTaskRunner> task_runner,const base::RepeatingCallback<void (base::flat_map<std::string,TensorPtr> inputs,const std::vector<std::string> outputs,const std::string app_id)> & callback)136 void CreateRankingsImpl(
137 base::flat_map<std::string, AppLaunchFeatures> app_features_map,
138 int total_hours,
139 int all_clicks_last_hour,
140 int all_clicks_last_24_hours,
141 scoped_refptr<base::SequencedTaskRunner> task_runner,
142 const base::RepeatingCallback<
143 void(base::flat_map<std::string, TensorPtr> inputs,
144 const std::vector<std::string> outputs,
145 const std::string app_id)>& callback) {
146 const base::Time now(base::Time::Now());
147 const int hour = HourOfDay(now);
148 const int day = DayOfWeek(now);
149
150 assist_ranker::ExamplePreprocessorConfig preprocessor_config;
151 if (!LoadExamplePreprocessorConfig(&preprocessor_config)) {
152 return;
153 }
154 for (auto& app : app_features_map) {
155 assist_ranker::RankerExample example(
156 CreateRankerExample(app.second,
157 now.ToDeltaSinceWindowsEpoch().InSeconds() -
158 app.second.time_of_last_click_sec(),
159 total_hours, day, hour, all_clicks_last_hour,
160 all_clicks_last_24_hours));
161 std::vector<float> vectorized_features;
162 if (RankerExampleToVectorizedFeatures(preprocessor_config, &example,
163 &vectorized_features)) {
164 DoInference(app.first, vectorized_features, task_runner, callback);
165 }
166 }
167 }
168
169 } // namespace
170
CreateRankerExample(const AppLaunchFeatures & features,int time_since_last_click,int total_hours,int day_of_week,int hour_of_day,int all_clicks_last_hour,int all_clicks_last_24_hours)171 assist_ranker::RankerExample CreateRankerExample(
172 const AppLaunchFeatures& features,
173 int time_since_last_click,
174 int total_hours,
175 int day_of_week,
176 int hour_of_day,
177 int all_clicks_last_hour,
178 int all_clicks_last_24_hours) {
179 assist_ranker::RankerExample example;
180 auto& ranker_example_features = *example.mutable_features();
181
182 ranker_example_features["DayOfWeek"].set_int32_value(day_of_week);
183 ranker_example_features["HourOfDay"].set_int32_value(hour_of_day);
184 ranker_example_features["AllClicksLastHour"].set_int32_value(
185 all_clicks_last_hour);
186 ranker_example_features["AllClicksLast24Hours"].set_int32_value(
187 all_clicks_last_24_hours);
188
189 ranker_example_features["AppType"].set_int32_value(features.app_type());
190 ranker_example_features["ClickRank"].set_int32_value(features.click_rank());
191 ranker_example_features["ClicksLastHour"].set_int32_value(
192 features.clicks_last_hour());
193 ranker_example_features["ClicksLast24Hours"].set_int32_value(
194 features.clicks_last_24_hours());
195 ranker_example_features["LastLaunchedFrom"].set_int32_value(
196 features.last_launched_from());
197 ranker_example_features["HasClick"].set_bool_value(
198 features.has_most_recently_used_index());
199 ranker_example_features["MostRecentlyUsedIndex"].set_int32_value(
200 features.most_recently_used_index());
201 ranker_example_features["TimeSinceLastClick"].set_int32_value(
202 Bucketize(time_since_last_click, kTimeSinceLastClickBuckets));
203 ranker_example_features["TotalClicks"].set_int32_value(
204 features.total_clicks());
205 ranker_example_features["TotalClicksPerHour"].set_float_value(
206 static_cast<float>(features.total_clicks()) / (total_hours + 1));
207 ranker_example_features["TotalHours"].set_int32_value(total_hours);
208
209 // Calculate FourHourClicksN and SixHourClicksN, which sum clicks for four
210 // and six hour periods respectively.
211 int four_hour_count = 0;
212 int six_hour_count = 0;
213 // Apps that have been clicked will have 24 clicks_each_hour values. Apps that
214 // have not been clicked will have no clicks_each_hour values, so can skip
215 // the FourHourClicksN and SixHourClicksN calculations.
216 if (features.clicks_each_hour_size() == 24) {
217 for (int hour = 0; hour < 24; hour++) {
218 int clicks = Bucketize(features.clicks_each_hour(hour), kClickBuckets);
219 ranker_example_features["ClicksEachHour" +
220 base::StringPrintf("%02d", hour)]
221 .set_int32_value(clicks);
222 ranker_example_features["ClicksPerHour" +
223 base::StringPrintf("%02d", hour)]
224 .set_float_value(static_cast<float>(clicks) / (total_hours + 1));
225 four_hour_count += clicks;
226 six_hour_count += clicks;
227 // Divide day into periods of 4 hours each.
228 if (hour % 4 == 3 && four_hour_count != 0) {
229 ranker_example_features["FourHourClicks" +
230 base::StringPrintf("%01d", hour / 4)]
231 .set_int32_value(four_hour_count);
232 four_hour_count = 0;
233 }
234 // Divide day into periods of 6 hours each.
235 if (hour % 6 == 5 && six_hour_count != 0) {
236 ranker_example_features["SixHourClicks" +
237 base::StringPrintf("%01d", hour / 6)]
238 .set_int32_value(six_hour_count);
239 six_hour_count = 0;
240 }
241 }
242 }
243
244 if (features.app_type() == AppLaunchEvent_AppType_CHROME) {
245 ranker_example_features["URL"].set_string_value(
246 kExtensionSchemeWithDelimiter + features.app_id());
247 } else if (features.app_type() == AppLaunchEvent_AppType_PWA) {
248 ranker_example_features["URL"].set_string_value(features.pwa_url());
249 } else if (features.app_type() == AppLaunchEvent_AppType_PLAY) {
250 ranker_example_features["URL"].set_string_value(
251 kAppScheme +
252 crx_file::id_util::GenerateId(features.arc_package_name()));
253 } else {
254 // TODO(crbug.com/1027782): Add DCHECK that this branch is not reached.
255 }
256 return example;
257 }
258
MlAppRankProvider()259 MlAppRankProvider::MlAppRankProvider()
260 : creation_task_runner_(base::SequencedTaskRunnerHandle::Get()),
261 background_task_runner_(base::ThreadPool::CreateSequencedTaskRunner(
262 {base::TaskPriority::BEST_EFFORT,
263 base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN})) {}
264
265 MlAppRankProvider::~MlAppRankProvider() = default;
266
CreateRankings(const base::flat_map<std::string,AppLaunchFeatures> & app_features_map,int total_hours,int all_clicks_last_hour,int all_clicks_last_24_hours)267 void MlAppRankProvider::CreateRankings(
268 const base::flat_map<std::string, AppLaunchFeatures>& app_features_map,
269 int total_hours,
270 int all_clicks_last_hour,
271 int all_clicks_last_24_hours) {
272 DCHECK_CALLED_ON_VALID_SEQUENCE(creation_sequence_checker_);
273 // TODO(jennyz): Add start-to-end latency metrics for the work on each
274 // sequence.
275 background_task_runner_->PostTask(
276 FROM_HERE,
277 base::BindOnce(&CreateRankingsImpl, app_features_map, total_hours,
278 all_clicks_last_hour, all_clicks_last_24_hours,
279 creation_task_runner_,
280 base::BindRepeating(&MlAppRankProvider::RunExecutor,
281 weak_factory_.GetWeakPtr())));
282 }
283
RetrieveRankings()284 std::map<std::string, float> MlAppRankProvider::RetrieveRankings() {
285 DCHECK_CALLED_ON_VALID_SEQUENCE(creation_sequence_checker_);
286 return ranking_map_;
287 }
288
RunExecutor(base::flat_map<std::string,TensorPtr> inputs,const std::vector<std::string> outputs,const std::string app_id)289 void MlAppRankProvider::RunExecutor(
290 base::flat_map<std::string, TensorPtr> inputs,
291 const std::vector<std::string> outputs,
292 const std::string app_id) {
293 DCHECK_CALLED_ON_VALID_SEQUENCE(creation_sequence_checker_);
294 BindGraphExecutorIfNeeded();
295 executor_->Execute(std::move(inputs), std::move(outputs),
296 base::BindOnce(&MlAppRankProvider::ExecuteCallback,
297 base::Unretained(this), app_id));
298 }
299
ExecuteCallback(std::string app_id,ExecuteResult result,const base::Optional<std::vector<TensorPtr>> outputs)300 void MlAppRankProvider::ExecuteCallback(
301 std::string app_id,
302 ExecuteResult result,
303 const base::Optional<std::vector<TensorPtr>> outputs) {
304 DCHECK_CALLED_ON_VALID_SEQUENCE(creation_sequence_checker_);
305 if (result != ExecuteResult::OK) {
306 LOG(ERROR) << "Top Cat inference execution failed.";
307 return;
308 }
309 ranking_map_[app_id] = outputs.value()[0]->data->get_float_list()->value[0];
310 }
311
BindGraphExecutorIfNeeded()312 void MlAppRankProvider::BindGraphExecutorIfNeeded() {
313 DCHECK_CALLED_ON_VALID_SEQUENCE(creation_sequence_checker_);
314 if (!model_) {
315 // Load the model.
316 BuiltinModelSpecPtr spec =
317 BuiltinModelSpec::New(BuiltinModelId::TOP_CAT_20190722);
318 chromeos::machine_learning::ServiceConnection::GetInstance()
319 ->LoadBuiltinModel(std::move(spec), model_.BindNewPipeAndPassReceiver(),
320 base::BindOnce(&LoadModelCallback));
321 }
322
323 if (!executor_) {
324 // Get the graph executor.
325 model_->CreateGraphExecutor(executor_.BindNewPipeAndPassReceiver(),
326 base::BindOnce(&CreateGraphExecutorCallback));
327 executor_.set_disconnect_handler(base::BindOnce(
328 &MlAppRankProvider::OnConnectionError, base::Unretained(this)));
329 }
330 }
331
OnConnectionError()332 void MlAppRankProvider::OnConnectionError() {
333 LOG(WARNING) << "Mojo connection for ML service closed.";
334 executor_.reset();
335 model_.reset();
336 }
337
338 } // namespace app_list
339