1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_ranker_util.h"
6
7 #include <utility>
8 #include <vector>
9
10 #include "base/json/json_reader.h"
11 #include "base/json/json_value_converter.h"
12 #include "base/strings/string_number_conversions.h"
13 #include "base/strings/string_piece.h"
14 #include "base/values.h"
15 #include "chrome/browser/ui/app_list/search/search_result_ranker/histogram_util.h"
16 #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_predictor.h"
17 #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_predictor.pb.h"
18 #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_ranker.pb.h"
19 #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_ranker_config.pb.h"
20
21 namespace app_list {
22 namespace {
23
24 using base::Optional;
25 using base::Value;
26
27 using FakePredictorConfig = RecurrencePredictorConfigProto::FakePredictorConfig;
28 using DefaultPredictorConfig =
29 RecurrencePredictorConfigProto::DefaultPredictorConfig;
30 using ConditionalFrequencyPredictorConfig =
31 RecurrencePredictorConfigProto::ConditionalFrequencyPredictorConfig;
32 using FrecencyPredictorConfig =
33 RecurrencePredictorConfigProto::FrecencyPredictorConfig;
34 using HourBinPredictorConfig =
35 RecurrencePredictorConfigProto::HourBinPredictorConfig;
36 using MarkovPredictorConfig =
37 RecurrencePredictorConfigProto::MarkovPredictorConfig;
38 using ExponentialWeightsEnsembleConfig =
39 RecurrencePredictorConfigProto::ExponentialWeightsEnsembleConfig;
40
41 //---------------------
42 // Conversion utilities
43 //---------------------
44
GetNestedField(const Value * value,const std::string & key)45 base::Optional<const Value*> GetNestedField(const Value* value,
46 const std::string& key) {
47 const Value* field = value->FindKey(key);
48 if (!field || !field->is_dict())
49 return base::nullopt;
50 return base::Optional<const Value*>(field);
51 }
52
GetList(const Value * value,const std::string & key)53 Optional<const Value*> GetList(const Value* value, const std::string& key) {
54 const Value* field = value->FindKey(key);
55 if (!field || !field->is_list())
56 return base::nullopt;
57 return base::Optional<const Value*>(field);
58 }
59
GetInt(const Value * value,const std::string & key)60 Optional<int> GetInt(const Value* value, const std::string& key) {
61 const Value* field = value->FindKey(key);
62 if (!field || !field->is_int())
63 return base::nullopt;
64 return field->GetInt();
65 }
66
GetDouble(const Value * value,const std::string & key)67 base::Optional<double> GetDouble(const Value* value, const std::string& key) {
68 const Value* field = value->FindKey(key);
69 if (!field || !field->is_double())
70 return base::nullopt;
71 return field->GetDouble();
72 }
73
GetString(const Value * value,const std::string & key)74 base::Optional<std::string> GetString(const Value* value,
75 const std::string& key) {
76 const Value* field = value->FindKey(key);
77 if (!field || !field->is_string())
78 return base::nullopt;
79 return field->GetString();
80 }
81
82 //----------------------
83 // Predictor conversions
84 //----------------------
85
86 bool ConvertRecurrencePredictor(const Value*,
87 RecurrencePredictorConfigProto* proto);
88
ConvertFrecencyPredictor(const Value * value,FrecencyPredictorConfig * proto)89 bool ConvertFrecencyPredictor(const Value* value,
90 FrecencyPredictorConfig* proto) {
91 const auto& decay_coeff = GetDouble(value, "decay_coeff");
92 if (!decay_coeff)
93 return false;
94 proto->set_decay_coeff(decay_coeff.value());
95 return true;
96 }
97
ConvertHourBinPredictor(const Value * value,HourBinPredictorConfig * proto)98 bool ConvertHourBinPredictor(const Value* value,
99 HourBinPredictorConfig* proto) {
100 const auto& bin_weights = GetList(value, "bin_weights");
101
102 if (!bin_weights)
103 return false;
104
105 for (const Value& bin_weight : bin_weights.value()->GetList()) {
106 const auto& bin = GetInt(&bin_weight, "bin");
107 const auto& weight = GetDouble(&bin_weight, "weight");
108 if (!bin || !weight)
109 return false;
110
111 auto* proto_bin_weight = proto->add_bin_weights();
112 proto_bin_weight->set_bin(bin.value());
113 proto_bin_weight->set_weight(weight.value());
114 }
115 return true;
116 }
117
ConvertExponentialWeightsEnsemble(const Value * value,ExponentialWeightsEnsembleConfig * proto)118 bool ConvertExponentialWeightsEnsemble(
119 const Value* value,
120 ExponentialWeightsEnsembleConfig* proto) {
121 const auto& learning_rate = GetDouble(value, "learning_rate");
122 const auto& predictors = GetList(value, "predictors");
123
124 if (!learning_rate || !predictors)
125 return false;
126
127 proto->set_learning_rate(learning_rate.value());
128
129 bool success = true;
130 for (const Value& predictor : predictors.value()->GetList())
131 success &= ConvertRecurrencePredictor(&predictor, proto->add_predictors());
132 return success;
133 }
134
135 //----------------------
136 // Framework conversions
137 //----------------------
138
ConvertRecurrenceRanker(const Value * value,RecurrenceRankerConfigProto * proto)139 bool ConvertRecurrenceRanker(const Value* value,
140 RecurrenceRankerConfigProto* proto) {
141 const auto& min_seconds_between_saves =
142 GetInt(value, "min_seconds_between_saves");
143 const auto& target_limit = GetInt(value, "target_limit");
144 const auto& target_decay = GetDouble(value, "target_decay");
145 const auto& condition_limit = GetInt(value, "condition_limit");
146 const auto& condition_decay = GetDouble(value, "condition_decay");
147 const auto& predictor = GetNestedField(value, "predictor");
148
149 if (!min_seconds_between_saves || !target_limit || !target_decay ||
150 !condition_limit || !condition_decay || !predictor)
151 return false;
152
153 proto->set_min_seconds_between_saves(min_seconds_between_saves.value());
154 proto->set_target_limit(target_limit.value());
155 proto->set_target_decay(target_decay.value());
156 proto->set_condition_limit(condition_limit.value());
157 proto->set_condition_decay(condition_decay.value());
158
159 return ConvertRecurrencePredictor(predictor.value(),
160 proto->mutable_predictor());
161 }
162
ConvertRecurrencePredictor(const Value * value,RecurrencePredictorConfigProto * proto)163 bool ConvertRecurrencePredictor(const Value* value,
164 RecurrencePredictorConfigProto* proto) {
165 const auto& predictor_type = GetString(value, "predictor_type");
166 if (!predictor_type)
167 return false;
168
169 // Add new predictor converters here. Predictors with parameters should call a
170 // ConvertX function, and predictors without parameters should just set an
171 // empty message for that predictor. The empty message is important because
172 // its existence determines which predictor to use.
173 if (predictor_type == "fake") {
174 proto->mutable_fake_predictor();
175 return true;
176 } else if (predictor_type == "default") {
177 proto->mutable_default_predictor();
178 return true;
179 } else if (predictor_type == "conditional frequency") {
180 proto->mutable_conditional_frequency_predictor();
181 return true;
182 } else if (predictor_type == "frecency") {
183 return ConvertFrecencyPredictor(value, proto->mutable_frecency_predictor());
184 } else if (predictor_type == "hour bin") {
185 return ConvertHourBinPredictor(value, proto->mutable_hour_bin_predictor());
186 } else if (predictor_type == "markov") {
187 proto->mutable_markov_predictor();
188 return true;
189 } else if (predictor_type == "exponential weights ensemble") {
190 return ConvertExponentialWeightsEnsemble(
191 value, proto->mutable_exponential_weights_ensemble());
192 } else {
193 return false;
194 }
195 }
196
197 } // namespace
198
MakePredictor(const RecurrencePredictorConfigProto & config,const std::string & model_identifier)199 std::unique_ptr<RecurrencePredictor> MakePredictor(
200 const RecurrencePredictorConfigProto& config,
201 const std::string& model_identifier) {
202 if (config.has_fake_predictor())
203 return std::make_unique<FakePredictor>(config.fake_predictor(),
204 model_identifier);
205 if (config.has_default_predictor())
206 return std::make_unique<DefaultPredictor>(config.default_predictor(),
207 model_identifier);
208 if (config.has_conditional_frequency_predictor())
209 return std::make_unique<ConditionalFrequencyPredictor>(
210
211 config.conditional_frequency_predictor(), model_identifier);
212 if (config.has_frecency_predictor())
213 return std::make_unique<FrecencyPredictor>(config.frecency_predictor(),
214 model_identifier);
215 if (config.has_hour_bin_predictor())
216 return std::make_unique<HourBinPredictor>(config.hour_bin_predictor(),
217 model_identifier);
218 if (config.has_markov_predictor())
219 return std::make_unique<MarkovPredictor>(config.markov_predictor(),
220 model_identifier);
221 if (config.has_exponential_weights_ensemble())
222 return std::make_unique<ExponentialWeightsEnsemble>(
223 config.exponential_weights_ensemble(), model_identifier);
224
225 LogInitializationStatus(model_identifier,
226 InitializationStatus::kInvalidConfigPredictor);
227 NOTREACHED();
228 return nullptr;
229 }
230
Convert(const std::string & json_string,const std::string & model_identifier,OnConfigLoadedCallback callback)231 std::unique_ptr<JsonConfigConverter> JsonConfigConverter::Convert(
232 const std::string& json_string,
233 const std::string& model_identifier,
234 OnConfigLoadedCallback callback) {
235 // We don't use make_unique because the ctor is private.
236 std::unique_ptr<JsonConfigConverter> converter(new JsonConfigConverter());
237 converter->Start(json_string, model_identifier, std::move(callback));
238 return converter;
239 }
240
241 JsonConfigConverter::JsonConfigConverter() = default;
242
243 JsonConfigConverter::~JsonConfigConverter() = default;
244
Start(const std::string & json_string,const std::string & model_identifier,OnConfigLoadedCallback callback)245 void JsonConfigConverter::Start(const std::string& json_string,
246 const std::string& model_identifier,
247 OnConfigLoadedCallback callback) {
248 data_decoder::DataDecoder::ParseJsonIsolated(
249 json_string, base::BindOnce(&JsonConfigConverter::OnJsonParsed,
250 weak_ptr_factory_.GetWeakPtr(),
251 std::move(callback), model_identifier));
252 }
253
OnJsonParsed(OnConfigLoadedCallback callback,const std::string & model_identifier,data_decoder::DataDecoder::ValueOrError result)254 void JsonConfigConverter::OnJsonParsed(
255 OnConfigLoadedCallback callback,
256 const std::string& model_identifier,
257 data_decoder::DataDecoder::ValueOrError result) {
258 RecurrenceRankerConfigProto proto;
259 if (result.value && ConvertRecurrenceRanker(&result.value.value(), &proto)) {
260 LogJsonConfigConversionStatus(model_identifier,
261 JsonConfigConversionStatus::kSuccess);
262 std::move(callback).Run(std::move(proto));
263 } else {
264 LogJsonConfigConversionStatus(model_identifier,
265 JsonConfigConversionStatus::kFailure);
266 std::move(callback).Run(base::nullopt);
267 }
268 }
269
270 } // namespace app_list
271