1 // Copyright 2017 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "components/assist_ranker/generic_logistic_regression_inference.h"
6 #include "components/assist_ranker/example_preprocessing.h"
7 
8 #include "testing/gtest/include/gtest/gtest.h"
9 #include "third_party/protobuf/src/google/protobuf/map.h"
10 
11 namespace assist_ranker {
12 using ::google::protobuf::Map;
13 
14 class GenericLogisticRegressionInferenceTest : public testing::Test {
15  protected:
GetProto()16   GenericLogisticRegressionModel GetProto() {
17     GenericLogisticRegressionModel proto;
18     proto.set_bias(bias_);
19     proto.set_threshold(threshold_);
20 
21     auto& weights = *proto.mutable_weights();
22     weights[scalar1_name_].set_scalar(scalar1_weight_);
23     weights[scalar2_name_].set_scalar(scalar2_weight_);
24     weights[scalar3_name_].set_scalar(scalar3_weight_);
25 
26     auto* one_hot_feat = weights[one_hot_name_].mutable_one_hot();
27     one_hot_feat->set_default_weight(one_hot_default_weight_);
28     (*one_hot_feat->mutable_weights())[one_hot_elem1_name_] =
29         one_hot_elem1_weight_;
30     (*one_hot_feat->mutable_weights())[one_hot_elem2_name_] =
31         one_hot_elem2_weight_;
32     (*one_hot_feat->mutable_weights())[one_hot_elem3_name_] =
33         one_hot_elem3_weight_;
34 
35     SparseWeights* sparse_feat = weights[sparse_name_].mutable_sparse();
36     sparse_feat->set_default_weight(sparse_default_weight_);
37     (*sparse_feat->mutable_weights())[sparse_elem1_name_] =
38         sparse_elem1_weight_;
39     (*sparse_feat->mutable_weights())[sparse_elem2_name_] =
40         sparse_elem2_weight_;
41 
42     BucketizedWeights* bucketized_feat =
43         weights[bucketized_name_].mutable_bucketized();
44     bucketized_feat->set_default_weight(bucketization_default_weight_);
45     for (const float boundary : bucketization_boundaries_) {
46       bucketized_feat->add_boundaries(boundary);
47     }
48     for (const float weight : bucketization_weights_) {
49       bucketized_feat->add_weights(weight);
50     }
51 
52     return proto;
53   }
54 
55   const std::string scalar1_name_ = "scalar_feature1";
56   const std::string scalar2_name_ = "scalar_feature2";
57   const std::string scalar3_name_ = "scalar_feature3";
58   const std::string one_hot_name_ = "one_hot_feature";
59   const std::string one_hot_elem1_name_ = "one_hot_elem1";
60   const std::string one_hot_elem2_name_ = "one_hot_elem2";
61   const std::string one_hot_elem3_name_ = "one_hot_elem3";
62   const float bias_ = 1.5f;
63   const float threshold_ = 0.6f;
64   const float scalar1_weight_ = 0.8f;
65   const float scalar2_weight_ = -2.4f;
66   const float scalar3_weight_ = 0.01f;
67   const float one_hot_elem1_weight_ = -1.0f;
68   const float one_hot_elem2_weight_ = 5.0f;
69   const float one_hot_elem3_weight_ = -1.5f;
70   const float one_hot_default_weight_ = 10.0f;
71   const float epsilon_ = 0.001f;
72 
73   const std::string sparse_name_ = "sparse_feature";
74   const std::string sparse_elem1_name_ = "sparse_elem1";
75   const std::string sparse_elem2_name_ = "sparse_elem2";
76   const float sparse_elem1_weight_ = -2.2f;
77   const float sparse_elem2_weight_ = 3.1f;
78   const float sparse_default_weight_ = 4.4f;
79 
80   const std::string bucketized_name_ = "bucketized_feature";
81   const float bucketization_boundaries_[2] = {0.3f, 0.7f};
82   const float bucketization_weights_[3] = {-1.0f, 1.0f, 3.0f};
83   const float bucketization_default_weight_ = -3.3f;
84 };
85 
TEST_F(GenericLogisticRegressionInferenceTest,BaseTest)86 TEST_F(GenericLogisticRegressionInferenceTest, BaseTest) {
87   auto predictor = GenericLogisticRegressionInference(GetProto());
88 
89   RankerExample example;
90   auto& features = *example.mutable_features();
91   features[scalar1_name_].set_bool_value(true);
92   features[scalar2_name_].set_int32_value(42);
93   features[scalar3_name_].set_float_value(0.666f);
94   features[one_hot_name_].set_string_value(one_hot_elem1_name_);
95 
96   float score = predictor.PredictScore(example);
97   float expected_score =
98       Sigmoid(bias_ + 1.0f * scalar1_weight_ + 42.0f * scalar2_weight_ +
99               0.666f * scalar3_weight_ + one_hot_elem1_weight_);
100   EXPECT_NEAR(expected_score, score, epsilon_);
101   EXPECT_EQ(expected_score >= threshold_, predictor.Predict(example));
102 }
103 
TEST_F(GenericLogisticRegressionInferenceTest,UnknownElement)104 TEST_F(GenericLogisticRegressionInferenceTest, UnknownElement) {
105   RankerExample example;
106   auto& features = *example.mutable_features();
107   features[one_hot_name_].set_string_value("Unknown element");
108 
109   auto predictor = GenericLogisticRegressionInference(GetProto());
110   float score = predictor.PredictScore(example);
111   float expected_score = Sigmoid(bias_ + one_hot_default_weight_);
112   EXPECT_NEAR(expected_score, score, epsilon_);
113 }
114 
TEST_F(GenericLogisticRegressionInferenceTest,MissingFeatures)115 TEST_F(GenericLogisticRegressionInferenceTest, MissingFeatures) {
116   RankerExample example;
117 
118   auto predictor = GenericLogisticRegressionInference(GetProto());
119   float score = predictor.PredictScore(example);
120   // Missing features will use default weights for one_hot features and drop
121   // scalar features.
122   float expected_score = Sigmoid(bias_ + one_hot_default_weight_);
123   EXPECT_NEAR(expected_score, score, epsilon_);
124 }
125 
TEST_F(GenericLogisticRegressionInferenceTest,UnknownFeatures)126 TEST_F(GenericLogisticRegressionInferenceTest, UnknownFeatures) {
127   RankerExample example;
128   auto& features = *example.mutable_features();
129   features["foo1"].set_bool_value(true);
130   features["foo2"].set_int32_value(42);
131   features["foo3"].set_float_value(0.666f);
132   features["foo4"].set_string_value(one_hot_elem1_name_);
133   // All features except this one will be ignored.
134   features[one_hot_name_].set_string_value(one_hot_elem2_name_);
135 
136   auto predictor = GenericLogisticRegressionInference(GetProto());
137   float score = predictor.PredictScore(example);
138   // Unknown features will be ignored.
139   float expected_score = Sigmoid(bias_ + one_hot_elem2_weight_);
140   EXPECT_NEAR(expected_score, score, epsilon_);
141 }
142 
TEST_F(GenericLogisticRegressionInferenceTest,Threshold)143 TEST_F(GenericLogisticRegressionInferenceTest, Threshold) {
144   // In this test, we calculate the score for a given example and set the model
145   // threshold to this value. We then add a feature to the example that should
146   // tip the score slightly on either side of the treshold and verify that the
147   // decision is as expected.
148 
149   auto proto = GetProto();
150   auto threshold_calculator = GenericLogisticRegressionInference(proto);
151 
152   RankerExample example;
153   auto& features = *example.mutable_features();
154   features[scalar1_name_].set_bool_value(true);
155   features[scalar2_name_].set_int32_value(2);
156   features[one_hot_name_].set_string_value(one_hot_elem1_name_);
157 
158   float threshold = threshold_calculator.PredictScore(example);
159   proto.set_threshold(threshold);
160 
161   // Setting the model with the calculated threshold.
162   auto predictor = GenericLogisticRegressionInference(proto);
163 
164   // Adding small positive contribution from scalar3 to tip the decision the
165   // positive side of the threshold.
166   features[scalar3_name_].set_float_value(0.01f);
167   float score = predictor.PredictScore(example);
168   // The score is now greater than, but still near the threshold. The
169   // decision should be positive.
170   EXPECT_LT(threshold, score);
171   EXPECT_NEAR(threshold, score, epsilon_);
172   EXPECT_TRUE(predictor.Predict(example));
173 
174   // A small negative contribution from scalar3 should tip the decision the
175   // other way.
176   features[scalar3_name_].set_float_value(-0.01f);
177   score = predictor.PredictScore(example);
178   EXPECT_GT(threshold, score);
179   EXPECT_NEAR(threshold, score, epsilon_);
180   EXPECT_FALSE(predictor.Predict(example));
181 }
182 
TEST_F(GenericLogisticRegressionInferenceTest,NoThreshold)183 TEST_F(GenericLogisticRegressionInferenceTest, NoThreshold) {
184   auto proto = GetProto();
185   // When no threshold is specified, we use the default of 0.5.
186   proto.clear_threshold();
187   auto predictor = GenericLogisticRegressionInference(proto);
188 
189   RankerExample example;
190   auto& features = *example.mutable_features();
191   // one_hot_elem3 exactly balances the bias, so we expect the pre-sigmoid score
192   // to be zero, and the post-sigmoid score to be 0.5 if this is the only active
193   // feature.
194   features[one_hot_name_].set_string_value(one_hot_elem3_name_);
195   float score = predictor.PredictScore(example);
196   EXPECT_NEAR(0.5f, score, epsilon_);
197 
198   // Adding small contribution from scalar3 to tip the decision on one side or
199   // the other of the threshold.
200   features[scalar3_name_].set_float_value(0.01f);
201   score = predictor.PredictScore(example);
202   // The score is now greater than, but still near 0.5. The decision should be
203   // positive.
204   EXPECT_LT(0.5f, score);
205   EXPECT_NEAR(0.5f, score, epsilon_);
206   EXPECT_TRUE(predictor.Predict(example));
207 
208   features[scalar3_name_].set_float_value(-0.01f);
209   score = predictor.PredictScore(example);
210   // The score is now lower than, but near 0.5. The decision should be
211   // negative.
212   EXPECT_GT(0.5f, score);
213   EXPECT_NEAR(0.5f, score, epsilon_);
214   EXPECT_FALSE(predictor.Predict(example));
215 }
216 
TEST_F(GenericLogisticRegressionInferenceTest,PreprossessedModel)217 TEST_F(GenericLogisticRegressionInferenceTest, PreprossessedModel) {
218   GenericLogisticRegressionModel proto = GetProto();
219   proto.set_is_preprocessed_model(true);
220   // Clear the weights to make sure the inference is done by fullname_weights.
221   proto.clear_weights();
222 
223   // Build fullname weights.
224   Map<std::string, float>& weights = *proto.mutable_fullname_weights();
225   weights[scalar1_name_] = scalar1_weight_;
226   weights[scalar2_name_] = scalar2_weight_;
227   weights[scalar3_name_] = scalar3_weight_;
228   weights[ExamplePreprocessor::FeatureFullname(
229       one_hot_name_, one_hot_elem1_name_)] = one_hot_elem1_weight_;
230   weights[ExamplePreprocessor::FeatureFullname(
231       one_hot_name_, one_hot_elem2_name_)] = one_hot_elem2_weight_;
232   weights[ExamplePreprocessor::FeatureFullname(
233       one_hot_name_, one_hot_elem3_name_)] = one_hot_elem3_weight_;
234   weights[ExamplePreprocessor::FeatureFullname(
235       sparse_name_, sparse_elem1_name_)] = sparse_elem1_weight_;
236   weights[ExamplePreprocessor::FeatureFullname(
237       sparse_name_, sparse_elem2_name_)] = sparse_elem2_weight_;
238   weights[ExamplePreprocessor::FeatureFullname(bucketized_name_, "0")] =
239       bucketization_weights_[0];
240   weights[ExamplePreprocessor::FeatureFullname(bucketized_name_, "1")] =
241       bucketization_weights_[1];
242   weights[ExamplePreprocessor::FeatureFullname(bucketized_name_, "2")] =
243       bucketization_weights_[2];
244   weights[ExamplePreprocessor::FeatureFullname(
245       ExamplePreprocessor::kMissingFeatureDefaultName, one_hot_name_)] =
246       one_hot_default_weight_;
247   weights[ExamplePreprocessor::FeatureFullname(
248       ExamplePreprocessor::kMissingFeatureDefaultName, sparse_name_)] =
249       sparse_default_weight_;
250   weights[ExamplePreprocessor::FeatureFullname(
251       ExamplePreprocessor::kMissingFeatureDefaultName, bucketized_name_)] =
252       bucketization_default_weight_;
253 
254   // Build preprocessor_config.
255   ExamplePreprocessorConfig& config = *proto.mutable_preprocessor_config();
256   config.add_missing_features(one_hot_name_);
257   config.add_missing_features(sparse_name_);
258   config.add_missing_features(bucketized_name_);
259   (*config.mutable_bucketizers())[bucketized_name_].add_boundaries(
260       bucketization_boundaries_[0]);
261   (*config.mutable_bucketizers())[bucketized_name_].add_boundaries(
262       bucketization_boundaries_[1]);
263 
264   auto predictor = GenericLogisticRegressionInference(proto);
265 
266   // Build example.
267   RankerExample example;
268   Map<std::string, Feature>& features = *example.mutable_features();
269   features[scalar1_name_].set_bool_value(true);
270   features[scalar2_name_].set_int32_value(42);
271   features[scalar3_name_].set_float_value(0.666f);
272   features[one_hot_name_].set_string_value(one_hot_elem1_name_);
273   features[sparse_name_].mutable_string_list()->add_string_value(
274       sparse_elem1_name_);
275   features[sparse_name_].mutable_string_list()->add_string_value(
276       sparse_elem2_name_);
277   features[bucketized_name_].set_float_value(0.98f);
278 
279   // Inference.
280   float score = predictor.PredictScore(example);
281   float expected_score = Sigmoid(
282       bias_ + 1.0f * scalar1_weight_ + 42.0f * scalar2_weight_ +
283       0.666f * scalar3_weight_ + one_hot_elem1_weight_ + sparse_elem1_weight_ +
284       sparse_elem2_weight_ + bucketization_weights_[2]);
285 
286   EXPECT_NEAR(expected_score, score, epsilon_);
287   EXPECT_EQ(expected_score >= threshold_, predictor.Predict(example));
288 }
289 
290 }  // namespace assist_ranker
291