1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_predictor.h"
6
7 #include <cmath>
8 #include <utility>
9
10 #include "base/time/time.h"
11 #include "chrome/browser/ui/app_list/search/search_result_ranker/frecency_store.pb.h"
12 #include "chrome/browser/ui/app_list/search/search_result_ranker/histogram_util.h"
13 #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_predictor.pb.h"
14 #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_ranker_util.h"
15
16 namespace app_list {
17 namespace {
18
19 constexpr int kHoursADay = 24;
20
21 } // namespace
22
RecurrencePredictor(const std::string & model_identifier)23 RecurrencePredictor::RecurrencePredictor(const std::string& model_identifier)
24 : model_identifier_(model_identifier) {}
25
FakePredictor(const std::string & model_identifier)26 FakePredictor::FakePredictor(const std::string& model_identifier)
27 : RecurrencePredictor(model_identifier) {
28 // The fake predictor should only be used for testing, not in production.
29 // Record an error so we know if it is being used.
30 LogInitializationStatus(model_identifier_,
31 InitializationStatus::kFakePredictorUsed);
32 }
33
FakePredictor(const FakePredictorConfig & config,const std::string & model_identifier)34 FakePredictor::FakePredictor(const FakePredictorConfig& config,
35 const std::string& model_identifier)
36 : RecurrencePredictor(model_identifier) {
37 // The fake predictor should only be used for testing, not in production.
38 // Record an error so we know if it is being used.
39 LogInitializationStatus(model_identifier_,
40 InitializationStatus::kFakePredictorUsed);
41 }
42
43 FakePredictor::~FakePredictor() = default;
44
45 const char FakePredictor::kPredictorName[] = "FakePredictor";
GetPredictorName() const46 const char* FakePredictor::GetPredictorName() const {
47 return kPredictorName;
48 }
49
Train(unsigned int target,unsigned int condition)50 void FakePredictor::Train(unsigned int target, unsigned int condition) {
51 counts_[target] += 1.0f;
52 }
53
Rank(unsigned int condition)54 std::map<unsigned int, float> FakePredictor::Rank(unsigned int condition) {
55 return counts_;
56 }
57
Cleanup(const std::vector<unsigned int> & valid_targets)58 void FakePredictor::Cleanup(const std::vector<unsigned int>& valid_targets) {
59 std::map<unsigned int, float> new_counts;
60
61 for (unsigned int id : valid_targets) {
62 const auto& it = counts_.find(id);
63 if (it != counts_.end())
64 new_counts[id] = it->second;
65 }
66
67 counts_.swap(new_counts);
68 }
69
ToProto(RecurrencePredictorProto * proto) const70 void FakePredictor::ToProto(RecurrencePredictorProto* proto) const {
71 auto* counts = proto->mutable_fake_predictor()->mutable_counts();
72 for (auto& pair : counts_)
73 (*counts)[pair.first] = pair.second;
74 }
75
FromProto(const RecurrencePredictorProto & proto)76 void FakePredictor::FromProto(const RecurrencePredictorProto& proto) {
77 if (!proto.has_fake_predictor()) {
78 LogSerializationStatus(model_identifier_,
79 SerializationStatus::kFakePredictorLoadingError);
80 return;
81 }
82
83 for (const auto& pair : proto.fake_predictor().counts())
84 counts_[pair.first] = pair.second;
85 }
86
DefaultPredictor(const DefaultPredictorConfig & config,const std::string & model_identifier)87 DefaultPredictor::DefaultPredictor(const DefaultPredictorConfig& config,
88 const std::string& model_identifier)
89 : RecurrencePredictor(model_identifier) {}
~DefaultPredictor()90 DefaultPredictor::~DefaultPredictor() {}
91
Train(unsigned int target,unsigned int condition)92 void DefaultPredictor::Train(unsigned int target, unsigned int condition) {}
93
Rank(unsigned int condition)94 std::map<unsigned int, float> DefaultPredictor::Rank(unsigned int condition) {
95 NOTREACHED();
96 return {};
97 }
98
99 const char DefaultPredictor::kPredictorName[] = "DefaultPredictor";
GetPredictorName() const100 const char* DefaultPredictor::GetPredictorName() const {
101 return kPredictorName;
102 }
103
ToProto(RecurrencePredictorProto * proto) const104 void DefaultPredictor::ToProto(RecurrencePredictorProto* proto) const {}
105
FromProto(const RecurrencePredictorProto & proto)106 void DefaultPredictor::FromProto(const RecurrencePredictorProto& proto) {}
107
ConditionalFrequencyPredictor(const std::string & model_identifier)108 ConditionalFrequencyPredictor::ConditionalFrequencyPredictor(
109 const std::string& model_identifier)
110 : RecurrencePredictor(model_identifier) {}
ConditionalFrequencyPredictor(const ConditionalFrequencyPredictorConfig & config,const std::string & model_identifier)111 ConditionalFrequencyPredictor::ConditionalFrequencyPredictor(
112 const ConditionalFrequencyPredictorConfig& config,
113 const std::string& model_identifier)
114 : RecurrencePredictor(model_identifier) {}
115 ConditionalFrequencyPredictor::~ConditionalFrequencyPredictor() = default;
116
117 ConditionalFrequencyPredictor::Events::Events() = default;
118 ConditionalFrequencyPredictor::Events::~Events() = default;
119 ConditionalFrequencyPredictor::Events::Events(const Events& other) = default;
120
121 const char ConditionalFrequencyPredictor::kPredictorName[] =
122 "ConditionalFrequencyPredictor";
GetPredictorName() const123 const char* ConditionalFrequencyPredictor::GetPredictorName() const {
124 return kPredictorName;
125 }
126
Train(unsigned int target,unsigned int condition)127 void ConditionalFrequencyPredictor::Train(unsigned int target,
128 unsigned int condition) {
129 TrainWithDelta(target, condition, 1.0f);
130 }
131
TrainWithDelta(unsigned int target,unsigned int condition,float delta)132 void ConditionalFrequencyPredictor::TrainWithDelta(unsigned int target,
133 unsigned int condition,
134 float delta) {
135 DCHECK_NE(delta, 0.0f);
136 auto& events = table_[condition];
137 events.freqs[target] += delta;
138 events.total += delta;
139 }
140
Rank(unsigned int condition)141 std::map<unsigned int, float> ConditionalFrequencyPredictor::Rank(
142 unsigned int condition) {
143 const auto& it = table_.find(condition);
144 // If the total frequency is zero, we can't return any meaningful results, so
145 // return empty.
146 if (it == table_.end() || it->second.total == 0.0f)
147 return {};
148
149 std::map<unsigned int, float> result;
150 const auto& events = it->second;
151 for (const auto& target_freq : events.freqs)
152 result[target_freq.first] = target_freq.second / events.total;
153 return result;
154 }
155
Cleanup(const std::vector<unsigned int> & valid_targets)156 void ConditionalFrequencyPredictor::Cleanup(
157 const std::vector<unsigned int>& valid_targets) {
158 for (auto iter = table_.begin(); iter != table_.end();) {
159 auto& events = iter->second;
160
161 std::map<unsigned int, float> new_freqs;
162 float new_total = 0.0f;
163 for (unsigned int id : valid_targets) {
164 const auto& it = events.freqs.find(id);
165 if (it != events.freqs.end()) {
166 new_freqs[id] = it->second;
167 new_total += it->second;
168 }
169 }
170
171 // Delete the whole condition out of the table if it contains no valid
172 // targets.
173 if (new_freqs.empty()) {
174 // C++11: the return value of erase(iter) is an iterator pointing to the
175 // next element in the container.
176 iter = table_.erase(iter);
177 } else {
178 ++iter;
179 events.freqs.swap(new_freqs);
180 events.total = new_total;
181 }
182 }
183 }
184
CleanupConditions(const std::vector<unsigned int> & valid_conditions)185 void ConditionalFrequencyPredictor::CleanupConditions(
186 const std::vector<unsigned int>& valid_conditions) {
187 std::map<unsigned int, ConditionalFrequencyPredictor::Events> new_table;
188
189 for (unsigned int id : valid_conditions) {
190 const auto& it = table_.find(id);
191 if (it != table_.end()) {
192 new_table[id] = std::move(it->second);
193 }
194 }
195
196 table_.swap(new_table);
197 }
198
ToProto(RecurrencePredictorProto * proto) const199 void ConditionalFrequencyPredictor::ToProto(
200 RecurrencePredictorProto* proto) const {
201 auto* predictor = proto->mutable_conditional_frequency_predictor();
202 for (const auto& condition_events : table_) {
203 for (const auto& event_freq : condition_events.second.freqs) {
204 auto* event = predictor->add_events();
205 event->set_condition(condition_events.first);
206 event->set_event(event_freq.first);
207 event->set_freq(event_freq.second);
208 }
209 }
210 }
211
FromProto(const RecurrencePredictorProto & proto)212 void ConditionalFrequencyPredictor::FromProto(
213 const RecurrencePredictorProto& proto) {
214 if (!proto.has_conditional_frequency_predictor()) {
215 LogSerializationStatus(
216 model_identifier_,
217 SerializationStatus::kConditionalFrequencyPredictorLoadingError);
218 return;
219 }
220
221 for (const auto& event : proto.conditional_frequency_predictor().events()) {
222 auto& events = table_[event.condition()];
223 events.freqs[event.event()] = event.freq();
224 events.total += event.freq();
225 }
226 }
227
FrecencyPredictor(const FrecencyPredictorConfig & config,const std::string & model_identifier)228 FrecencyPredictor::FrecencyPredictor(const FrecencyPredictorConfig& config,
229 const std::string& model_identifier)
230 : RecurrencePredictor(model_identifier),
231 decay_coeff_(config.decay_coeff()) {}
232 FrecencyPredictor::~FrecencyPredictor() = default;
233
234 const char FrecencyPredictor::kPredictorName[] = "FrecencyPredictor";
GetPredictorName() const235 const char* FrecencyPredictor::GetPredictorName() const {
236 return kPredictorName;
237 }
238
Train(unsigned int target,unsigned int condition)239 void FrecencyPredictor::Train(unsigned int target, unsigned int condition) {
240 ++num_updates_;
241 TargetData& data = targets_[target];
242 DecayScore(&data);
243 data.last_score += 1.0f - decay_coeff_;
244 }
245
Rank(unsigned int condition)246 std::map<unsigned int, float> FrecencyPredictor::Rank(unsigned int condition) {
247 float total = 0.0f;
248 for (auto& pair : targets_) {
249 DecayScore(&pair.second);
250 total += pair.second.last_score;
251 }
252 if (total == 0.0f)
253 return {};
254
255 std::map<unsigned int, float> result;
256 for (auto& pair : targets_) {
257 result[pair.first] = pair.second.last_score / total;
258 }
259 return result;
260 }
261
Cleanup(const std::vector<unsigned int> & valid_targets)262 void FrecencyPredictor::Cleanup(
263 const std::vector<unsigned int>& valid_targets) {
264 std::map<unsigned int, FrecencyPredictor::TargetData> new_targets;
265
266 for (unsigned int id : valid_targets) {
267 const auto& it = targets_.find(id);
268 if (it != targets_.end())
269 new_targets[id] = it->second;
270 }
271
272 targets_.swap(new_targets);
273 }
274
ToProto(RecurrencePredictorProto * proto) const275 void FrecencyPredictor::ToProto(RecurrencePredictorProto* proto) const {
276 auto* predictor = proto->mutable_frecency_predictor();
277
278 predictor->set_num_updates(num_updates_);
279
280 for (const auto& pair : targets_) {
281 auto* target_data = predictor->add_targets();
282 target_data->set_id(pair.first);
283 target_data->set_last_score(pair.second.last_score);
284 target_data->set_last_num_updates(pair.second.last_num_updates);
285 }
286 }
287
FromProto(const RecurrencePredictorProto & proto)288 void FrecencyPredictor::FromProto(const RecurrencePredictorProto& proto) {
289 if (!proto.has_frecency_predictor()) {
290 LogSerializationStatus(model_identifier_,
291 SerializationStatus::kFrecencyPredictorLoadingError);
292 return;
293 }
294 const auto& predictor = proto.frecency_predictor();
295
296 num_updates_ = predictor.num_updates();
297
298 std::map<unsigned int, TargetData> targets;
299 for (const auto& target_data : predictor.targets()) {
300 targets[target_data.id()] = {target_data.last_score(),
301 target_data.last_num_updates()};
302 }
303 targets_.swap(targets);
304 }
305
DecayScore(TargetData * data)306 void FrecencyPredictor::DecayScore(TargetData* data) {
307 int time_since_update = num_updates_ - data->last_num_updates;
308
309 if (time_since_update > 0) {
310 data->last_score *= std::pow(decay_coeff_, time_since_update);
311 data->last_num_updates = num_updates_;
312 }
313 }
314
HourBinPredictor(const HourBinPredictorConfig & config,const std::string & model_identifier)315 HourBinPredictor::HourBinPredictor(const HourBinPredictorConfig& config,
316 const std::string& model_identifier)
317 : RecurrencePredictor(model_identifier),
318 weekly_decay_coeff_(config.weekly_decay_coeff()) {
319 for (const auto& pair : config.bin_weights())
320 bin_weights_[pair.bin()] = pair.weight();
321
322 if (!proto_.has_last_decay_timestamp())
323 SetLastDecayTimestamp(
324 base::Time::Now().ToDeltaSinceWindowsEpoch().InDays());
325 }
326
327 HourBinPredictor::~HourBinPredictor() = default;
328
329 const char HourBinPredictor::kPredictorName[] = "HourBinPredictor";
330
GetPredictorName() const331 const char* HourBinPredictor::GetPredictorName() const {
332 return kPredictorName;
333 }
334
GetBinFromHourDifference(int hour_difference) const335 int HourBinPredictor::GetBinFromHourDifference(int hour_difference) const {
336 base::Time shifted_time =
337 base::Time::Now() + base::TimeDelta::FromHours(hour_difference);
338 base::Time::Exploded exploded_time;
339 shifted_time.LocalExplode(&exploded_time);
340
341 const bool is_weekend =
342 exploded_time.day_of_week == 6 || exploded_time.day_of_week == 0;
343
344 // To distinguish workdays from weekend, use now.hour for workdays and
345 // now.hour + 24 for weekend.
346 if (!is_weekend) {
347 return exploded_time.hour;
348 } else {
349 return exploded_time.hour + kHoursADay;
350 }
351 }
352
GetBin() const353 int HourBinPredictor::GetBin() const {
354 return GetBinFromHourDifference(0);
355 }
356
Train(unsigned int target,unsigned int condition)357 void HourBinPredictor::Train(unsigned int target, unsigned int condition) {
358 int hour = GetBin();
359 auto& frequency_table = (*proto_.mutable_binned_frequency_table())[hour];
360 frequency_table.set_total_counts(frequency_table.total_counts() + 1);
361 (*frequency_table.mutable_frequency())[target] += 1;
362 }
363
Rank(unsigned int condition)364 std::map<unsigned int, float> HourBinPredictor::Rank(unsigned int condition) {
365 std::map<unsigned int, float> ranks;
366 const auto& frequency_table_map = proto_.binned_frequency_table();
367 for (const auto& hour_and_weight : bin_weights_) {
368 // Find adjacent bin and weight.
369 const int adj_bin = GetBinFromHourDifference(hour_and_weight.first);
370 const float weight = hour_and_weight.second;
371
372 const auto find_frequency_table = frequency_table_map.find(adj_bin);
373 if (find_frequency_table == frequency_table_map.end())
374 continue;
375 const auto& frequency_table = find_frequency_table->second;
376
377 // Accumulates the frequency to the output.
378 if (frequency_table.total_counts() > 0) {
379 const int total_counts = frequency_table.total_counts();
380 for (const auto& pair : frequency_table.frequency()) {
381 ranks[pair.first] +=
382 static_cast<float>(pair.second) / total_counts * weight;
383 }
384 }
385 }
386 return ranks;
387 }
388
389 // TODO(921444): Unify the hour bin predictor with the cleanup system used for
390 // other predictors. This is different than other predictors so as to be exactly
391 // the same as the Roselle predictor.
392
ToProto(RecurrencePredictorProto * proto) const393 void HourBinPredictor::ToProto(RecurrencePredictorProto* proto) const {
394 *proto->mutable_hour_bin_predictor() = proto_;
395 }
396
FromProto(const RecurrencePredictorProto & proto)397 void HourBinPredictor::FromProto(const RecurrencePredictorProto& proto) {
398 if (!proto.has_hour_bin_predictor()) {
399 LogSerializationStatus(model_identifier_,
400 SerializationStatus::kHourBinPredictorLoadingError);
401 return;
402 }
403
404 proto_ = proto.hour_bin_predictor();
405 if (ShouldDecay())
406 DecayAll();
407 }
408
ShouldDecay()409 bool HourBinPredictor::ShouldDecay() {
410 const int today = base::Time::Now().ToDeltaSinceWindowsEpoch().InDays();
411 // Check if we should decay the frequency
412 return today - proto_.last_decay_timestamp() > 7;
413 }
414
DecayAll()415 void HourBinPredictor::DecayAll() {
416 SetLastDecayTimestamp(base::Time::Now().ToDeltaSinceWindowsEpoch().InDays());
417 auto& frequency_table_map = *proto_.mutable_binned_frequency_table();
418 for (auto it_table = frequency_table_map.begin();
419 it_table != frequency_table_map.end();) {
420 auto& frequency_table = *it_table->second.mutable_frequency();
421 for (auto it_freq = frequency_table.begin();
422 it_freq != frequency_table.end();) {
423 const int new_frequency = it_freq->second * weekly_decay_coeff_;
424 it_table->second.set_total_counts(it_table->second.total_counts() -
425 it_freq->second + new_frequency);
426 it_freq->second = new_frequency;
427
428 // Remove item that has zero frequency
429 if (it_freq->second == 0) {
430 frequency_table.erase(it_freq++);
431 } else {
432 it_freq++;
433 }
434 }
435
436 // Remove bin that has zero total_counts
437 if (it_table->second.total_counts() == 0) {
438 frequency_table_map.erase(it_table++);
439 } else {
440 it_table++;
441 }
442 }
443 }
444
MarkovPredictor(const MarkovPredictorConfig & config,const std::string & model_identifier)445 MarkovPredictor::MarkovPredictor(const MarkovPredictorConfig& config,
446 const std::string& model_identifier)
447 : RecurrencePredictor(model_identifier) {
448 frequencies_ =
449 std::make_unique<ConditionalFrequencyPredictor>(model_identifier);
450 }
451 MarkovPredictor::~MarkovPredictor() = default;
452
453 const char MarkovPredictor::kPredictorName[] = "MarkovPredictor";
GetPredictorName() const454 const char* MarkovPredictor::GetPredictorName() const {
455 return kPredictorName;
456 }
457
Train(unsigned int target,unsigned int condition)458 void MarkovPredictor::Train(unsigned int target, unsigned int condition) {
459 if (previous_target_)
460 frequencies_->Train(target, previous_target_.value());
461 previous_target_ = target;
462 }
463
Rank(unsigned int condition)464 std::map<unsigned int, float> MarkovPredictor::Rank(unsigned int condition) {
465 if (previous_target_)
466 return frequencies_->Rank(previous_target_.value());
467 return std::map<unsigned int, float>();
468 }
469
Cleanup(const std::vector<unsigned int> & valid_targets)470 void MarkovPredictor::Cleanup(const std::vector<unsigned int>& valid_targets) {
471 frequencies_->CleanupConditions(valid_targets);
472 frequencies_->Cleanup(valid_targets);
473 }
474
ToProto(RecurrencePredictorProto * proto) const475 void MarkovPredictor::ToProto(RecurrencePredictorProto* proto) const {
476 auto* predictor = proto->mutable_markov_predictor();
477 frequencies_->ToProto(predictor->mutable_frequencies());
478 }
479
FromProto(const RecurrencePredictorProto & proto)480 void MarkovPredictor::FromProto(const RecurrencePredictorProto& proto) {
481 if (!proto.has_markov_predictor()) {
482 LogSerializationStatus(model_identifier_,
483 SerializationStatus::kMarkovPredictorLoadingError);
484 return;
485 }
486
487 frequencies_->FromProto(proto.markov_predictor().frequencies());
488 }
489
ExponentialWeightsEnsemble(const ExponentialWeightsEnsembleConfig & config,const std::string & model_identifier)490 ExponentialWeightsEnsemble::ExponentialWeightsEnsemble(
491 const ExponentialWeightsEnsembleConfig& config,
492 const std::string& model_identifier)
493 : RecurrencePredictor(model_identifier),
494 learning_rate_(config.learning_rate()) {
495 for (int i = 0; i < config.predictors_size(); ++i) {
496 predictors_.push_back(
497 {MakePredictor(config.predictors(i), model_identifier_),
498 1.0f / config.predictors_size()});
499 }
500 }
501
502 ExponentialWeightsEnsemble::~ExponentialWeightsEnsemble() = default;
503
504 const char ExponentialWeightsEnsemble::kPredictorName[] =
505 "ExponentialWeightsEnsemble";
GetPredictorName() const506 const char* ExponentialWeightsEnsemble::GetPredictorName() const {
507 return kPredictorName;
508 }
509
Train(unsigned int target,unsigned int condition)510 void ExponentialWeightsEnsemble::Train(unsigned int target,
511 unsigned int condition) {
512 // Update predictor weights. Do this before training the constituent
513 // predictors to avoid biasing towards fast-adjusting predictors.
514 for (auto& predictor_weight : predictors_) {
515 const auto& ranks = predictor_weight.first->Rank(condition);
516
517 // Find the normalized score associated with the ground-truth |target|.
518 // If the predictor didn't rank the ground truth target, consider that a
519 // score of 0.
520 float total_score = 0.0f;
521 for (const auto& target_score : ranks)
522 total_score += target_score.second;
523
524 float score = 0.0f;
525 const auto& it = ranks.find(target);
526 if (total_score > 0.0f && it != ranks.end())
527 score = it->second / total_score;
528
529 // Perform an exponential weights update.
530 predictor_weight.second *= std::exp(-learning_rate_ * (1 - score));
531 }
532
533 // Re-normalize weights.
534 float total_weight = 0.0f;
535 for (const auto& predictor_weight : predictors_)
536 total_weight += predictor_weight.second;
537 for (auto& predictor_weight : predictors_) {
538 predictor_weight.second /= total_weight;
539 }
540
541 // Train constituent predictors.
542 for (auto& predictor_weight : predictors_)
543 predictor_weight.first->Train(target, condition);
544 }
545
Rank(unsigned int condition)546 std::map<unsigned int, float> ExponentialWeightsEnsemble::Rank(
547 unsigned int condition) {
548 std::map<unsigned int, float> result;
549 for (const auto& predictor_weight : predictors_) {
550 const auto& ranks = predictor_weight.first->Rank(condition);
551 for (const auto& target_score : ranks) {
552 // Weights are kept normalized by Train, so all scores remain in [0,1] if
553 // the predictors' scores are in [0,1].
554 result[target_score.first] +=
555 target_score.second * predictor_weight.second;
556 }
557 }
558 return result;
559 }
560
ToProto(RecurrencePredictorProto * proto) const561 void ExponentialWeightsEnsemble::ToProto(
562 RecurrencePredictorProto* proto) const {
563 auto* ensemble = proto->mutable_exponential_weights_ensemble();
564
565 for (const auto& predictor_weight : predictors_) {
566 predictor_weight.first->ToProto(ensemble->add_predictors());
567 ensemble->add_weights(predictor_weight.second);
568 }
569 }
570
FromProto(const RecurrencePredictorProto & proto)571 void ExponentialWeightsEnsemble::FromProto(
572 const RecurrencePredictorProto& proto) {
573 if (!proto.has_exponential_weights_ensemble()) {
574 LogSerializationStatus(
575 model_identifier_,
576 SerializationStatus::kExponentialWeightsEnsembleLoadingError);
577 return;
578 }
579 const auto& ensemble = proto.exponential_weights_ensemble();
580 int num_predictors = static_cast<int>(predictors_.size());
581 DCHECK_EQ(num_predictors, ensemble.predictors_size());
582 DCHECK_EQ(num_predictors, ensemble.weights_size());
583
584 for (int i = 0; i < num_predictors; ++i) {
585 predictors_[i].first->FromProto(ensemble.predictors(i));
586 predictors_[i].second = ensemble.weights(i);
587 }
588 }
589
FrequencyPredictor(const std::string & model_identifier)590 FrequencyPredictor::FrequencyPredictor(const std::string& model_identifier)
591 : RecurrencePredictor(model_identifier) {}
592
FrequencyPredictor(const FrequencyPredictorConfig & config,const std::string & model_identifier)593 FrequencyPredictor::FrequencyPredictor(const FrequencyPredictorConfig& config,
594 const std::string& model_identifier)
595 : RecurrencePredictor(model_identifier) {}
596
597 FrequencyPredictor::~FrequencyPredictor() = default;
598
599 const char FrequencyPredictor::kPredictorName[] = "FrequencyPredictor";
GetPredictorName() const600 const char* FrequencyPredictor::GetPredictorName() const {
601 return kPredictorName;
602 }
603
Train(unsigned int target,unsigned int condition)604 void FrequencyPredictor::Train(unsigned int target, unsigned int condition) {
605 counts_[target] += 1.0f;
606 }
607
Rank(unsigned int condition)608 std::map<unsigned int, float> FrequencyPredictor::Rank(unsigned int condition) {
609 float total = 0.0f;
610 for (const auto& pair : counts_)
611 total += pair.second;
612
613 std::map<unsigned int, float> result;
614 for (const auto& pair : counts_)
615 result[pair.first] = pair.second / total;
616 return result;
617 }
618
Cleanup(const std::vector<unsigned int> & valid_targets)619 void FrequencyPredictor::Cleanup(
620 const std::vector<unsigned int>& valid_targets) {
621 std::map<unsigned int, int> new_counts;
622
623 for (unsigned int id : valid_targets) {
624 const auto& it = counts_.find(id);
625 if (it != counts_.end())
626 new_counts[id] = it->second;
627 }
628
629 counts_.swap(new_counts);
630 }
631
ToProto(RecurrencePredictorProto * proto) const632 void FrequencyPredictor::ToProto(RecurrencePredictorProto* proto) const {
633 auto* counts = proto->mutable_frequency_predictor()->mutable_counts();
634 for (auto& pair : counts_)
635 (*counts)[pair.first] = pair.second;
636 }
637
FromProto(const RecurrencePredictorProto & proto)638 void FrequencyPredictor::FromProto(const RecurrencePredictorProto& proto) {
639 if (!proto.has_frequency_predictor()) {
640 LogSerializationStatus(
641 model_identifier_,
642 SerializationStatus::kFrequencyPredictorLoadingError);
643 return;
644 }
645
646 for (const auto& pair : proto.frequency_predictor().counts())
647 counts_[pair.first] = pair.second;
648 }
649
650 } // namespace app_list
651