1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "chrome/browser/ui/app_list/search/search_result_ranker/search_ranking_event_logger.h"
6
7 #include <utility>
8
9 #include "ash/public/cpp/app_list/app_list_types.h"
10 #include "base/containers/flat_map.h"
11 #include "base/metrics/histogram_macros.h"
12 #include "base/no_destructor.h"
13 #include "base/strings/stringprintf.h"
14 #include "base/strings/utf_string_conversions.h"
15 #include "chrome/browser/ui/app_list/search/chrome_search_result.h"
16 #include "chrome/browser/ui/app_list/search/omnibox_result.h"
17 #include "chrome/browser/ui/app_list/search/search_result_ranker/search_ranking_event.pb.h"
18 #include "chrome/grit/browser_resources.h"
19 #include "chromeos/constants/devicetype.h"
20 #include "chromeos/services/machine_learning/public/cpp/service_connection.h"
21 #include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
22 #include "components/assist_ranker/example_preprocessing.h"
23 #include "components/crx_file/id_util.h"
24 #include "components/omnibox/browser/autocomplete_match_type.h"
25 #include "services/metrics/public/cpp/metrics_utils.h"
26 #include "services/metrics/public/cpp/ukm_builders.h"
27 #include "ui/base/resource/resource_bundle.h"
28 #include "url/gurl.h"
29 #include "url/origin.h"
30
31 namespace app_list {
32 namespace {
33
34 using chromeos::machine_learning::mojom::BuiltinModelId;
35 using chromeos::machine_learning::mojom::BuiltinModelSpec;
36 using chromeos::machine_learning::mojom::CreateGraphExecutorResult;
37 using chromeos::machine_learning::mojom::ExecuteResult;
38 using chromeos::machine_learning::mojom::FloatList;
39 using chromeos::machine_learning::mojom::Int64List;
40 using chromeos::machine_learning::mojom::LoadModelResult;
41 using chromeos::machine_learning::mojom::Tensor;
42 using chromeos::machine_learning::mojom::TensorPtr;
43 using chromeos::machine_learning::mojom::ValueList;
44 using ukm::GetExponentialBucketMinForCounts1000;
45
46 // How long to wait for a URL to enter the history service before querying it
47 // for a UKM source ID.
48 constexpr base::TimeDelta kDelayForHistoryService =
49 base::TimeDelta::FromSeconds(15);
50
51 // Chosen so that the bucket at the 24 hour mark is ~60 minutes long. The bucket
52 // exponent used for counts that are not seconds is 1.15 (via
53 // ukm::GetExponentialBucketMinForCounts1000). The first value skipped by
54 // bucketing is 10.
55 constexpr float kBucketExponentForSeconds = 1.045f;
56
57 // The UMA histogram that logs the error occurs in inference code.
58 constexpr char kInferenceError[] = "Apps.AppList.AggregatedSearchRankerError";
59
60 // Represents type of error in inference. These values are persisted to logs.
61 // Entries should not be renumbered and numeric values should never be reused.
62 enum class InferenceError {
63 kUnknown = 0,
64 kLoadModelFailed = 1,
65 kCreateGraphFailed = 2,
66 kLoadExamplePreprocessorConfigFailed = 3,
67 kVectorizeFeaturesFailed = 4,
68 kInferenceExecutionFailed = 5,
69 kMaxValue = kInferenceExecutionFailed,
70 };
71
72 // Represents the type of a search result. The indices of these values
73 // persist to logs, so existing values should not be modified.
74 enum class Category {
75 UNKNOWN = 0,
76 FILE = 1,
77 HISTORY = 2,
78 NAV_SUGGEST = 3,
79 SEARCH = 4,
80 BOOKMARK = 5,
81 DOCUMENT = 6,
82 OMNIBOX_DEPRECATED = 7,
83 OMNIBOX_GENERIC = 8
84 };
85
ExtensionTypeFromFileName(const std::string & file_name)86 int ExtensionTypeFromFileName(const std::string& file_name) {
87 // This is a limited list of commonly used extensions. The index of an
88 // extension in this list persists to logs, so existing values should not be
89 // modified and new values should only be added to the end. This should be
90 // kept in sync with AppListNonAppImpressionFileExtension in
91 // histograms/enums.xml
92 static const base::NoDestructor<std::vector<std::string>> known_extensions(
93 {".3ga", ".3gp", ".aac", ".alac", ".asf", ".avi",
94 ".bmp", ".csv", ".doc", ".docx", ".flac", ".gif",
95 ".jpeg", ".jpg", ".log", ".m3u", ".m3u8", ".m4a",
96 ".m4v", ".mid", ".mkv", ".mov", ".mp3", ".mp4",
97 ".mpg", ".odf", ".odp", ".ods", ".odt", ".oga",
98 ".ogg", ".ogv", ".pdf", ".png", ".ppt", ".pptx",
99 ".ra", ".ram", ".rar", ".rm", ".rtf", ".wav",
100 ".webm", ".webp", ".wma", ".wmv", ".xls", ".xlsx",
101 ".crdownload", ".crx", ".dmg", ".exe", ".html", ".htm",
102 ".jar", ".ps", ".torrent", ".txt", ".zip", ".mhtml",
103 ".gdoc", ".gsheet", ".gslides"});
104
105 size_t found = file_name.find_last_of(".");
106 if (found == std::string::npos)
107 return -1;
108 return std::distance(
109 known_extensions->begin(),
110 std::find(known_extensions->begin(), known_extensions->end(),
111 file_name.substr(found)));
112 }
113
CategoryFromResultType(ash::AppListSearchResultType type,int subtype)114 Category CategoryFromResultType(ash::AppListSearchResultType type,
115 int subtype) {
116 if (type == ash::AppListSearchResultType::kLauncher)
117 return Category::FILE;
118
119 if (type == ash::AppListSearchResultType::kOmnibox) {
120 switch (static_cast<AutocompleteMatchType::Type>(subtype)) {
121 case AutocompleteMatchType::Type::HISTORY_URL:
122 case AutocompleteMatchType::Type::HISTORY_TITLE:
123 case AutocompleteMatchType::Type::HISTORY_BODY:
124 case AutocompleteMatchType::Type::HISTORY_KEYWORD:
125 return Category::HISTORY;
126 case AutocompleteMatchType::Type::NAVSUGGEST:
127 case AutocompleteMatchType::Type::NAVSUGGEST_PERSONALIZED:
128 return Category::NAV_SUGGEST;
129 case AutocompleteMatchType::Type::SEARCH_HISTORY:
130 case AutocompleteMatchType::Type::SEARCH_SUGGEST:
131 case AutocompleteMatchType::Type::SEARCH_SUGGEST_ENTITY:
132 case AutocompleteMatchType::Type::SEARCH_SUGGEST_TAIL:
133 case AutocompleteMatchType::Type::SEARCH_SUGGEST_PERSONALIZED:
134 case AutocompleteMatchType::Type::SEARCH_SUGGEST_PROFILE:
135 case AutocompleteMatchType::Type::SEARCH_OTHER_ENGINE:
136 return Category::SEARCH;
137 case AutocompleteMatchType::Type::BOOKMARK_TITLE:
138 return Category::BOOKMARK;
139 case AutocompleteMatchType::Type::DOCUMENT_SUGGESTION:
140 return Category::DOCUMENT;
141 case AutocompleteMatchType::Type::EXTENSION_APP_DEPRECATED:
142 case AutocompleteMatchType::Type::CONTACT_DEPRECATED:
143 case AutocompleteMatchType::Type::PHYSICAL_WEB_DEPRECATED:
144 case AutocompleteMatchType::Type::PHYSICAL_WEB_OVERFLOW_DEPRECATED:
145 case AutocompleteMatchType::Type::TAB_SEARCH_DEPRECATED:
146 return Category::OMNIBOX_DEPRECATED;
147 default:
148 return Category::OMNIBOX_GENERIC;
149 }
150 }
151
152 return Category::UNKNOWN;
153 }
154
LogInferenceError(const InferenceError & error)155 void LogInferenceError(const InferenceError& error) {
156 UMA_HISTOGRAM_ENUMERATION(kInferenceError, error);
157 }
158
GetExponentialBucketMinForSeconds(int64_t sample)159 int GetExponentialBucketMinForSeconds(int64_t sample) {
160 return ukm::GetExponentialBucketMin(sample, kBucketExponentForSeconds);
161 }
LoadModelCallback(LoadModelResult result)162 void LoadModelCallback(LoadModelResult result) {
163 if (result != LoadModelResult::OK) {
164 LOG(ERROR) << "Failed to load Search Ranker model.";
165 LogInferenceError(InferenceError::kLoadModelFailed);
166 }
167 }
168
CreateGraphExecutorCallback(CreateGraphExecutorResult result)169 void CreateGraphExecutorCallback(CreateGraphExecutorResult result) {
170 if (result != CreateGraphExecutorResult::OK) {
171 LOG(ERROR) << "Failed to create a Search Ranker Graph Executor.";
172 LogInferenceError(InferenceError::kCreateGraphFailed);
173 }
174 }
175
176 // Populates |example| using |features|.
PopulateRankerExample(const SearchRankingItem::Features & features,assist_ranker::RankerExample * example)177 void PopulateRankerExample(const SearchRankingItem::Features& features,
178 assist_ranker::RankerExample* example) {
179 CHECK(example);
180
181 auto& ranker_example_features = *example->mutable_features();
182 ranker_example_features["QueryLength"].set_int32_value(
183 features.query_length());
184 ranker_example_features["RelevanceScore"].set_int32_value(
185 features.relevance_score());
186 ranker_example_features["Category"].set_int32_value(features.category());
187 ranker_example_features["HourOfDay"].set_int32_value(features.hour_of_day());
188 ranker_example_features["DayOfWeek"].set_int32_value(features.day_of_week());
189 ranker_example_features["LaunchesThisSession"].set_int32_value(
190 features.launches_this_session());
191 if (features.has_file_extension()) {
192 ranker_example_features["FileExtension"].set_int32_value(
193 features.file_extension());
194 }
195 if (features.has_time_since_last_launch()) {
196 ranker_example_features["TimeSinceLastLaunch"].set_int32_value(
197 features.time_since_last_launch());
198 ranker_example_features["TimeOfLastLaunch"].set_int32_value(
199 features.time_of_last_launch());
200 }
201 const auto& launches = features.launches_at_hour();
202 for (int hour = 0; hour < launches.size(); hour++) {
203 ranker_example_features["LaunchesAtHour" + base::StringPrintf("%02d", hour)]
204 .set_int32_value(launches[hour]);
205 }
206 if (features.has_domain()) {
207 ranker_example_features["Domain"].set_string_value(features.domain());
208 ranker_example_features["HasDomain"].set_int32_value(1);
209 }
210 }
211
212 // Loads the preprocessor config protobuf, which will be used later to convert
213 // a RankerExample to a vectorized float for inactivity score calculation.
214 // Returns nullptr if cannot load or parse the config.
215 std::unique_ptr<assist_ranker::ExamplePreprocessorConfig>
LoadExamplePreprocessorConfig()216 LoadExamplePreprocessorConfig() {
217 auto config = std::make_unique<assist_ranker::ExamplePreprocessorConfig>();
218
219 const int res_id = IDR_SEARCH_RANKER_20190923_EXAMPLE_PREPROCESSOR_CONFIG_PB;
220
221 scoped_refptr<base::RefCountedMemory> raw_config =
222 ui::ResourceBundle::GetSharedInstance().LoadDataResourceBytes(res_id);
223 if (!raw_config || !raw_config->front()) {
224 LOG(ERROR) << "Failed to load SearchRanker example preprocessor config.";
225 LogInferenceError(InferenceError::kLoadExamplePreprocessorConfigFailed);
226 return nullptr;
227 }
228
229 if (!config->ParseFromArray(raw_config->front(), raw_config->size())) {
230 LOG(ERROR) << "Failed to parse SearchRanker example preprocessor config.";
231 LogInferenceError(InferenceError::kLoadExamplePreprocessorConfigFailed);
232 return nullptr;
233 }
234
235 return config;
236 }
237 } // namespace
238
SearchRankingEventLogger(Profile * profile,SearchController * search_controller)239 SearchRankingEventLogger::SearchRankingEventLogger(
240 Profile* profile,
241 SearchController* search_controller)
242 : search_controller_(search_controller),
243 ukm_recorder_(ukm::UkmRecorder::Get()),
244 ukm_background_recorder_(
245 ukm::UkmBackgroundRecorderFactory::GetForProfile(profile)),
246 weak_factory_(this) {
247 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
248 DCHECK(search_controller_);
249 }
250
251 SearchRankingEventLogger::~SearchRankingEventLogger() = default;
252
253 SearchRankingEventLogger::ResultState::ResultState() = default;
254 SearchRankingEventLogger::ResultState::~ResultState() = default;
255
SetEventRecordedForTesting(base::OnceClosure closure)256 void SearchRankingEventLogger::SetEventRecordedForTesting(
257 base::OnceClosure closure) {
258 event_recorded_for_testing_ = std::move(closure);
259 }
260
PopulateSearchRankingItem(SearchRankingItem * proto,ChromeSearchResult * search_result,int query_length,bool use_for_logging)261 void SearchRankingEventLogger::PopulateSearchRankingItem(
262 SearchRankingItem* proto,
263 ChromeSearchResult* search_result,
264 int query_length,
265 bool use_for_logging) {
266 const base::Time now = base::Time::Now();
267 base::Time::Exploded now_exploded;
268 now.LocalExplode(&now_exploded);
269
270 auto& features = *proto->mutable_features();
271 features.set_category(static_cast<int>(CategoryFromResultType(
272 search_result->result_type(), search_result->result_subtype())));
273
274 // Note this is the search provider's original relevance score, not
275 // tweaked by any search ranking models. Scores are floats in 0 to 1, and
276 // we map this to ints 0 to 100.
277 features.set_relevance_score(
278 static_cast<int>(100 * search_result->relevance()));
279 features.set_hour_of_day(now_exploded.hour);
280 features.set_day_of_week(now_exploded.day_of_week);
281 features.set_query_length(GetExponentialBucketMinForCounts1000(query_length));
282
283 if (features.category() == static_cast<int>(Category::FILE)) {
284 features.set_file_extension(ExtensionTypeFromFileName(search_result->id()));
285 }
286
287 if (search_result->result_type() == ash::AppListSearchResultType::kOmnibox) {
288 // The id metadata of an OmniboxResult is a stripped URL, which does not
289 // correspond to the URL that will be navigated to.
290 proto->set_target(
291 static_cast<OmniboxResult*>(search_result)->DestinationURL().spec());
292 } else {
293 proto->set_target(search_result->id());
294 }
295
296 const std::string& domain = GURL(search_result->id()).host();
297 if (!domain.empty()) {
298 features.set_domain(domain);
299 }
300
301 // If the proto is created for logging purposes, create a new item in the map.
302 // Otherwise lookup the map for event info and create a "dummy" event info if
303 // doesn't nothing found.
304 ResultState* event_info;
305 ResultState dummy_event_info;
306 if (use_for_logging) {
307 event_info = &id_to_result_state_[proto->target()];
308 } else {
309 const auto& it = id_to_result_state_.find(proto->target());
310 if (it != id_to_result_state_.end()) {
311 event_info = &it->second;
312 } else {
313 event_info = &dummy_event_info;
314 }
315 }
316
317 if (event_info->last_launch.has_value()) {
318 base::Time last_launch = event_info->last_launch.value();
319 base::Time::Exploded last_launch_exploded;
320 last_launch.LocalExplode(&last_launch_exploded);
321
322 features.set_time_since_last_launch(
323 GetExponentialBucketMinForSeconds((now - last_launch).InSeconds()));
324 features.set_time_of_last_launch(last_launch_exploded.hour);
325
326 // Reset the number of launches this hour to 0 if this is the first
327 // launch today of this event, to account for user sessions spanning
328 // multiple days.
329 if (features.has_is_launched() && features.is_launched() == 1 &&
330 now - event_info->last_launch.value() >=
331 base::TimeDelta::FromHours(23)) {
332 event_info->launches_per_hour[now_exploded.hour] = 0;
333 }
334 }
335
336 features.set_launches_this_session(
337 GetExponentialBucketMinForCounts1000(event_info->launches_this_session));
338
339 const auto& launches = event_info->launches_per_hour;
340 features.add_launches_at_hour(
341 GetExponentialBucketMinForCounts1000(launches[0]));
342 features.add_launches_at_hour(
343 GetExponentialBucketMinForCounts1000(launches[1]));
344 features.add_launches_at_hour(
345 GetExponentialBucketMinForCounts1000(launches[2]));
346 features.add_launches_at_hour(
347 GetExponentialBucketMinForCounts1000(launches[3]));
348 features.add_launches_at_hour(
349 GetExponentialBucketMinForCounts1000(launches[4]));
350 features.add_launches_at_hour(
351 GetExponentialBucketMinForCounts1000(launches[5]));
352 features.add_launches_at_hour(
353 GetExponentialBucketMinForCounts1000(launches[6]));
354 features.add_launches_at_hour(
355 GetExponentialBucketMinForCounts1000(launches[7]));
356 features.add_launches_at_hour(
357 GetExponentialBucketMinForCounts1000(launches[8]));
358 features.add_launches_at_hour(
359 GetExponentialBucketMinForCounts1000(launches[9]));
360 features.add_launches_at_hour(
361 GetExponentialBucketMinForCounts1000(launches[10]));
362 features.add_launches_at_hour(
363 GetExponentialBucketMinForCounts1000(launches[11]));
364 features.add_launches_at_hour(
365 GetExponentialBucketMinForCounts1000(launches[12]));
366 features.add_launches_at_hour(
367 GetExponentialBucketMinForCounts1000(launches[13]));
368 features.add_launches_at_hour(
369 GetExponentialBucketMinForCounts1000(launches[14]));
370 features.add_launches_at_hour(
371 GetExponentialBucketMinForCounts1000(launches[15]));
372 features.add_launches_at_hour(
373 GetExponentialBucketMinForCounts1000(launches[16]));
374 features.add_launches_at_hour(
375 GetExponentialBucketMinForCounts1000(launches[17]));
376 features.add_launches_at_hour(
377 GetExponentialBucketMinForCounts1000(launches[18]));
378 features.add_launches_at_hour(
379 GetExponentialBucketMinForCounts1000(launches[19]));
380 features.add_launches_at_hour(
381 GetExponentialBucketMinForCounts1000(launches[20]));
382 features.add_launches_at_hour(
383 GetExponentialBucketMinForCounts1000(launches[21]));
384 features.add_launches_at_hour(
385 GetExponentialBucketMinForCounts1000(launches[22]));
386 features.add_launches_at_hour(
387 GetExponentialBucketMinForCounts1000(launches[23]));
388
389 if (features.has_is_launched() && features.is_launched() == 1) {
390 event_info->last_launch = now;
391 event_info->launches_this_session += 1;
392 event_info->launches_per_hour[now_exploded.hour] += 1;
393 }
394 }
395
Log(const base::string16 & trimmed_query,const ash::SearchResultIdWithPositionIndices & search_results,int launched_index)396 void SearchRankingEventLogger::Log(
397 const base::string16& trimmed_query,
398 const ash::SearchResultIdWithPositionIndices& search_results,
399 int launched_index) {
400 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
401 for (const auto& id_index : search_results) {
402 auto* result = search_controller_->FindSearchResult(id_index.id);
403 if (!result)
404 continue;
405
406 SearchRankingItem proto;
407 proto.mutable_features()->set_position(id_index.position_index);
408 proto.set_event_id(next_event_id_);
409 proto.mutable_features()->set_is_launched(
410 id_index.position_index == launched_index ? 1 : 0);
411 PopulateSearchRankingItem(&proto, result, trimmed_query.size(),
412 true /*use_for_logging*/);
413
414 // Omnibox results have associated URLs, so are logged keyed on the URL
415 // after validating that it exists in the history service. Other results
416 // have no associated URL, so use a blank source ID.
417 if (result->result_type() == ash::AppListSearchResultType::kOmnibox) {
418 // When an omnibox result is launched, we need to retrieve a source ID
419 // using the history service. This may be the first time the URL is used
420 // and so it must be committed to the history service database before we
421 // retrieve it, which happens once the page has loaded. So we delay our
422 // check for long enough that most pages will have loaded.
423 if (launched_index == id_index.position_index) {
424 base::SequencedTaskRunnerHandle::Get()->PostDelayedTask(
425 FROM_HERE,
426 base::BindOnce(
427 &SearchRankingEventLogger::GetBackgroundSourceIdAndLogEvent,
428 weak_factory_.GetWeakPtr(), proto),
429 kDelayForHistoryService);
430 } else {
431 GetBackgroundSourceIdAndLogEvent(proto);
432 }
433 } else {
434 LogEvent(proto, base::nullopt);
435 }
436 }
437
438 ++next_event_id_;
439 }
440
GetBackgroundSourceIdAndLogEvent(const SearchRankingItem & result)441 void SearchRankingEventLogger::GetBackgroundSourceIdAndLogEvent(
442 const SearchRankingItem& result) {
443 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
444 ukm_background_recorder_->GetBackgroundSourceIdIfAllowed(
445 url::Origin::Create(GURL(result.target())),
446 base::BindOnce(&SearchRankingEventLogger::LogEvent,
447 base::Unretained(this), result));
448 }
449
LogEvent(const SearchRankingItem & result,base::Optional<ukm::SourceId> source_id)450 void SearchRankingEventLogger::LogEvent(
451 const SearchRankingItem& result,
452 base::Optional<ukm::SourceId> source_id) {
453 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
454 if (!source_id)
455 source_id = ukm_recorder_->GetNewSourceID();
456
457 ukm::builders::AppListNonAppImpression event(source_id.value());
458 event.SetEventId(result.event_id())
459 .SetPosition(result.features().position())
460 .SetIsLaunched(result.features().is_launched())
461 .SetQueryLength(result.features().query_length())
462 .SetRelevanceScore(result.features().relevance_score())
463 .SetCategory(result.features().category())
464 .SetHourOfDay(result.features().hour_of_day())
465 .SetDayOfWeek(result.features().day_of_week())
466 .SetLaunchesThisSession(result.features().launches_this_session());
467
468 if (result.features().has_file_extension()) {
469 event.SetFileExtension(result.features().file_extension());
470 }
471
472 if (result.features().has_time_since_last_launch()) {
473 event.SetTimeSinceLastLaunch(result.features().time_since_last_launch());
474 event.SetTimeOfLastLaunch(result.features().time_of_last_launch());
475 }
476
477 const auto& launches = result.features().launches_at_hour();
478 event.SetLaunchesAtHour00(launches[0]);
479 event.SetLaunchesAtHour01(launches[1]);
480 event.SetLaunchesAtHour02(launches[2]);
481 event.SetLaunchesAtHour03(launches[3]);
482 event.SetLaunchesAtHour04(launches[4]);
483 event.SetLaunchesAtHour05(launches[5]);
484 event.SetLaunchesAtHour06(launches[6]);
485 event.SetLaunchesAtHour07(launches[7]);
486 event.SetLaunchesAtHour08(launches[8]);
487 event.SetLaunchesAtHour09(launches[9]);
488 event.SetLaunchesAtHour10(launches[10]);
489 event.SetLaunchesAtHour11(launches[11]);
490 event.SetLaunchesAtHour12(launches[12]);
491 event.SetLaunchesAtHour13(launches[13]);
492 event.SetLaunchesAtHour14(launches[14]);
493 event.SetLaunchesAtHour15(launches[15]);
494 event.SetLaunchesAtHour16(launches[16]);
495 event.SetLaunchesAtHour17(launches[17]);
496 event.SetLaunchesAtHour18(launches[18]);
497 event.SetLaunchesAtHour19(launches[19]);
498 event.SetLaunchesAtHour20(launches[20]);
499 event.SetLaunchesAtHour21(launches[21]);
500 event.SetLaunchesAtHour22(launches[22]);
501 event.SetLaunchesAtHour23(launches[23]);
502
503 event.Record(ukm_recorder_);
504
505 if (event_recorded_for_testing_)
506 std::move(event_recorded_for_testing_).Run();
507 }
508
CreateRankings(Mixer::SortedResults * results,int query_length)509 void SearchRankingEventLogger::CreateRankings(Mixer::SortedResults* results,
510 int query_length) {
511 for (const auto& result : *results) {
512 if (!result.result) {
513 continue;
514 }
515 SearchRankingItem proto;
516 std::vector<float> vectorized_features;
517
518 PopulateSearchRankingItem(&proto, result.result, query_length,
519 false /*use_for_logging*/);
520 if (!PreprocessInput(proto.features(), &vectorized_features)) {
521 return;
522 }
523 DoInference(vectorized_features, result.result->id());
524 }
525 }
526
RetrieveRankings()527 std::map<std::string, float> SearchRankingEventLogger::RetrieveRankings() {
528 return prediction_;
529 }
530
LazyInitialize()531 void SearchRankingEventLogger::LazyInitialize() {
532 if (!preprocessor_config_) {
533 preprocessor_config_ = LoadExamplePreprocessorConfig();
534 }
535 }
536
PreprocessInput(const SearchRankingItem::Features & features,std::vector<float> * vectorized_features)537 bool SearchRankingEventLogger::PreprocessInput(
538 const SearchRankingItem::Features& features,
539 std::vector<float>* vectorized_features) {
540 DCHECK(vectorized_features);
541 LazyInitialize();
542
543 if (!preprocessor_config_) {
544 LOG(ERROR) << "Failed to create preprocessor config.";
545 LogInferenceError(InferenceError::kLoadExamplePreprocessorConfigFailed);
546 return false;
547 }
548
549 assist_ranker::RankerExample ranker_example;
550 PopulateRankerExample(features, &ranker_example);
551
552 int preprocessor_error = assist_ranker::ExamplePreprocessor::Process(
553 *preprocessor_config_, &ranker_example, true);
554 // kNoFeatureIndexFound can occur normally (e.g., when the domain name
555 // isn't known to the model or a rarely seen enum value is used).
556 if (preprocessor_error != assist_ranker::ExamplePreprocessor::kSuccess &&
557 preprocessor_error !=
558 assist_ranker::ExamplePreprocessor::kNoFeatureIndexFound) {
559 LOG(ERROR) << "Failed to vectorize features using ExamplePreprocessor.";
560 LogInferenceError(InferenceError::kVectorizeFeaturesFailed);
561 return false;
562 }
563
564 const auto& extracted_features =
565 ranker_example.features()
566 .at(assist_ranker::ExamplePreprocessor::kVectorizedFeatureDefaultName)
567 .float_list()
568 .float_value();
569 vectorized_features->assign(extracted_features.begin(),
570 extracted_features.end());
571 return true;
572 }
573
DoInference(const std::vector<float> & features,const std::string & id)574 void SearchRankingEventLogger::DoInference(const std::vector<float>& features,
575 const std::string& id) {
576 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
577 BindGraphExecutorIfNeeded();
578
579 // Prepare the input tensor.
580 base::flat_map<std::string, TensorPtr> inputs;
581 auto tensor = Tensor::New();
582 tensor->shape = Int64List::New();
583 tensor->shape->value = std::vector<int64_t>({1, features.size()});
584 tensor->data = ValueList::New();
585 tensor->data->set_float_list(FloatList::New());
586 tensor->data->get_float_list()->value =
587 std::vector<double>(std::begin(features), std::end(features));
588 inputs.emplace(std::string("input"), std::move(tensor));
589
590 const std::vector<std::string> outputs({std::string("output")});
591 // Execute
592 executor_->Execute(std::move(inputs), std::move(outputs),
593 base::BindOnce(&SearchRankingEventLogger::ExecuteCallback,
594 weak_factory_.GetWeakPtr(), id));
595 }
596
BindGraphExecutorIfNeeded()597 void SearchRankingEventLogger::BindGraphExecutorIfNeeded() {
598 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
599 if (!model_) {
600 // Load the model.
601 auto spec = BuiltinModelSpec::New(BuiltinModelId::SEARCH_RANKER_20190923);
602 chromeos::machine_learning::ServiceConnection::GetInstance()
603 ->LoadBuiltinModel(std::move(spec), model_.BindNewPipeAndPassReceiver(),
604 base::BindOnce(&LoadModelCallback));
605 }
606
607 if (!executor_) {
608 // Get the graph executor.
609 model_->CreateGraphExecutor(executor_.BindNewPipeAndPassReceiver(),
610 base::BindOnce(&CreateGraphExecutorCallback));
611 executor_.set_disconnect_handler(base::BindOnce(
612 &SearchRankingEventLogger::OnConnectionError, base::Unretained(this)));
613 }
614 }
615
OnConnectionError()616 void SearchRankingEventLogger::OnConnectionError() {
617 LOG(WARNING) << "Mojo connection for ML service closed.";
618 executor_.reset();
619 model_.reset();
620 }
621
ExecuteCallback(const std::string & id,ExecuteResult result,const base::Optional<std::vector<TensorPtr>> outputs)622 void SearchRankingEventLogger::ExecuteCallback(
623 const std::string& id,
624 ExecuteResult result,
625 const base::Optional<std::vector<TensorPtr>> outputs) {
626 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
627 if (result != ExecuteResult::OK) {
628 LOG(ERROR) << "Search Ranker inference execution failed.";
629 LogInferenceError(InferenceError::kInferenceExecutionFailed);
630 return;
631 }
632 prediction_[id] = outputs.value()[0]->data->get_float_list()->value[0];
633 }
634
635 } // namespace app_list
636