1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chrome/browser/ui/app_list/search/search_result_ranker/search_ranking_event_logger.h"
6 
7 #include <utility>
8 
9 #include "ash/public/cpp/app_list/app_list_types.h"
10 #include "base/containers/flat_map.h"
11 #include "base/metrics/histogram_macros.h"
12 #include "base/no_destructor.h"
13 #include "base/strings/stringprintf.h"
14 #include "base/strings/utf_string_conversions.h"
15 #include "chrome/browser/ui/app_list/search/chrome_search_result.h"
16 #include "chrome/browser/ui/app_list/search/omnibox_result.h"
17 #include "chrome/browser/ui/app_list/search/search_result_ranker/search_ranking_event.pb.h"
18 #include "chrome/grit/browser_resources.h"
19 #include "chromeos/constants/devicetype.h"
20 #include "chromeos/services/machine_learning/public/cpp/service_connection.h"
21 #include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
22 #include "components/assist_ranker/example_preprocessing.h"
23 #include "components/crx_file/id_util.h"
24 #include "components/omnibox/browser/autocomplete_match_type.h"
25 #include "services/metrics/public/cpp/metrics_utils.h"
26 #include "services/metrics/public/cpp/ukm_builders.h"
27 #include "ui/base/resource/resource_bundle.h"
28 #include "url/gurl.h"
29 #include "url/origin.h"
30 
31 namespace app_list {
32 namespace {
33 
34 using chromeos::machine_learning::mojom::BuiltinModelId;
35 using chromeos::machine_learning::mojom::BuiltinModelSpec;
36 using chromeos::machine_learning::mojom::CreateGraphExecutorResult;
37 using chromeos::machine_learning::mojom::ExecuteResult;
38 using chromeos::machine_learning::mojom::FloatList;
39 using chromeos::machine_learning::mojom::Int64List;
40 using chromeos::machine_learning::mojom::LoadModelResult;
41 using chromeos::machine_learning::mojom::Tensor;
42 using chromeos::machine_learning::mojom::TensorPtr;
43 using chromeos::machine_learning::mojom::ValueList;
44 using ukm::GetExponentialBucketMinForCounts1000;
45 
46 // How long to wait for a URL to enter the history service before querying it
47 // for a UKM source ID.
48 constexpr base::TimeDelta kDelayForHistoryService =
49     base::TimeDelta::FromSeconds(15);
50 
51 // Chosen so that the bucket at the 24 hour mark is ~60 minutes long. The bucket
52 // exponent used for counts that are not seconds is 1.15 (via
53 // ukm::GetExponentialBucketMinForCounts1000). The first value skipped by
54 // bucketing is 10.
55 constexpr float kBucketExponentForSeconds = 1.045f;
56 
57 // The UMA histogram that logs the error occurs in inference code.
58 constexpr char kInferenceError[] = "Apps.AppList.AggregatedSearchRankerError";
59 
60 // Represents type of error in inference. These values are persisted to logs.
61 // Entries should not be renumbered and numeric values should never be reused.
62 enum class InferenceError {
63   kUnknown = 0,
64   kLoadModelFailed = 1,
65   kCreateGraphFailed = 2,
66   kLoadExamplePreprocessorConfigFailed = 3,
67   kVectorizeFeaturesFailed = 4,
68   kInferenceExecutionFailed = 5,
69   kMaxValue = kInferenceExecutionFailed,
70 };
71 
72 // Represents the type of a search result. The indices of these values
73 // persist to logs, so existing values should not be modified.
74 enum class Category {
75   UNKNOWN = 0,
76   FILE = 1,
77   HISTORY = 2,
78   NAV_SUGGEST = 3,
79   SEARCH = 4,
80   BOOKMARK = 5,
81   DOCUMENT = 6,
82   OMNIBOX_DEPRECATED = 7,
83   OMNIBOX_GENERIC = 8
84 };
85 
ExtensionTypeFromFileName(const std::string & file_name)86 int ExtensionTypeFromFileName(const std::string& file_name) {
87   // This is a limited list of commonly used extensions. The index of an
88   // extension in this list persists to logs, so existing values should not be
89   // modified and new values should only be added to the end. This should be
90   // kept in sync with AppListNonAppImpressionFileExtension in
91   // histograms/enums.xml
92   static const base::NoDestructor<std::vector<std::string>> known_extensions(
93       {".3ga",        ".3gp",    ".aac",     ".alac", ".asf",  ".avi",
94        ".bmp",        ".csv",    ".doc",     ".docx", ".flac", ".gif",
95        ".jpeg",       ".jpg",    ".log",     ".m3u",  ".m3u8", ".m4a",
96        ".m4v",        ".mid",    ".mkv",     ".mov",  ".mp3",  ".mp4",
97        ".mpg",        ".odf",    ".odp",     ".ods",  ".odt",  ".oga",
98        ".ogg",        ".ogv",    ".pdf",     ".png",  ".ppt",  ".pptx",
99        ".ra",         ".ram",    ".rar",     ".rm",   ".rtf",  ".wav",
100        ".webm",       ".webp",   ".wma",     ".wmv",  ".xls",  ".xlsx",
101        ".crdownload", ".crx",    ".dmg",     ".exe",  ".html", ".htm",
102        ".jar",        ".ps",     ".torrent", ".txt",  ".zip",  ".mhtml",
103        ".gdoc",       ".gsheet", ".gslides"});
104 
105   size_t found = file_name.find_last_of(".");
106   if (found == std::string::npos)
107     return -1;
108   return std::distance(
109       known_extensions->begin(),
110       std::find(known_extensions->begin(), known_extensions->end(),
111                 file_name.substr(found)));
112 }
113 
CategoryFromResultType(ash::AppListSearchResultType type,int subtype)114 Category CategoryFromResultType(ash::AppListSearchResultType type,
115                                 int subtype) {
116   if (type == ash::AppListSearchResultType::kLauncher)
117     return Category::FILE;
118 
119   if (type == ash::AppListSearchResultType::kOmnibox) {
120     switch (static_cast<AutocompleteMatchType::Type>(subtype)) {
121       case AutocompleteMatchType::Type::HISTORY_URL:
122       case AutocompleteMatchType::Type::HISTORY_TITLE:
123       case AutocompleteMatchType::Type::HISTORY_BODY:
124       case AutocompleteMatchType::Type::HISTORY_KEYWORD:
125         return Category::HISTORY;
126       case AutocompleteMatchType::Type::NAVSUGGEST:
127       case AutocompleteMatchType::Type::NAVSUGGEST_PERSONALIZED:
128         return Category::NAV_SUGGEST;
129       case AutocompleteMatchType::Type::SEARCH_HISTORY:
130       case AutocompleteMatchType::Type::SEARCH_SUGGEST:
131       case AutocompleteMatchType::Type::SEARCH_SUGGEST_ENTITY:
132       case AutocompleteMatchType::Type::SEARCH_SUGGEST_TAIL:
133       case AutocompleteMatchType::Type::SEARCH_SUGGEST_PERSONALIZED:
134       case AutocompleteMatchType::Type::SEARCH_SUGGEST_PROFILE:
135       case AutocompleteMatchType::Type::SEARCH_OTHER_ENGINE:
136         return Category::SEARCH;
137       case AutocompleteMatchType::Type::BOOKMARK_TITLE:
138         return Category::BOOKMARK;
139       case AutocompleteMatchType::Type::DOCUMENT_SUGGESTION:
140         return Category::DOCUMENT;
141       case AutocompleteMatchType::Type::EXTENSION_APP_DEPRECATED:
142       case AutocompleteMatchType::Type::CONTACT_DEPRECATED:
143       case AutocompleteMatchType::Type::PHYSICAL_WEB_DEPRECATED:
144       case AutocompleteMatchType::Type::PHYSICAL_WEB_OVERFLOW_DEPRECATED:
145       case AutocompleteMatchType::Type::TAB_SEARCH_DEPRECATED:
146         return Category::OMNIBOX_DEPRECATED;
147       default:
148         return Category::OMNIBOX_GENERIC;
149     }
150   }
151 
152   return Category::UNKNOWN;
153 }
154 
LogInferenceError(const InferenceError & error)155 void LogInferenceError(const InferenceError& error) {
156   UMA_HISTOGRAM_ENUMERATION(kInferenceError, error);
157 }
158 
GetExponentialBucketMinForSeconds(int64_t sample)159 int GetExponentialBucketMinForSeconds(int64_t sample) {
160   return ukm::GetExponentialBucketMin(sample, kBucketExponentForSeconds);
161 }
LoadModelCallback(LoadModelResult result)162 void LoadModelCallback(LoadModelResult result) {
163   if (result != LoadModelResult::OK) {
164     LOG(ERROR) << "Failed to load Search Ranker model.";
165     LogInferenceError(InferenceError::kLoadModelFailed);
166   }
167 }
168 
CreateGraphExecutorCallback(CreateGraphExecutorResult result)169 void CreateGraphExecutorCallback(CreateGraphExecutorResult result) {
170   if (result != CreateGraphExecutorResult::OK) {
171     LOG(ERROR) << "Failed to create a Search Ranker Graph Executor.";
172     LogInferenceError(InferenceError::kCreateGraphFailed);
173   }
174 }
175 
176 // Populates |example| using |features|.
PopulateRankerExample(const SearchRankingItem::Features & features,assist_ranker::RankerExample * example)177 void PopulateRankerExample(const SearchRankingItem::Features& features,
178                            assist_ranker::RankerExample* example) {
179   CHECK(example);
180 
181   auto& ranker_example_features = *example->mutable_features();
182   ranker_example_features["QueryLength"].set_int32_value(
183       features.query_length());
184   ranker_example_features["RelevanceScore"].set_int32_value(
185       features.relevance_score());
186   ranker_example_features["Category"].set_int32_value(features.category());
187   ranker_example_features["HourOfDay"].set_int32_value(features.hour_of_day());
188   ranker_example_features["DayOfWeek"].set_int32_value(features.day_of_week());
189   ranker_example_features["LaunchesThisSession"].set_int32_value(
190       features.launches_this_session());
191   if (features.has_file_extension()) {
192     ranker_example_features["FileExtension"].set_int32_value(
193         features.file_extension());
194   }
195   if (features.has_time_since_last_launch()) {
196     ranker_example_features["TimeSinceLastLaunch"].set_int32_value(
197         features.time_since_last_launch());
198     ranker_example_features["TimeOfLastLaunch"].set_int32_value(
199         features.time_of_last_launch());
200   }
201   const auto& launches = features.launches_at_hour();
202   for (int hour = 0; hour < launches.size(); hour++) {
203     ranker_example_features["LaunchesAtHour" + base::StringPrintf("%02d", hour)]
204         .set_int32_value(launches[hour]);
205   }
206   if (features.has_domain()) {
207     ranker_example_features["Domain"].set_string_value(features.domain());
208     ranker_example_features["HasDomain"].set_int32_value(1);
209   }
210 }
211 
212 // Loads the preprocessor config protobuf, which will be used later to convert
213 // a RankerExample to a vectorized float for inactivity score calculation.
214 // Returns nullptr if cannot load or parse the config.
215 std::unique_ptr<assist_ranker::ExamplePreprocessorConfig>
LoadExamplePreprocessorConfig()216 LoadExamplePreprocessorConfig() {
217   auto config = std::make_unique<assist_ranker::ExamplePreprocessorConfig>();
218 
219   const int res_id = IDR_SEARCH_RANKER_20190923_EXAMPLE_PREPROCESSOR_CONFIG_PB;
220 
221   scoped_refptr<base::RefCountedMemory> raw_config =
222       ui::ResourceBundle::GetSharedInstance().LoadDataResourceBytes(res_id);
223   if (!raw_config || !raw_config->front()) {
224     LOG(ERROR) << "Failed to load SearchRanker example preprocessor config.";
225     LogInferenceError(InferenceError::kLoadExamplePreprocessorConfigFailed);
226     return nullptr;
227   }
228 
229   if (!config->ParseFromArray(raw_config->front(), raw_config->size())) {
230     LOG(ERROR) << "Failed to parse SearchRanker example preprocessor config.";
231     LogInferenceError(InferenceError::kLoadExamplePreprocessorConfigFailed);
232     return nullptr;
233   }
234 
235   return config;
236 }
237 }  // namespace
238 
SearchRankingEventLogger(Profile * profile,SearchController * search_controller)239 SearchRankingEventLogger::SearchRankingEventLogger(
240     Profile* profile,
241     SearchController* search_controller)
242     : search_controller_(search_controller),
243       ukm_recorder_(ukm::UkmRecorder::Get()),
244       ukm_background_recorder_(
245           ukm::UkmBackgroundRecorderFactory::GetForProfile(profile)),
246       weak_factory_(this) {
247   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
248   DCHECK(search_controller_);
249 }
250 
251 SearchRankingEventLogger::~SearchRankingEventLogger() = default;
252 
253 SearchRankingEventLogger::ResultState::ResultState() = default;
254 SearchRankingEventLogger::ResultState::~ResultState() = default;
255 
SetEventRecordedForTesting(base::OnceClosure closure)256 void SearchRankingEventLogger::SetEventRecordedForTesting(
257     base::OnceClosure closure) {
258   event_recorded_for_testing_ = std::move(closure);
259 }
260 
PopulateSearchRankingItem(SearchRankingItem * proto,ChromeSearchResult * search_result,int query_length,bool use_for_logging)261 void SearchRankingEventLogger::PopulateSearchRankingItem(
262     SearchRankingItem* proto,
263     ChromeSearchResult* search_result,
264     int query_length,
265     bool use_for_logging) {
266   const base::Time now = base::Time::Now();
267   base::Time::Exploded now_exploded;
268   now.LocalExplode(&now_exploded);
269 
270   auto& features = *proto->mutable_features();
271   features.set_category(static_cast<int>(CategoryFromResultType(
272       search_result->result_type(), search_result->result_subtype())));
273 
274   // Note this is the search provider's original relevance score, not
275   // tweaked by any search ranking models. Scores are floats in 0 to 1, and
276   // we map this to ints 0 to 100.
277   features.set_relevance_score(
278       static_cast<int>(100 * search_result->relevance()));
279   features.set_hour_of_day(now_exploded.hour);
280   features.set_day_of_week(now_exploded.day_of_week);
281   features.set_query_length(GetExponentialBucketMinForCounts1000(query_length));
282 
283   if (features.category() == static_cast<int>(Category::FILE)) {
284     features.set_file_extension(ExtensionTypeFromFileName(search_result->id()));
285   }
286 
287   if (search_result->result_type() == ash::AppListSearchResultType::kOmnibox) {
288     // The id metadata of an OmniboxResult is a stripped URL, which does not
289     // correspond to the URL that will be navigated to.
290     proto->set_target(
291         static_cast<OmniboxResult*>(search_result)->DestinationURL().spec());
292   } else {
293     proto->set_target(search_result->id());
294   }
295 
296   const std::string& domain = GURL(search_result->id()).host();
297   if (!domain.empty()) {
298     features.set_domain(domain);
299   }
300 
301   // If the proto is created for logging purposes, create a new item in the map.
302   // Otherwise lookup the map for event info and create a "dummy" event info if
303   // doesn't nothing found.
304   ResultState* event_info;
305   ResultState dummy_event_info;
306   if (use_for_logging) {
307     event_info = &id_to_result_state_[proto->target()];
308   } else {
309     const auto& it = id_to_result_state_.find(proto->target());
310     if (it != id_to_result_state_.end()) {
311       event_info = &it->second;
312     } else {
313       event_info = &dummy_event_info;
314     }
315   }
316 
317   if (event_info->last_launch.has_value()) {
318     base::Time last_launch = event_info->last_launch.value();
319     base::Time::Exploded last_launch_exploded;
320     last_launch.LocalExplode(&last_launch_exploded);
321 
322     features.set_time_since_last_launch(
323         GetExponentialBucketMinForSeconds((now - last_launch).InSeconds()));
324     features.set_time_of_last_launch(last_launch_exploded.hour);
325 
326     // Reset the number of launches this hour to 0 if this is the first
327     // launch today of this event, to account for user sessions spanning
328     // multiple days.
329     if (features.has_is_launched() && features.is_launched() == 1 &&
330         now - event_info->last_launch.value() >=
331             base::TimeDelta::FromHours(23)) {
332       event_info->launches_per_hour[now_exploded.hour] = 0;
333     }
334   }
335 
336   features.set_launches_this_session(
337       GetExponentialBucketMinForCounts1000(event_info->launches_this_session));
338 
339   const auto& launches = event_info->launches_per_hour;
340   features.add_launches_at_hour(
341       GetExponentialBucketMinForCounts1000(launches[0]));
342   features.add_launches_at_hour(
343       GetExponentialBucketMinForCounts1000(launches[1]));
344   features.add_launches_at_hour(
345       GetExponentialBucketMinForCounts1000(launches[2]));
346   features.add_launches_at_hour(
347       GetExponentialBucketMinForCounts1000(launches[3]));
348   features.add_launches_at_hour(
349       GetExponentialBucketMinForCounts1000(launches[4]));
350   features.add_launches_at_hour(
351       GetExponentialBucketMinForCounts1000(launches[5]));
352   features.add_launches_at_hour(
353       GetExponentialBucketMinForCounts1000(launches[6]));
354   features.add_launches_at_hour(
355       GetExponentialBucketMinForCounts1000(launches[7]));
356   features.add_launches_at_hour(
357       GetExponentialBucketMinForCounts1000(launches[8]));
358   features.add_launches_at_hour(
359       GetExponentialBucketMinForCounts1000(launches[9]));
360   features.add_launches_at_hour(
361       GetExponentialBucketMinForCounts1000(launches[10]));
362   features.add_launches_at_hour(
363       GetExponentialBucketMinForCounts1000(launches[11]));
364   features.add_launches_at_hour(
365       GetExponentialBucketMinForCounts1000(launches[12]));
366   features.add_launches_at_hour(
367       GetExponentialBucketMinForCounts1000(launches[13]));
368   features.add_launches_at_hour(
369       GetExponentialBucketMinForCounts1000(launches[14]));
370   features.add_launches_at_hour(
371       GetExponentialBucketMinForCounts1000(launches[15]));
372   features.add_launches_at_hour(
373       GetExponentialBucketMinForCounts1000(launches[16]));
374   features.add_launches_at_hour(
375       GetExponentialBucketMinForCounts1000(launches[17]));
376   features.add_launches_at_hour(
377       GetExponentialBucketMinForCounts1000(launches[18]));
378   features.add_launches_at_hour(
379       GetExponentialBucketMinForCounts1000(launches[19]));
380   features.add_launches_at_hour(
381       GetExponentialBucketMinForCounts1000(launches[20]));
382   features.add_launches_at_hour(
383       GetExponentialBucketMinForCounts1000(launches[21]));
384   features.add_launches_at_hour(
385       GetExponentialBucketMinForCounts1000(launches[22]));
386   features.add_launches_at_hour(
387       GetExponentialBucketMinForCounts1000(launches[23]));
388 
389   if (features.has_is_launched() && features.is_launched() == 1) {
390     event_info->last_launch = now;
391     event_info->launches_this_session += 1;
392     event_info->launches_per_hour[now_exploded.hour] += 1;
393   }
394 }
395 
Log(const base::string16 & trimmed_query,const ash::SearchResultIdWithPositionIndices & search_results,int launched_index)396 void SearchRankingEventLogger::Log(
397     const base::string16& trimmed_query,
398     const ash::SearchResultIdWithPositionIndices& search_results,
399     int launched_index) {
400   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
401   for (const auto& id_index : search_results) {
402     auto* result = search_controller_->FindSearchResult(id_index.id);
403     if (!result)
404       continue;
405 
406     SearchRankingItem proto;
407     proto.mutable_features()->set_position(id_index.position_index);
408     proto.set_event_id(next_event_id_);
409     proto.mutable_features()->set_is_launched(
410         id_index.position_index == launched_index ? 1 : 0);
411     PopulateSearchRankingItem(&proto, result, trimmed_query.size(),
412                               true /*use_for_logging*/);
413 
414     // Omnibox results have associated URLs, so are logged keyed on the URL
415     // after validating that it exists in the history service. Other results
416     // have no associated URL, so use a blank source ID.
417     if (result->result_type() == ash::AppListSearchResultType::kOmnibox) {
418       // When an omnibox result is launched, we need to retrieve a source ID
419       // using the history service. This may be the first time the URL is used
420       // and so it must be committed to the history service database before we
421       // retrieve it, which happens once the page has loaded. So we delay our
422       // check for long enough that most pages will have loaded.
423       if (launched_index == id_index.position_index) {
424         base::SequencedTaskRunnerHandle::Get()->PostDelayedTask(
425             FROM_HERE,
426             base::BindOnce(
427                 &SearchRankingEventLogger::GetBackgroundSourceIdAndLogEvent,
428                 weak_factory_.GetWeakPtr(), proto),
429             kDelayForHistoryService);
430       } else {
431         GetBackgroundSourceIdAndLogEvent(proto);
432       }
433     } else {
434       LogEvent(proto, base::nullopt);
435     }
436   }
437 
438   ++next_event_id_;
439 }
440 
GetBackgroundSourceIdAndLogEvent(const SearchRankingItem & result)441 void SearchRankingEventLogger::GetBackgroundSourceIdAndLogEvent(
442     const SearchRankingItem& result) {
443   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
444   ukm_background_recorder_->GetBackgroundSourceIdIfAllowed(
445       url::Origin::Create(GURL(result.target())),
446       base::BindOnce(&SearchRankingEventLogger::LogEvent,
447                      base::Unretained(this), result));
448 }
449 
LogEvent(const SearchRankingItem & result,base::Optional<ukm::SourceId> source_id)450 void SearchRankingEventLogger::LogEvent(
451     const SearchRankingItem& result,
452     base::Optional<ukm::SourceId> source_id) {
453   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
454   if (!source_id)
455     source_id = ukm_recorder_->GetNewSourceID();
456 
457   ukm::builders::AppListNonAppImpression event(source_id.value());
458   event.SetEventId(result.event_id())
459       .SetPosition(result.features().position())
460       .SetIsLaunched(result.features().is_launched())
461       .SetQueryLength(result.features().query_length())
462       .SetRelevanceScore(result.features().relevance_score())
463       .SetCategory(result.features().category())
464       .SetHourOfDay(result.features().hour_of_day())
465       .SetDayOfWeek(result.features().day_of_week())
466       .SetLaunchesThisSession(result.features().launches_this_session());
467 
468   if (result.features().has_file_extension()) {
469     event.SetFileExtension(result.features().file_extension());
470   }
471 
472   if (result.features().has_time_since_last_launch()) {
473     event.SetTimeSinceLastLaunch(result.features().time_since_last_launch());
474     event.SetTimeOfLastLaunch(result.features().time_of_last_launch());
475   }
476 
477   const auto& launches = result.features().launches_at_hour();
478   event.SetLaunchesAtHour00(launches[0]);
479   event.SetLaunchesAtHour01(launches[1]);
480   event.SetLaunchesAtHour02(launches[2]);
481   event.SetLaunchesAtHour03(launches[3]);
482   event.SetLaunchesAtHour04(launches[4]);
483   event.SetLaunchesAtHour05(launches[5]);
484   event.SetLaunchesAtHour06(launches[6]);
485   event.SetLaunchesAtHour07(launches[7]);
486   event.SetLaunchesAtHour08(launches[8]);
487   event.SetLaunchesAtHour09(launches[9]);
488   event.SetLaunchesAtHour10(launches[10]);
489   event.SetLaunchesAtHour11(launches[11]);
490   event.SetLaunchesAtHour12(launches[12]);
491   event.SetLaunchesAtHour13(launches[13]);
492   event.SetLaunchesAtHour14(launches[14]);
493   event.SetLaunchesAtHour15(launches[15]);
494   event.SetLaunchesAtHour16(launches[16]);
495   event.SetLaunchesAtHour17(launches[17]);
496   event.SetLaunchesAtHour18(launches[18]);
497   event.SetLaunchesAtHour19(launches[19]);
498   event.SetLaunchesAtHour20(launches[20]);
499   event.SetLaunchesAtHour21(launches[21]);
500   event.SetLaunchesAtHour22(launches[22]);
501   event.SetLaunchesAtHour23(launches[23]);
502 
503   event.Record(ukm_recorder_);
504 
505   if (event_recorded_for_testing_)
506     std::move(event_recorded_for_testing_).Run();
507 }
508 
CreateRankings(Mixer::SortedResults * results,int query_length)509 void SearchRankingEventLogger::CreateRankings(Mixer::SortedResults* results,
510                                               int query_length) {
511   for (const auto& result : *results) {
512     if (!result.result) {
513       continue;
514     }
515     SearchRankingItem proto;
516     std::vector<float> vectorized_features;
517 
518     PopulateSearchRankingItem(&proto, result.result, query_length,
519                               false /*use_for_logging*/);
520     if (!PreprocessInput(proto.features(), &vectorized_features)) {
521       return;
522     }
523     DoInference(vectorized_features, result.result->id());
524   }
525 }
526 
RetrieveRankings()527 std::map<std::string, float> SearchRankingEventLogger::RetrieveRankings() {
528   return prediction_;
529 }
530 
LazyInitialize()531 void SearchRankingEventLogger::LazyInitialize() {
532   if (!preprocessor_config_) {
533     preprocessor_config_ = LoadExamplePreprocessorConfig();
534   }
535 }
536 
PreprocessInput(const SearchRankingItem::Features & features,std::vector<float> * vectorized_features)537 bool SearchRankingEventLogger::PreprocessInput(
538     const SearchRankingItem::Features& features,
539     std::vector<float>* vectorized_features) {
540   DCHECK(vectorized_features);
541   LazyInitialize();
542 
543   if (!preprocessor_config_) {
544     LOG(ERROR) << "Failed to create preprocessor config.";
545     LogInferenceError(InferenceError::kLoadExamplePreprocessorConfigFailed);
546     return false;
547   }
548 
549   assist_ranker::RankerExample ranker_example;
550   PopulateRankerExample(features, &ranker_example);
551 
552   int preprocessor_error = assist_ranker::ExamplePreprocessor::Process(
553       *preprocessor_config_, &ranker_example, true);
554   // kNoFeatureIndexFound can occur normally (e.g., when the domain name
555   // isn't known to the model or a rarely seen enum value is used).
556   if (preprocessor_error != assist_ranker::ExamplePreprocessor::kSuccess &&
557       preprocessor_error !=
558           assist_ranker::ExamplePreprocessor::kNoFeatureIndexFound) {
559     LOG(ERROR) << "Failed to vectorize features using ExamplePreprocessor.";
560     LogInferenceError(InferenceError::kVectorizeFeaturesFailed);
561     return false;
562   }
563 
564   const auto& extracted_features =
565       ranker_example.features()
566           .at(assist_ranker::ExamplePreprocessor::kVectorizedFeatureDefaultName)
567           .float_list()
568           .float_value();
569   vectorized_features->assign(extracted_features.begin(),
570                               extracted_features.end());
571   return true;
572 }
573 
DoInference(const std::vector<float> & features,const std::string & id)574 void SearchRankingEventLogger::DoInference(const std::vector<float>& features,
575                                            const std::string& id) {
576   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
577   BindGraphExecutorIfNeeded();
578 
579   // Prepare the input tensor.
580   base::flat_map<std::string, TensorPtr> inputs;
581   auto tensor = Tensor::New();
582   tensor->shape = Int64List::New();
583   tensor->shape->value = std::vector<int64_t>({1, features.size()});
584   tensor->data = ValueList::New();
585   tensor->data->set_float_list(FloatList::New());
586   tensor->data->get_float_list()->value =
587       std::vector<double>(std::begin(features), std::end(features));
588   inputs.emplace(std::string("input"), std::move(tensor));
589 
590   const std::vector<std::string> outputs({std::string("output")});
591   // Execute
592   executor_->Execute(std::move(inputs), std::move(outputs),
593                      base::BindOnce(&SearchRankingEventLogger::ExecuteCallback,
594                                     weak_factory_.GetWeakPtr(), id));
595 }
596 
BindGraphExecutorIfNeeded()597 void SearchRankingEventLogger::BindGraphExecutorIfNeeded() {
598   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
599   if (!model_) {
600     // Load the model.
601     auto spec = BuiltinModelSpec::New(BuiltinModelId::SEARCH_RANKER_20190923);
602     chromeos::machine_learning::ServiceConnection::GetInstance()
603         ->LoadBuiltinModel(std::move(spec), model_.BindNewPipeAndPassReceiver(),
604                            base::BindOnce(&LoadModelCallback));
605   }
606 
607   if (!executor_) {
608     // Get the graph executor.
609     model_->CreateGraphExecutor(executor_.BindNewPipeAndPassReceiver(),
610                                 base::BindOnce(&CreateGraphExecutorCallback));
611     executor_.set_disconnect_handler(base::BindOnce(
612         &SearchRankingEventLogger::OnConnectionError, base::Unretained(this)));
613   }
614 }
615 
OnConnectionError()616 void SearchRankingEventLogger::OnConnectionError() {
617   LOG(WARNING) << "Mojo connection for ML service closed.";
618   executor_.reset();
619   model_.reset();
620 }
621 
ExecuteCallback(const std::string & id,ExecuteResult result,const base::Optional<std::vector<TensorPtr>> outputs)622 void SearchRankingEventLogger::ExecuteCallback(
623     const std::string& id,
624     ExecuteResult result,
625     const base::Optional<std::vector<TensorPtr>> outputs) {
626   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
627   if (result != ExecuteResult::OK) {
628     LOG(ERROR) << "Search Ranker inference execution failed.";
629     LogInferenceError(InferenceError::kInferenceExecutionFailed);
630     return;
631   }
632   prediction_[id] = outputs.value()[0]->data->get_float_list()->value[0];
633 }
634 
635 }  // namespace app_list
636