1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chrome/browser/ui/app_list/search/search_result_ranker/ml_app_rank_provider.h"
6 
7 #include <utility>
8 
9 #include "base/bind.h"
10 #include "base/callback.h"
11 #include "base/location.h"
12 #include "base/memory/ref_counted_memory.h"
13 #include "base/strings/stringprintf.h"
14 #include "base/task/post_task.h"
15 #include "base/task/task_traits.h"
16 #include "base/task/thread_pool.h"
17 #include "base/time/time.h"
18 #include "chrome/browser/chromeos/power/ml/user_activity_ukm_logger_helpers.h"
19 #include "chrome/browser/ui/app_list/search/search_result_ranker/app_launch_event_logger_helper.h"
20 #include "chrome/grit/browser_resources.h"
21 #include "chromeos/services/machine_learning/public/cpp/service_connection.h"
22 #include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
23 #include "components/assist_ranker/example_preprocessing.h"
24 #include "components/crx_file/id_util.h"
25 #include "content/public/browser/browser_task_traits.h"
26 #include "content/public/browser/browser_thread.h"
27 #include "ui/base/resource/resource_bundle.h"
28 
29 using ::chromeos::machine_learning::mojom::BuiltinModelId;
30 using ::chromeos::machine_learning::mojom::BuiltinModelSpec;
31 using ::chromeos::machine_learning::mojom::BuiltinModelSpecPtr;
32 using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
33 using ::chromeos::machine_learning::mojom::ExecuteResult;
34 using ::chromeos::machine_learning::mojom::FloatList;
35 using ::chromeos::machine_learning::mojom::Int64List;
36 using ::chromeos::machine_learning::mojom::LoadModelResult;
37 using ::chromeos::machine_learning::mojom::Tensor;
38 using ::chromeos::machine_learning::mojom::TensorPtr;
39 using ::chromeos::machine_learning::mojom::ValueList;
40 
41 namespace app_list {
42 
43 namespace {
44 
LoadModelCallback(LoadModelResult result)45 void LoadModelCallback(LoadModelResult result) {
46   if (result != LoadModelResult::OK) {
47     LOG(ERROR) << "Failed to load Top Cat model.";
48   }
49 }
50 
CreateGraphExecutorCallback(CreateGraphExecutorResult result)51 void CreateGraphExecutorCallback(CreateGraphExecutorResult result) {
52   if (result != CreateGraphExecutorResult::OK) {
53     LOG(ERROR) << "Failed to create a Top Cat Graph Executor.";
54   }
55 }
56 
57 // Returns: true if preprocessor config loaded, false if it could not be loaded.
LoadExamplePreprocessorConfig(assist_ranker::ExamplePreprocessorConfig * preprocessor_config)58 bool LoadExamplePreprocessorConfig(
59     assist_ranker::ExamplePreprocessorConfig* preprocessor_config) {
60   DCHECK(preprocessor_config);
61 
62   const int resource_id = IDR_TOP_CAT_20190722_EXAMPLE_PREPROCESSOR_CONFIG_PB;
63   const scoped_refptr<base::RefCountedMemory> raw_config =
64       ui::ResourceBundle::GetSharedInstance().LoadDataResourceBytes(
65           resource_id);
66   if (!raw_config || !raw_config->front()) {
67     LOG(ERROR) << "Failed to load TopCatModel example preprocessor config.";
68     return false;
69   }
70 
71   if (!preprocessor_config->ParseFromArray(raw_config->front(),
72                                            raw_config->size())) {
73     LOG(ERROR) << "Failed to parse TopCatModel example preprocessor config.";
74     return false;
75   }
76   return true;
77 }
78 
79 // Perform the inference given the |features| and |app_id| of an app.
80 // Posts |callback| to |task_runner| to perform the actual inference.
DoInference(const std::string & app_id,const std::vector<float> & features,scoped_refptr<base::SequencedTaskRunner> task_runner,const base::RepeatingCallback<void (base::flat_map<std::string,TensorPtr> inputs,const std::vector<std::string> outputs,const std::string app_id)> callback)81 void DoInference(const std::string& app_id,
82                  const std::vector<float>& features,
83                  scoped_refptr<base::SequencedTaskRunner> task_runner,
84                  const base::RepeatingCallback<
85                      void(base::flat_map<std::string, TensorPtr> inputs,
86                           const std::vector<std::string> outputs,
87                           const std::string app_id)> callback) {
88   // Prepare the input tensor.
89   base::flat_map<std::string, TensorPtr> inputs;
90   auto tensor = Tensor::New();
91   tensor->shape = Int64List::New();
92   tensor->shape->value = std::vector<int64_t>({1, features.size()});
93   tensor->data = ValueList::New();
94   tensor->data->set_float_list(FloatList::New());
95   tensor->data->get_float_list()->value =
96       std::vector<double>(std::begin(features), std::end(features));
97   inputs.emplace(std::string("input"), std::move(tensor));
98 
99   const std::vector<std::string> outputs({std::string("output")});
100   DCHECK(task_runner);
101   task_runner->PostTask(FROM_HERE, base::BindOnce(callback, std::move(inputs),
102                                                   std::move(outputs), app_id));
103 }
104 
105 // Process the RankerExample to vectorize the feature list for inference.
106 // Returns true on success.
RankerExampleToVectorizedFeatures(const assist_ranker::ExamplePreprocessorConfig & preprocessor_config,assist_ranker::RankerExample * example,std::vector<float> * vectorized_features)107 bool RankerExampleToVectorizedFeatures(
108     const assist_ranker::ExamplePreprocessorConfig& preprocessor_config,
109     assist_ranker::RankerExample* example,
110     std::vector<float>* vectorized_features) {
111   int preprocessor_error = assist_ranker::ExamplePreprocessor::Process(
112       preprocessor_config, example, true);
113   // kNoFeatureIndexFound can occur normally (e.g., when the app URL
114   // isn't known to the model or a rarely seen enum value is used).
115   if (preprocessor_error != assist_ranker::ExamplePreprocessor::kSuccess &&
116       preprocessor_error !=
117           assist_ranker::ExamplePreprocessor::kNoFeatureIndexFound) {
118     // TODO: Log to UMA.
119     return false;
120   }
121 
122   const auto& extracted_features =
123       example->features()
124           .at(assist_ranker::ExamplePreprocessor::kVectorizedFeatureDefaultName)
125           .float_list()
126           .float_value();
127   vectorized_features->assign(extracted_features.begin(),
128                               extracted_features.end());
129   return true;
130 }
131 
132 // Does the CPU-intensive part of CreateRankings (preparing the Tensor inputs
133 // from |app_features_map|, intended to be called on a low-priority
134 // background thread. Invokes |callback| on |task_runner| once for each app in
135 // |app_features_map|.
CreateRankingsImpl(base::flat_map<std::string,AppLaunchFeatures> app_features_map,int total_hours,int all_clicks_last_hour,int all_clicks_last_24_hours,scoped_refptr<base::SequencedTaskRunner> task_runner,const base::RepeatingCallback<void (base::flat_map<std::string,TensorPtr> inputs,const std::vector<std::string> outputs,const std::string app_id)> & callback)136 void CreateRankingsImpl(
137     base::flat_map<std::string, AppLaunchFeatures> app_features_map,
138     int total_hours,
139     int all_clicks_last_hour,
140     int all_clicks_last_24_hours,
141     scoped_refptr<base::SequencedTaskRunner> task_runner,
142     const base::RepeatingCallback<
143         void(base::flat_map<std::string, TensorPtr> inputs,
144              const std::vector<std::string> outputs,
145              const std::string app_id)>& callback) {
146   const base::Time now(base::Time::Now());
147   const int hour = HourOfDay(now);
148   const int day = DayOfWeek(now);
149 
150   assist_ranker::ExamplePreprocessorConfig preprocessor_config;
151   if (!LoadExamplePreprocessorConfig(&preprocessor_config)) {
152     return;
153   }
154   for (auto& app : app_features_map) {
155     assist_ranker::RankerExample example(
156         CreateRankerExample(app.second,
157                             now.ToDeltaSinceWindowsEpoch().InSeconds() -
158                                 app.second.time_of_last_click_sec(),
159                             total_hours, day, hour, all_clicks_last_hour,
160                             all_clicks_last_24_hours));
161     std::vector<float> vectorized_features;
162     if (RankerExampleToVectorizedFeatures(preprocessor_config, &example,
163                                           &vectorized_features)) {
164       DoInference(app.first, vectorized_features, task_runner, callback);
165     }
166   }
167 }
168 
169 }  // namespace
170 
CreateRankerExample(const AppLaunchFeatures & features,int time_since_last_click,int total_hours,int day_of_week,int hour_of_day,int all_clicks_last_hour,int all_clicks_last_24_hours)171 assist_ranker::RankerExample CreateRankerExample(
172     const AppLaunchFeatures& features,
173     int time_since_last_click,
174     int total_hours,
175     int day_of_week,
176     int hour_of_day,
177     int all_clicks_last_hour,
178     int all_clicks_last_24_hours) {
179   assist_ranker::RankerExample example;
180   auto& ranker_example_features = *example.mutable_features();
181 
182   ranker_example_features["DayOfWeek"].set_int32_value(day_of_week);
183   ranker_example_features["HourOfDay"].set_int32_value(hour_of_day);
184   ranker_example_features["AllClicksLastHour"].set_int32_value(
185       all_clicks_last_hour);
186   ranker_example_features["AllClicksLast24Hours"].set_int32_value(
187       all_clicks_last_24_hours);
188 
189   ranker_example_features["AppType"].set_int32_value(features.app_type());
190   ranker_example_features["ClickRank"].set_int32_value(features.click_rank());
191   ranker_example_features["ClicksLastHour"].set_int32_value(
192       features.clicks_last_hour());
193   ranker_example_features["ClicksLast24Hours"].set_int32_value(
194       features.clicks_last_24_hours());
195   ranker_example_features["LastLaunchedFrom"].set_int32_value(
196       features.last_launched_from());
197   ranker_example_features["HasClick"].set_bool_value(
198       features.has_most_recently_used_index());
199   ranker_example_features["MostRecentlyUsedIndex"].set_int32_value(
200       features.most_recently_used_index());
201   ranker_example_features["TimeSinceLastClick"].set_int32_value(
202       Bucketize(time_since_last_click, kTimeSinceLastClickBuckets));
203   ranker_example_features["TotalClicks"].set_int32_value(
204       features.total_clicks());
205   ranker_example_features["TotalClicksPerHour"].set_float_value(
206       static_cast<float>(features.total_clicks()) / (total_hours + 1));
207   ranker_example_features["TotalHours"].set_int32_value(total_hours);
208 
209   // Calculate FourHourClicksN and SixHourClicksN, which sum clicks for four
210   // and six hour periods respectively.
211   int four_hour_count = 0;
212   int six_hour_count = 0;
213   // Apps that have been clicked will have 24 clicks_each_hour values. Apps that
214   // have not been clicked will have no clicks_each_hour values, so can skip
215   // the FourHourClicksN and SixHourClicksN calculations.
216   if (features.clicks_each_hour_size() == 24) {
217     for (int hour = 0; hour < 24; hour++) {
218       int clicks = Bucketize(features.clicks_each_hour(hour), kClickBuckets);
219       ranker_example_features["ClicksEachHour" +
220                               base::StringPrintf("%02d", hour)]
221           .set_int32_value(clicks);
222       ranker_example_features["ClicksPerHour" +
223                               base::StringPrintf("%02d", hour)]
224           .set_float_value(static_cast<float>(clicks) / (total_hours + 1));
225       four_hour_count += clicks;
226       six_hour_count += clicks;
227       // Divide day into periods of 4 hours each.
228       if (hour % 4 == 3 && four_hour_count != 0) {
229         ranker_example_features["FourHourClicks" +
230                                 base::StringPrintf("%01d", hour / 4)]
231             .set_int32_value(four_hour_count);
232         four_hour_count = 0;
233       }
234       // Divide day into periods of 6 hours each.
235       if (hour % 6 == 5 && six_hour_count != 0) {
236         ranker_example_features["SixHourClicks" +
237                                 base::StringPrintf("%01d", hour / 6)]
238             .set_int32_value(six_hour_count);
239         six_hour_count = 0;
240       }
241     }
242   }
243 
244   if (features.app_type() == AppLaunchEvent_AppType_CHROME) {
245     ranker_example_features["URL"].set_string_value(
246         kExtensionSchemeWithDelimiter + features.app_id());
247   } else if (features.app_type() == AppLaunchEvent_AppType_PWA) {
248     ranker_example_features["URL"].set_string_value(features.pwa_url());
249   } else if (features.app_type() == AppLaunchEvent_AppType_PLAY) {
250     ranker_example_features["URL"].set_string_value(
251         kAppScheme +
252         crx_file::id_util::GenerateId(features.arc_package_name()));
253   } else {
254     // TODO(crbug.com/1027782): Add DCHECK that this branch is not reached.
255   }
256   return example;
257 }
258 
MlAppRankProvider()259 MlAppRankProvider::MlAppRankProvider()
260     : creation_task_runner_(base::SequencedTaskRunnerHandle::Get()),
261       background_task_runner_(base::ThreadPool::CreateSequencedTaskRunner(
262           {base::TaskPriority::BEST_EFFORT,
263            base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN})) {}
264 
265 MlAppRankProvider::~MlAppRankProvider() = default;
266 
CreateRankings(const base::flat_map<std::string,AppLaunchFeatures> & app_features_map,int total_hours,int all_clicks_last_hour,int all_clicks_last_24_hours)267 void MlAppRankProvider::CreateRankings(
268     const base::flat_map<std::string, AppLaunchFeatures>& app_features_map,
269     int total_hours,
270     int all_clicks_last_hour,
271     int all_clicks_last_24_hours) {
272   DCHECK_CALLED_ON_VALID_SEQUENCE(creation_sequence_checker_);
273   // TODO(jennyz): Add start-to-end latency metrics for the work on each
274   // sequence.
275   background_task_runner_->PostTask(
276       FROM_HERE,
277       base::BindOnce(&CreateRankingsImpl, app_features_map, total_hours,
278                      all_clicks_last_hour, all_clicks_last_24_hours,
279                      creation_task_runner_,
280                      base::BindRepeating(&MlAppRankProvider::RunExecutor,
281                                          weak_factory_.GetWeakPtr())));
282 }
283 
RetrieveRankings()284 std::map<std::string, float> MlAppRankProvider::RetrieveRankings() {
285   DCHECK_CALLED_ON_VALID_SEQUENCE(creation_sequence_checker_);
286   return ranking_map_;
287 }
288 
RunExecutor(base::flat_map<std::string,TensorPtr> inputs,const std::vector<std::string> outputs,const std::string app_id)289 void MlAppRankProvider::RunExecutor(
290     base::flat_map<std::string, TensorPtr> inputs,
291     const std::vector<std::string> outputs,
292     const std::string app_id) {
293   DCHECK_CALLED_ON_VALID_SEQUENCE(creation_sequence_checker_);
294   BindGraphExecutorIfNeeded();
295   executor_->Execute(std::move(inputs), std::move(outputs),
296                      base::BindOnce(&MlAppRankProvider::ExecuteCallback,
297                                     base::Unretained(this), app_id));
298 }
299 
ExecuteCallback(std::string app_id,ExecuteResult result,const base::Optional<std::vector<TensorPtr>> outputs)300 void MlAppRankProvider::ExecuteCallback(
301     std::string app_id,
302     ExecuteResult result,
303     const base::Optional<std::vector<TensorPtr>> outputs) {
304   DCHECK_CALLED_ON_VALID_SEQUENCE(creation_sequence_checker_);
305   if (result != ExecuteResult::OK) {
306     LOG(ERROR) << "Top Cat inference execution failed.";
307     return;
308   }
309   ranking_map_[app_id] = outputs.value()[0]->data->get_float_list()->value[0];
310 }
311 
BindGraphExecutorIfNeeded()312 void MlAppRankProvider::BindGraphExecutorIfNeeded() {
313   DCHECK_CALLED_ON_VALID_SEQUENCE(creation_sequence_checker_);
314   if (!model_) {
315     // Load the model.
316     BuiltinModelSpecPtr spec =
317         BuiltinModelSpec::New(BuiltinModelId::TOP_CAT_20190722);
318     chromeos::machine_learning::ServiceConnection::GetInstance()
319         ->LoadBuiltinModel(std::move(spec), model_.BindNewPipeAndPassReceiver(),
320                            base::BindOnce(&LoadModelCallback));
321   }
322 
323   if (!executor_) {
324     // Get the graph executor.
325     model_->CreateGraphExecutor(executor_.BindNewPipeAndPassReceiver(),
326                                 base::BindOnce(&CreateGraphExecutorCallback));
327     executor_.set_disconnect_handler(base::BindOnce(
328         &MlAppRankProvider::OnConnectionError, base::Unretained(this)));
329   }
330 }
331 
OnConnectionError()332 void MlAppRankProvider::OnConnectionError() {
333   LOG(WARNING) << "Mojo connection for ML service closed.";
334   executor_.reset();
335   model_.reset();
336 }
337 
338 }  // namespace app_list
339