1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "components/optimization_guide/prediction_model.h"
6 
7 #include <utility>
8 
9 #include "components/optimization_guide/decision_tree_prediction_model.h"
10 
11 namespace optimization_guide {
12 
13 // static
Create(std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model)14 std::unique_ptr<PredictionModel> PredictionModel::Create(
15     std::unique_ptr<optimization_guide::proto::PredictionModel>
16         prediction_model) {
17   // TODO(crbug/1009123): Add a histogram to record if the provided model is
18   // constructed successfully or not.
19   // TODO(crbug/1009123): Adding timing metrics around initialization due to
20   // potential validation overhead.
21   if (!prediction_model->has_model())
22     return nullptr;
23 
24   if (!prediction_model->has_model_info())
25     return nullptr;
26 
27   if (!prediction_model->model_info().has_version())
28     return nullptr;
29 
30   // Enforce that only one ModelType is specified for the PredictionModel.
31   if (prediction_model->model_info().supported_model_types_size() != 1) {
32     return nullptr;
33   }
34 
35   // Check that the client supports this type of model and is not an unknown
36   // type.
37   if (!optimization_guide::proto::ModelType_IsValid(
38           prediction_model->model_info().supported_model_types(0)) ||
39       prediction_model->model_info().supported_model_types(0) ==
40           optimization_guide::proto::ModelType::MODEL_TYPE_UNKNOWN) {
41     return nullptr;
42   }
43 
44   // Check that the client supports the model features for |prediction model|.
45   for (const auto& model_feature :
46        prediction_model->model_info().supported_model_features()) {
47     if (!optimization_guide::proto::ClientModelFeature_IsValid(model_feature) ||
48         model_feature == optimization_guide::proto::ClientModelFeature::
49                              CLIENT_MODEL_FEATURE_UNKNOWN)
50       return nullptr;
51   }
52 
53   std::unique_ptr<PredictionModel> model;
54   // The Decision Tree model type is currently the only supported model type.
55   if (prediction_model->model_info().supported_model_types(0) !=
56       optimization_guide::proto::ModelType::MODEL_TYPE_DECISION_TREE) {
57     return nullptr;
58   }
59   model = std::make_unique<DecisionTreePredictionModel>(
60       std::move(prediction_model));
61 
62   // Any constructed model must be validated for correctness according to its
63   // model type before being returned.
64   if (!model->ValidatePredictionModel())
65     return nullptr;
66 
67   return model;
68 }
69 
PredictionModel(std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model)70 PredictionModel::PredictionModel(
71     std::unique_ptr<optimization_guide::proto::PredictionModel>
72         prediction_model) {
73   version_ = prediction_model->model_info().version();
74   model_features_.reserve(
75       prediction_model->model_info().supported_model_features_size() +
76       prediction_model->model_info().supported_host_model_features_size());
77   // Insert all the client model features for the owned |model_|.
78   for (const auto& client_model_feature :
79        prediction_model->model_info().supported_model_features()) {
80     model_features_.emplace(optimization_guide::proto::ClientModelFeature_Name(
81         client_model_feature));
82   }
83   // Insert all the host model features for the owned |model_|.
84   for (const auto& host_model_feature :
85        prediction_model->model_info().supported_host_model_features()) {
86     model_features_.emplace(host_model_feature);
87   }
88   model_ = std::make_unique<optimization_guide::proto::Model>(
89       prediction_model->model());
90 }
91 
GetVersion() const92 int64_t PredictionModel::GetVersion() const {
93   SEQUENCE_CHECKER(sequence_checker_);
94   return version_;
95 }
96 
GetModelFeatures() const97 base::flat_set<std::string> PredictionModel::GetModelFeatures() const {
98   SEQUENCE_CHECKER(sequence_checker_);
99   return model_features_;
100 }
101 
102 PredictionModel::~PredictionModel() = default;
103 
104 }  // namespace optimization_guide
105