1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_predictor.h"
6 
7 #include <cmath>
8 #include <utility>
9 
10 #include "base/time/time.h"
11 #include "chrome/browser/ui/app_list/search/search_result_ranker/frecency_store.pb.h"
12 #include "chrome/browser/ui/app_list/search/search_result_ranker/histogram_util.h"
13 #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_predictor.pb.h"
14 #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_ranker_util.h"
15 
16 namespace app_list {
17 namespace {
18 
19 constexpr int kHoursADay = 24;
20 
21 }  // namespace
22 
RecurrencePredictor(const std::string & model_identifier)23 RecurrencePredictor::RecurrencePredictor(const std::string& model_identifier)
24     : model_identifier_(model_identifier) {}
25 
FakePredictor(const std::string & model_identifier)26 FakePredictor::FakePredictor(const std::string& model_identifier)
27     : RecurrencePredictor(model_identifier) {
28   // The fake predictor should only be used for testing, not in production.
29   // Record an error so we know if it is being used.
30   LogInitializationStatus(model_identifier_,
31                           InitializationStatus::kFakePredictorUsed);
32 }
33 
FakePredictor(const FakePredictorConfig & config,const std::string & model_identifier)34 FakePredictor::FakePredictor(const FakePredictorConfig& config,
35                              const std::string& model_identifier)
36     : RecurrencePredictor(model_identifier) {
37   // The fake predictor should only be used for testing, not in production.
38   // Record an error so we know if it is being used.
39   LogInitializationStatus(model_identifier_,
40                           InitializationStatus::kFakePredictorUsed);
41 }
42 
43 FakePredictor::~FakePredictor() = default;
44 
45 const char FakePredictor::kPredictorName[] = "FakePredictor";
GetPredictorName() const46 const char* FakePredictor::GetPredictorName() const {
47   return kPredictorName;
48 }
49 
Train(unsigned int target,unsigned int condition)50 void FakePredictor::Train(unsigned int target, unsigned int condition) {
51   counts_[target] += 1.0f;
52 }
53 
Rank(unsigned int condition)54 std::map<unsigned int, float> FakePredictor::Rank(unsigned int condition) {
55   return counts_;
56 }
57 
Cleanup(const std::vector<unsigned int> & valid_targets)58 void FakePredictor::Cleanup(const std::vector<unsigned int>& valid_targets) {
59   std::map<unsigned int, float> new_counts;
60 
61   for (unsigned int id : valid_targets) {
62     const auto& it = counts_.find(id);
63     if (it != counts_.end())
64       new_counts[id] = it->second;
65   }
66 
67   counts_.swap(new_counts);
68 }
69 
ToProto(RecurrencePredictorProto * proto) const70 void FakePredictor::ToProto(RecurrencePredictorProto* proto) const {
71   auto* counts = proto->mutable_fake_predictor()->mutable_counts();
72   for (auto& pair : counts_)
73     (*counts)[pair.first] = pair.second;
74 }
75 
FromProto(const RecurrencePredictorProto & proto)76 void FakePredictor::FromProto(const RecurrencePredictorProto& proto) {
77   if (!proto.has_fake_predictor()) {
78     LogSerializationStatus(model_identifier_,
79                            SerializationStatus::kFakePredictorLoadingError);
80     return;
81   }
82 
83   for (const auto& pair : proto.fake_predictor().counts())
84     counts_[pair.first] = pair.second;
85 }
86 
DefaultPredictor(const DefaultPredictorConfig & config,const std::string & model_identifier)87 DefaultPredictor::DefaultPredictor(const DefaultPredictorConfig& config,
88                                    const std::string& model_identifier)
89     : RecurrencePredictor(model_identifier) {}
~DefaultPredictor()90 DefaultPredictor::~DefaultPredictor() {}
91 
Train(unsigned int target,unsigned int condition)92 void DefaultPredictor::Train(unsigned int target, unsigned int condition) {}
93 
Rank(unsigned int condition)94 std::map<unsigned int, float> DefaultPredictor::Rank(unsigned int condition) {
95   NOTREACHED();
96   return {};
97 }
98 
99 const char DefaultPredictor::kPredictorName[] = "DefaultPredictor";
GetPredictorName() const100 const char* DefaultPredictor::GetPredictorName() const {
101   return kPredictorName;
102 }
103 
ToProto(RecurrencePredictorProto * proto) const104 void DefaultPredictor::ToProto(RecurrencePredictorProto* proto) const {}
105 
FromProto(const RecurrencePredictorProto & proto)106 void DefaultPredictor::FromProto(const RecurrencePredictorProto& proto) {}
107 
ConditionalFrequencyPredictor(const std::string & model_identifier)108 ConditionalFrequencyPredictor::ConditionalFrequencyPredictor(
109     const std::string& model_identifier)
110     : RecurrencePredictor(model_identifier) {}
ConditionalFrequencyPredictor(const ConditionalFrequencyPredictorConfig & config,const std::string & model_identifier)111 ConditionalFrequencyPredictor::ConditionalFrequencyPredictor(
112     const ConditionalFrequencyPredictorConfig& config,
113     const std::string& model_identifier)
114     : RecurrencePredictor(model_identifier) {}
115 ConditionalFrequencyPredictor::~ConditionalFrequencyPredictor() = default;
116 
117 ConditionalFrequencyPredictor::Events::Events() = default;
118 ConditionalFrequencyPredictor::Events::~Events() = default;
119 ConditionalFrequencyPredictor::Events::Events(const Events& other) = default;
120 
121 const char ConditionalFrequencyPredictor::kPredictorName[] =
122     "ConditionalFrequencyPredictor";
GetPredictorName() const123 const char* ConditionalFrequencyPredictor::GetPredictorName() const {
124   return kPredictorName;
125 }
126 
Train(unsigned int target,unsigned int condition)127 void ConditionalFrequencyPredictor::Train(unsigned int target,
128                                           unsigned int condition) {
129   TrainWithDelta(target, condition, 1.0f);
130 }
131 
TrainWithDelta(unsigned int target,unsigned int condition,float delta)132 void ConditionalFrequencyPredictor::TrainWithDelta(unsigned int target,
133                                                    unsigned int condition,
134                                                    float delta) {
135   DCHECK_NE(delta, 0.0f);
136   auto& events = table_[condition];
137   events.freqs[target] += delta;
138   events.total += delta;
139 }
140 
Rank(unsigned int condition)141 std::map<unsigned int, float> ConditionalFrequencyPredictor::Rank(
142     unsigned int condition) {
143   const auto& it = table_.find(condition);
144   // If the total frequency is zero, we can't return any meaningful results, so
145   // return empty.
146   if (it == table_.end() || it->second.total == 0.0f)
147     return {};
148 
149   std::map<unsigned int, float> result;
150   const auto& events = it->second;
151   for (const auto& target_freq : events.freqs)
152     result[target_freq.first] = target_freq.second / events.total;
153   return result;
154 }
155 
Cleanup(const std::vector<unsigned int> & valid_targets)156 void ConditionalFrequencyPredictor::Cleanup(
157     const std::vector<unsigned int>& valid_targets) {
158   for (auto iter = table_.begin(); iter != table_.end();) {
159     auto& events = iter->second;
160 
161     std::map<unsigned int, float> new_freqs;
162     float new_total = 0.0f;
163     for (unsigned int id : valid_targets) {
164       const auto& it = events.freqs.find(id);
165       if (it != events.freqs.end()) {
166         new_freqs[id] = it->second;
167         new_total += it->second;
168       }
169     }
170 
171     // Delete the whole condition out of the table if it contains no valid
172     // targets.
173     if (new_freqs.empty()) {
174       // C++11: the return value of erase(iter) is an iterator pointing to the
175       // next element in the container.
176       iter = table_.erase(iter);
177     } else {
178       ++iter;
179       events.freqs.swap(new_freqs);
180       events.total = new_total;
181     }
182   }
183 }
184 
CleanupConditions(const std::vector<unsigned int> & valid_conditions)185 void ConditionalFrequencyPredictor::CleanupConditions(
186     const std::vector<unsigned int>& valid_conditions) {
187   std::map<unsigned int, ConditionalFrequencyPredictor::Events> new_table;
188 
189   for (unsigned int id : valid_conditions) {
190     const auto& it = table_.find(id);
191     if (it != table_.end()) {
192       new_table[id] = std::move(it->second);
193     }
194   }
195 
196   table_.swap(new_table);
197 }
198 
ToProto(RecurrencePredictorProto * proto) const199 void ConditionalFrequencyPredictor::ToProto(
200     RecurrencePredictorProto* proto) const {
201   auto* predictor = proto->mutable_conditional_frequency_predictor();
202   for (const auto& condition_events : table_) {
203     for (const auto& event_freq : condition_events.second.freqs) {
204       auto* event = predictor->add_events();
205       event->set_condition(condition_events.first);
206       event->set_event(event_freq.first);
207       event->set_freq(event_freq.second);
208     }
209   }
210 }
211 
FromProto(const RecurrencePredictorProto & proto)212 void ConditionalFrequencyPredictor::FromProto(
213     const RecurrencePredictorProto& proto) {
214   if (!proto.has_conditional_frequency_predictor()) {
215     LogSerializationStatus(
216         model_identifier_,
217         SerializationStatus::kConditionalFrequencyPredictorLoadingError);
218     return;
219   }
220 
221   for (const auto& event : proto.conditional_frequency_predictor().events()) {
222     auto& events = table_[event.condition()];
223     events.freqs[event.event()] = event.freq();
224     events.total += event.freq();
225   }
226 }
227 
FrecencyPredictor(const FrecencyPredictorConfig & config,const std::string & model_identifier)228 FrecencyPredictor::FrecencyPredictor(const FrecencyPredictorConfig& config,
229                                      const std::string& model_identifier)
230     : RecurrencePredictor(model_identifier),
231       decay_coeff_(config.decay_coeff()) {}
232 FrecencyPredictor::~FrecencyPredictor() = default;
233 
234 const char FrecencyPredictor::kPredictorName[] = "FrecencyPredictor";
GetPredictorName() const235 const char* FrecencyPredictor::GetPredictorName() const {
236   return kPredictorName;
237 }
238 
Train(unsigned int target,unsigned int condition)239 void FrecencyPredictor::Train(unsigned int target, unsigned int condition) {
240   ++num_updates_;
241   TargetData& data = targets_[target];
242   DecayScore(&data);
243   data.last_score += 1.0f - decay_coeff_;
244 }
245 
Rank(unsigned int condition)246 std::map<unsigned int, float> FrecencyPredictor::Rank(unsigned int condition) {
247   float total = 0.0f;
248   for (auto& pair : targets_) {
249     DecayScore(&pair.second);
250     total += pair.second.last_score;
251   }
252   if (total == 0.0f)
253     return {};
254 
255   std::map<unsigned int, float> result;
256   for (auto& pair : targets_) {
257     result[pair.first] = pair.second.last_score / total;
258   }
259   return result;
260 }
261 
Cleanup(const std::vector<unsigned int> & valid_targets)262 void FrecencyPredictor::Cleanup(
263     const std::vector<unsigned int>& valid_targets) {
264   std::map<unsigned int, FrecencyPredictor::TargetData> new_targets;
265 
266   for (unsigned int id : valid_targets) {
267     const auto& it = targets_.find(id);
268     if (it != targets_.end())
269       new_targets[id] = it->second;
270   }
271 
272   targets_.swap(new_targets);
273 }
274 
ToProto(RecurrencePredictorProto * proto) const275 void FrecencyPredictor::ToProto(RecurrencePredictorProto* proto) const {
276   auto* predictor = proto->mutable_frecency_predictor();
277 
278   predictor->set_num_updates(num_updates_);
279 
280   for (const auto& pair : targets_) {
281     auto* target_data = predictor->add_targets();
282     target_data->set_id(pair.first);
283     target_data->set_last_score(pair.second.last_score);
284     target_data->set_last_num_updates(pair.second.last_num_updates);
285   }
286 }
287 
FromProto(const RecurrencePredictorProto & proto)288 void FrecencyPredictor::FromProto(const RecurrencePredictorProto& proto) {
289   if (!proto.has_frecency_predictor()) {
290     LogSerializationStatus(model_identifier_,
291                            SerializationStatus::kFrecencyPredictorLoadingError);
292     return;
293   }
294   const auto& predictor = proto.frecency_predictor();
295 
296   num_updates_ = predictor.num_updates();
297 
298   std::map<unsigned int, TargetData> targets;
299   for (const auto& target_data : predictor.targets()) {
300     targets[target_data.id()] = {target_data.last_score(),
301                                  target_data.last_num_updates()};
302   }
303   targets_.swap(targets);
304 }
305 
DecayScore(TargetData * data)306 void FrecencyPredictor::DecayScore(TargetData* data) {
307   int time_since_update = num_updates_ - data->last_num_updates;
308 
309   if (time_since_update > 0) {
310     data->last_score *= std::pow(decay_coeff_, time_since_update);
311     data->last_num_updates = num_updates_;
312   }
313 }
314 
HourBinPredictor(const HourBinPredictorConfig & config,const std::string & model_identifier)315 HourBinPredictor::HourBinPredictor(const HourBinPredictorConfig& config,
316                                    const std::string& model_identifier)
317     : RecurrencePredictor(model_identifier),
318       weekly_decay_coeff_(config.weekly_decay_coeff()) {
319   for (const auto& pair : config.bin_weights())
320     bin_weights_[pair.bin()] = pair.weight();
321 
322   if (!proto_.has_last_decay_timestamp())
323     SetLastDecayTimestamp(
324         base::Time::Now().ToDeltaSinceWindowsEpoch().InDays());
325 }
326 
327 HourBinPredictor::~HourBinPredictor() = default;
328 
329 const char HourBinPredictor::kPredictorName[] = "HourBinPredictor";
330 
GetPredictorName() const331 const char* HourBinPredictor::GetPredictorName() const {
332   return kPredictorName;
333 }
334 
GetBinFromHourDifference(int hour_difference) const335 int HourBinPredictor::GetBinFromHourDifference(int hour_difference) const {
336   base::Time shifted_time =
337       base::Time::Now() + base::TimeDelta::FromHours(hour_difference);
338   base::Time::Exploded exploded_time;
339   shifted_time.LocalExplode(&exploded_time);
340 
341   const bool is_weekend =
342       exploded_time.day_of_week == 6 || exploded_time.day_of_week == 0;
343 
344   // To distinguish workdays from weekend, use now.hour for workdays and
345   // now.hour + 24 for weekend.
346   if (!is_weekend) {
347     return exploded_time.hour;
348   } else {
349     return exploded_time.hour + kHoursADay;
350   }
351 }
352 
GetBin() const353 int HourBinPredictor::GetBin() const {
354   return GetBinFromHourDifference(0);
355 }
356 
Train(unsigned int target,unsigned int condition)357 void HourBinPredictor::Train(unsigned int target, unsigned int condition) {
358   int hour = GetBin();
359   auto& frequency_table = (*proto_.mutable_binned_frequency_table())[hour];
360   frequency_table.set_total_counts(frequency_table.total_counts() + 1);
361   (*frequency_table.mutable_frequency())[target] += 1;
362 }
363 
Rank(unsigned int condition)364 std::map<unsigned int, float> HourBinPredictor::Rank(unsigned int condition) {
365   std::map<unsigned int, float> ranks;
366   const auto& frequency_table_map = proto_.binned_frequency_table();
367   for (const auto& hour_and_weight : bin_weights_) {
368     // Find adjacent bin and weight.
369     const int adj_bin = GetBinFromHourDifference(hour_and_weight.first);
370     const float weight = hour_and_weight.second;
371 
372     const auto find_frequency_table = frequency_table_map.find(adj_bin);
373     if (find_frequency_table == frequency_table_map.end())
374       continue;
375     const auto& frequency_table = find_frequency_table->second;
376 
377     // Accumulates the frequency to the output.
378     if (frequency_table.total_counts() > 0) {
379       const int total_counts = frequency_table.total_counts();
380       for (const auto& pair : frequency_table.frequency()) {
381         ranks[pair.first] +=
382             static_cast<float>(pair.second) / total_counts * weight;
383       }
384     }
385   }
386   return ranks;
387 }
388 
389 // TODO(921444): Unify the hour bin predictor with the cleanup system used for
390 // other predictors. This is different than other predictors so as to be exactly
391 // the same as the Roselle predictor.
392 
ToProto(RecurrencePredictorProto * proto) const393 void HourBinPredictor::ToProto(RecurrencePredictorProto* proto) const {
394   *proto->mutable_hour_bin_predictor() = proto_;
395 }
396 
FromProto(const RecurrencePredictorProto & proto)397 void HourBinPredictor::FromProto(const RecurrencePredictorProto& proto) {
398   if (!proto.has_hour_bin_predictor()) {
399     LogSerializationStatus(model_identifier_,
400                            SerializationStatus::kHourBinPredictorLoadingError);
401     return;
402   }
403 
404   proto_ = proto.hour_bin_predictor();
405   if (ShouldDecay())
406     DecayAll();
407 }
408 
ShouldDecay()409 bool HourBinPredictor::ShouldDecay() {
410   const int today = base::Time::Now().ToDeltaSinceWindowsEpoch().InDays();
411   // Check if we should decay the frequency
412   return today - proto_.last_decay_timestamp() > 7;
413 }
414 
DecayAll()415 void HourBinPredictor::DecayAll() {
416   SetLastDecayTimestamp(base::Time::Now().ToDeltaSinceWindowsEpoch().InDays());
417   auto& frequency_table_map = *proto_.mutable_binned_frequency_table();
418   for (auto it_table = frequency_table_map.begin();
419        it_table != frequency_table_map.end();) {
420     auto& frequency_table = *it_table->second.mutable_frequency();
421     for (auto it_freq = frequency_table.begin();
422          it_freq != frequency_table.end();) {
423       const int new_frequency = it_freq->second * weekly_decay_coeff_;
424       it_table->second.set_total_counts(it_table->second.total_counts() -
425                                         it_freq->second + new_frequency);
426       it_freq->second = new_frequency;
427 
428       // Remove item that has zero frequency
429       if (it_freq->second == 0) {
430         frequency_table.erase(it_freq++);
431       } else {
432         it_freq++;
433       }
434     }
435 
436     // Remove bin that has zero total_counts
437     if (it_table->second.total_counts() == 0) {
438       frequency_table_map.erase(it_table++);
439     } else {
440       it_table++;
441     }
442   }
443 }
444 
MarkovPredictor(const MarkovPredictorConfig & config,const std::string & model_identifier)445 MarkovPredictor::MarkovPredictor(const MarkovPredictorConfig& config,
446                                  const std::string& model_identifier)
447     : RecurrencePredictor(model_identifier) {
448   frequencies_ =
449       std::make_unique<ConditionalFrequencyPredictor>(model_identifier);
450 }
451 MarkovPredictor::~MarkovPredictor() = default;
452 
453 const char MarkovPredictor::kPredictorName[] = "MarkovPredictor";
GetPredictorName() const454 const char* MarkovPredictor::GetPredictorName() const {
455   return kPredictorName;
456 }
457 
Train(unsigned int target,unsigned int condition)458 void MarkovPredictor::Train(unsigned int target, unsigned int condition) {
459   if (previous_target_)
460     frequencies_->Train(target, previous_target_.value());
461   previous_target_ = target;
462 }
463 
Rank(unsigned int condition)464 std::map<unsigned int, float> MarkovPredictor::Rank(unsigned int condition) {
465   if (previous_target_)
466     return frequencies_->Rank(previous_target_.value());
467   return std::map<unsigned int, float>();
468 }
469 
Cleanup(const std::vector<unsigned int> & valid_targets)470 void MarkovPredictor::Cleanup(const std::vector<unsigned int>& valid_targets) {
471   frequencies_->CleanupConditions(valid_targets);
472   frequencies_->Cleanup(valid_targets);
473 }
474 
ToProto(RecurrencePredictorProto * proto) const475 void MarkovPredictor::ToProto(RecurrencePredictorProto* proto) const {
476   auto* predictor = proto->mutable_markov_predictor();
477   frequencies_->ToProto(predictor->mutable_frequencies());
478 }
479 
FromProto(const RecurrencePredictorProto & proto)480 void MarkovPredictor::FromProto(const RecurrencePredictorProto& proto) {
481   if (!proto.has_markov_predictor()) {
482     LogSerializationStatus(model_identifier_,
483                            SerializationStatus::kMarkovPredictorLoadingError);
484     return;
485   }
486 
487   frequencies_->FromProto(proto.markov_predictor().frequencies());
488 }
489 
ExponentialWeightsEnsemble(const ExponentialWeightsEnsembleConfig & config,const std::string & model_identifier)490 ExponentialWeightsEnsemble::ExponentialWeightsEnsemble(
491     const ExponentialWeightsEnsembleConfig& config,
492     const std::string& model_identifier)
493     : RecurrencePredictor(model_identifier),
494       learning_rate_(config.learning_rate()) {
495   for (int i = 0; i < config.predictors_size(); ++i) {
496     predictors_.push_back(
497         {MakePredictor(config.predictors(i), model_identifier_),
498          1.0f / config.predictors_size()});
499   }
500 }
501 
502 ExponentialWeightsEnsemble::~ExponentialWeightsEnsemble() = default;
503 
504 const char ExponentialWeightsEnsemble::kPredictorName[] =
505     "ExponentialWeightsEnsemble";
GetPredictorName() const506 const char* ExponentialWeightsEnsemble::GetPredictorName() const {
507   return kPredictorName;
508 }
509 
Train(unsigned int target,unsigned int condition)510 void ExponentialWeightsEnsemble::Train(unsigned int target,
511                                        unsigned int condition) {
512   // Update predictor weights. Do this before training the constituent
513   // predictors to avoid biasing towards fast-adjusting predictors.
514   for (auto& predictor_weight : predictors_) {
515     const auto& ranks = predictor_weight.first->Rank(condition);
516 
517     // Find the normalized score associated with the ground-truth |target|.
518     // If the predictor didn't rank the ground truth target, consider that a
519     // score of 0.
520     float total_score = 0.0f;
521     for (const auto& target_score : ranks)
522       total_score += target_score.second;
523 
524     float score = 0.0f;
525     const auto& it = ranks.find(target);
526     if (total_score > 0.0f && it != ranks.end())
527       score = it->second / total_score;
528 
529     // Perform an exponential weights update.
530     predictor_weight.second *= std::exp(-learning_rate_ * (1 - score));
531   }
532 
533   // Re-normalize weights.
534   float total_weight = 0.0f;
535   for (const auto& predictor_weight : predictors_)
536     total_weight += predictor_weight.second;
537   for (auto& predictor_weight : predictors_) {
538     predictor_weight.second /= total_weight;
539   }
540 
541   // Train constituent predictors.
542   for (auto& predictor_weight : predictors_)
543     predictor_weight.first->Train(target, condition);
544 }
545 
Rank(unsigned int condition)546 std::map<unsigned int, float> ExponentialWeightsEnsemble::Rank(
547     unsigned int condition) {
548   std::map<unsigned int, float> result;
549   for (const auto& predictor_weight : predictors_) {
550     const auto& ranks = predictor_weight.first->Rank(condition);
551     for (const auto& target_score : ranks) {
552       // Weights are kept normalized by Train, so all scores remain in [0,1] if
553       // the predictors' scores are in [0,1].
554       result[target_score.first] +=
555           target_score.second * predictor_weight.second;
556     }
557   }
558   return result;
559 }
560 
ToProto(RecurrencePredictorProto * proto) const561 void ExponentialWeightsEnsemble::ToProto(
562     RecurrencePredictorProto* proto) const {
563   auto* ensemble = proto->mutable_exponential_weights_ensemble();
564 
565   for (const auto& predictor_weight : predictors_) {
566     predictor_weight.first->ToProto(ensemble->add_predictors());
567     ensemble->add_weights(predictor_weight.second);
568   }
569 }
570 
FromProto(const RecurrencePredictorProto & proto)571 void ExponentialWeightsEnsemble::FromProto(
572     const RecurrencePredictorProto& proto) {
573   if (!proto.has_exponential_weights_ensemble()) {
574     LogSerializationStatus(
575         model_identifier_,
576         SerializationStatus::kExponentialWeightsEnsembleLoadingError);
577     return;
578   }
579   const auto& ensemble = proto.exponential_weights_ensemble();
580   int num_predictors = static_cast<int>(predictors_.size());
581   DCHECK_EQ(num_predictors, ensemble.predictors_size());
582   DCHECK_EQ(num_predictors, ensemble.weights_size());
583 
584   for (int i = 0; i < num_predictors; ++i) {
585     predictors_[i].first->FromProto(ensemble.predictors(i));
586     predictors_[i].second = ensemble.weights(i);
587   }
588 }
589 
FrequencyPredictor(const std::string & model_identifier)590 FrequencyPredictor::FrequencyPredictor(const std::string& model_identifier)
591     : RecurrencePredictor(model_identifier) {}
592 
FrequencyPredictor(const FrequencyPredictorConfig & config,const std::string & model_identifier)593 FrequencyPredictor::FrequencyPredictor(const FrequencyPredictorConfig& config,
594                                        const std::string& model_identifier)
595     : RecurrencePredictor(model_identifier) {}
596 
597 FrequencyPredictor::~FrequencyPredictor() = default;
598 
599 const char FrequencyPredictor::kPredictorName[] = "FrequencyPredictor";
GetPredictorName() const600 const char* FrequencyPredictor::GetPredictorName() const {
601   return kPredictorName;
602 }
603 
Train(unsigned int target,unsigned int condition)604 void FrequencyPredictor::Train(unsigned int target, unsigned int condition) {
605   counts_[target] += 1.0f;
606 }
607 
Rank(unsigned int condition)608 std::map<unsigned int, float> FrequencyPredictor::Rank(unsigned int condition) {
609   float total = 0.0f;
610   for (const auto& pair : counts_)
611     total += pair.second;
612 
613   std::map<unsigned int, float> result;
614   for (const auto& pair : counts_)
615     result[pair.first] = pair.second / total;
616   return result;
617 }
618 
Cleanup(const std::vector<unsigned int> & valid_targets)619 void FrequencyPredictor::Cleanup(
620     const std::vector<unsigned int>& valid_targets) {
621   std::map<unsigned int, int> new_counts;
622 
623   for (unsigned int id : valid_targets) {
624     const auto& it = counts_.find(id);
625     if (it != counts_.end())
626       new_counts[id] = it->second;
627   }
628 
629   counts_.swap(new_counts);
630 }
631 
ToProto(RecurrencePredictorProto * proto) const632 void FrequencyPredictor::ToProto(RecurrencePredictorProto* proto) const {
633   auto* counts = proto->mutable_frequency_predictor()->mutable_counts();
634   for (auto& pair : counts_)
635     (*counts)[pair.first] = pair.second;
636 }
637 
FromProto(const RecurrencePredictorProto & proto)638 void FrequencyPredictor::FromProto(const RecurrencePredictorProto& proto) {
639   if (!proto.has_frequency_predictor()) {
640     LogSerializationStatus(
641         model_identifier_,
642         SerializationStatus::kFrequencyPredictorLoadingError);
643     return;
644   }
645 
646   for (const auto& pair : proto.frequency_predictor().counts())
647     counts_[pair.first] = pair.second;
648 }
649 
650 }  // namespace app_list
651