1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_ranker_util.h"
6 
7 #include <utility>
8 #include <vector>
9 
10 #include "base/json/json_reader.h"
11 #include "base/json/json_value_converter.h"
12 #include "base/strings/string_number_conversions.h"
13 #include "base/strings/string_piece.h"
14 #include "base/values.h"
15 #include "chrome/browser/ui/app_list/search/search_result_ranker/histogram_util.h"
16 #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_predictor.h"
17 #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_predictor.pb.h"
18 #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_ranker.pb.h"
19 #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_ranker_config.pb.h"
20 
21 namespace app_list {
22 namespace {
23 
24 using base::Optional;
25 using base::Value;
26 
27 using FakePredictorConfig = RecurrencePredictorConfigProto::FakePredictorConfig;
28 using DefaultPredictorConfig =
29     RecurrencePredictorConfigProto::DefaultPredictorConfig;
30 using ConditionalFrequencyPredictorConfig =
31     RecurrencePredictorConfigProto::ConditionalFrequencyPredictorConfig;
32 using FrecencyPredictorConfig =
33     RecurrencePredictorConfigProto::FrecencyPredictorConfig;
34 using HourBinPredictorConfig =
35     RecurrencePredictorConfigProto::HourBinPredictorConfig;
36 using MarkovPredictorConfig =
37     RecurrencePredictorConfigProto::MarkovPredictorConfig;
38 using ExponentialWeightsEnsembleConfig =
39     RecurrencePredictorConfigProto::ExponentialWeightsEnsembleConfig;
40 
41 //---------------------
42 // Conversion utilities
43 //---------------------
44 
GetNestedField(const Value * value,const std::string & key)45 base::Optional<const Value*> GetNestedField(const Value* value,
46                                             const std::string& key) {
47   const Value* field = value->FindKey(key);
48   if (!field || !field->is_dict())
49     return base::nullopt;
50   return base::Optional<const Value*>(field);
51 }
52 
GetList(const Value * value,const std::string & key)53 Optional<const Value*> GetList(const Value* value, const std::string& key) {
54   const Value* field = value->FindKey(key);
55   if (!field || !field->is_list())
56     return base::nullopt;
57   return base::Optional<const Value*>(field);
58 }
59 
GetInt(const Value * value,const std::string & key)60 Optional<int> GetInt(const Value* value, const std::string& key) {
61   const Value* field = value->FindKey(key);
62   if (!field || !field->is_int())
63     return base::nullopt;
64   return field->GetInt();
65 }
66 
GetDouble(const Value * value,const std::string & key)67 base::Optional<double> GetDouble(const Value* value, const std::string& key) {
68   const Value* field = value->FindKey(key);
69   if (!field || !field->is_double())
70     return base::nullopt;
71   return field->GetDouble();
72 }
73 
GetString(const Value * value,const std::string & key)74 base::Optional<std::string> GetString(const Value* value,
75                                       const std::string& key) {
76   const Value* field = value->FindKey(key);
77   if (!field || !field->is_string())
78     return base::nullopt;
79   return field->GetString();
80 }
81 
82 //----------------------
83 // Predictor conversions
84 //----------------------
85 
86 bool ConvertRecurrencePredictor(const Value*,
87                                 RecurrencePredictorConfigProto* proto);
88 
ConvertFrecencyPredictor(const Value * value,FrecencyPredictorConfig * proto)89 bool ConvertFrecencyPredictor(const Value* value,
90                               FrecencyPredictorConfig* proto) {
91   const auto& decay_coeff = GetDouble(value, "decay_coeff");
92   if (!decay_coeff)
93     return false;
94   proto->set_decay_coeff(decay_coeff.value());
95   return true;
96 }
97 
ConvertHourBinPredictor(const Value * value,HourBinPredictorConfig * proto)98 bool ConvertHourBinPredictor(const Value* value,
99                              HourBinPredictorConfig* proto) {
100   const auto& bin_weights = GetList(value, "bin_weights");
101 
102   if (!bin_weights)
103     return false;
104 
105   for (const Value& bin_weight : bin_weights.value()->GetList()) {
106     const auto& bin = GetInt(&bin_weight, "bin");
107     const auto& weight = GetDouble(&bin_weight, "weight");
108     if (!bin || !weight)
109       return false;
110 
111     auto* proto_bin_weight = proto->add_bin_weights();
112     proto_bin_weight->set_bin(bin.value());
113     proto_bin_weight->set_weight(weight.value());
114   }
115   return true;
116 }
117 
ConvertExponentialWeightsEnsemble(const Value * value,ExponentialWeightsEnsembleConfig * proto)118 bool ConvertExponentialWeightsEnsemble(
119     const Value* value,
120     ExponentialWeightsEnsembleConfig* proto) {
121   const auto& learning_rate = GetDouble(value, "learning_rate");
122   const auto& predictors = GetList(value, "predictors");
123 
124   if (!learning_rate || !predictors)
125     return false;
126 
127   proto->set_learning_rate(learning_rate.value());
128 
129   bool success = true;
130   for (const Value& predictor : predictors.value()->GetList())
131     success &= ConvertRecurrencePredictor(&predictor, proto->add_predictors());
132   return success;
133 }
134 
135 //----------------------
136 // Framework conversions
137 //----------------------
138 
ConvertRecurrenceRanker(const Value * value,RecurrenceRankerConfigProto * proto)139 bool ConvertRecurrenceRanker(const Value* value,
140                              RecurrenceRankerConfigProto* proto) {
141   const auto& min_seconds_between_saves =
142       GetInt(value, "min_seconds_between_saves");
143   const auto& target_limit = GetInt(value, "target_limit");
144   const auto& target_decay = GetDouble(value, "target_decay");
145   const auto& condition_limit = GetInt(value, "condition_limit");
146   const auto& condition_decay = GetDouble(value, "condition_decay");
147   const auto& predictor = GetNestedField(value, "predictor");
148 
149   if (!min_seconds_between_saves || !target_limit || !target_decay ||
150       !condition_limit || !condition_decay || !predictor)
151     return false;
152 
153   proto->set_min_seconds_between_saves(min_seconds_between_saves.value());
154   proto->set_target_limit(target_limit.value());
155   proto->set_target_decay(target_decay.value());
156   proto->set_condition_limit(condition_limit.value());
157   proto->set_condition_decay(condition_decay.value());
158 
159   return ConvertRecurrencePredictor(predictor.value(),
160                                     proto->mutable_predictor());
161 }
162 
ConvertRecurrencePredictor(const Value * value,RecurrencePredictorConfigProto * proto)163 bool ConvertRecurrencePredictor(const Value* value,
164                                 RecurrencePredictorConfigProto* proto) {
165   const auto& predictor_type = GetString(value, "predictor_type");
166   if (!predictor_type)
167     return false;
168 
169   // Add new predictor converters here. Predictors with parameters should call a
170   // ConvertX function, and predictors without parameters should just set an
171   // empty message for that predictor. The empty message is important because
172   // its existence determines which predictor to use.
173   if (predictor_type == "fake") {
174     proto->mutable_fake_predictor();
175     return true;
176   } else if (predictor_type == "default") {
177     proto->mutable_default_predictor();
178     return true;
179   } else if (predictor_type == "conditional frequency") {
180     proto->mutable_conditional_frequency_predictor();
181     return true;
182   } else if (predictor_type == "frecency") {
183     return ConvertFrecencyPredictor(value, proto->mutable_frecency_predictor());
184   } else if (predictor_type == "hour bin") {
185     return ConvertHourBinPredictor(value, proto->mutable_hour_bin_predictor());
186   } else if (predictor_type == "markov") {
187     proto->mutable_markov_predictor();
188     return true;
189   } else if (predictor_type == "exponential weights ensemble") {
190     return ConvertExponentialWeightsEnsemble(
191         value, proto->mutable_exponential_weights_ensemble());
192   } else {
193     return false;
194   }
195 }
196 
197 }  // namespace
198 
MakePredictor(const RecurrencePredictorConfigProto & config,const std::string & model_identifier)199 std::unique_ptr<RecurrencePredictor> MakePredictor(
200     const RecurrencePredictorConfigProto& config,
201     const std::string& model_identifier) {
202   if (config.has_fake_predictor())
203     return std::make_unique<FakePredictor>(config.fake_predictor(),
204                                            model_identifier);
205   if (config.has_default_predictor())
206     return std::make_unique<DefaultPredictor>(config.default_predictor(),
207                                               model_identifier);
208   if (config.has_conditional_frequency_predictor())
209     return std::make_unique<ConditionalFrequencyPredictor>(
210 
211         config.conditional_frequency_predictor(), model_identifier);
212   if (config.has_frecency_predictor())
213     return std::make_unique<FrecencyPredictor>(config.frecency_predictor(),
214                                                model_identifier);
215   if (config.has_hour_bin_predictor())
216     return std::make_unique<HourBinPredictor>(config.hour_bin_predictor(),
217                                               model_identifier);
218   if (config.has_markov_predictor())
219     return std::make_unique<MarkovPredictor>(config.markov_predictor(),
220                                              model_identifier);
221   if (config.has_exponential_weights_ensemble())
222     return std::make_unique<ExponentialWeightsEnsemble>(
223         config.exponential_weights_ensemble(), model_identifier);
224 
225   LogInitializationStatus(model_identifier,
226                           InitializationStatus::kInvalidConfigPredictor);
227   NOTREACHED();
228   return nullptr;
229 }
230 
Convert(const std::string & json_string,const std::string & model_identifier,OnConfigLoadedCallback callback)231 std::unique_ptr<JsonConfigConverter> JsonConfigConverter::Convert(
232     const std::string& json_string,
233     const std::string& model_identifier,
234     OnConfigLoadedCallback callback) {
235   // We don't use make_unique because the ctor is private.
236   std::unique_ptr<JsonConfigConverter> converter(new JsonConfigConverter());
237   converter->Start(json_string, model_identifier, std::move(callback));
238   return converter;
239 }
240 
241 JsonConfigConverter::JsonConfigConverter() = default;
242 
243 JsonConfigConverter::~JsonConfigConverter() = default;
244 
Start(const std::string & json_string,const std::string & model_identifier,OnConfigLoadedCallback callback)245 void JsonConfigConverter::Start(const std::string& json_string,
246                                 const std::string& model_identifier,
247                                 OnConfigLoadedCallback callback) {
248   data_decoder::DataDecoder::ParseJsonIsolated(
249       json_string, base::BindOnce(&JsonConfigConverter::OnJsonParsed,
250                                   weak_ptr_factory_.GetWeakPtr(),
251                                   std::move(callback), model_identifier));
252 }
253 
OnJsonParsed(OnConfigLoadedCallback callback,const std::string & model_identifier,data_decoder::DataDecoder::ValueOrError result)254 void JsonConfigConverter::OnJsonParsed(
255     OnConfigLoadedCallback callback,
256     const std::string& model_identifier,
257     data_decoder::DataDecoder::ValueOrError result) {
258   RecurrenceRankerConfigProto proto;
259   if (result.value && ConvertRecurrenceRanker(&result.value.value(), &proto)) {
260     LogJsonConfigConversionStatus(model_identifier,
261                                   JsonConfigConversionStatus::kSuccess);
262     std::move(callback).Run(std::move(proto));
263   } else {
264     LogJsonConfigConversionStatus(model_identifier,
265                                   JsonConfigConversionStatus::kFailure);
266     std::move(callback).Run(base::nullopt);
267   }
268 }
269 
270 }  // namespace app_list
271