1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef COMPONENTS_OPTIMIZATION_GUIDE_PREDICTION_MODEL_H_
6 #define COMPONENTS_OPTIMIZATION_GUIDE_PREDICTION_MODEL_H_
7 
8 #include <stdint.h>
9 #include <memory>
10 #include <string>
11 
12 #include "base/containers/flat_map.h"
13 #include "base/containers/flat_set.h"
14 #include "base/macros.h"
15 #include "base/sequence_checker.h"
16 #include "components/optimization_guide/optimization_guide_enums.h"
17 #include "components/optimization_guide/proto/models.pb.h"
18 
19 namespace optimization_guide {
20 
21 // A PredictionModel supported by the optimization guide that makes an
22 // OptimizationTargetDecision by evaluating a prediction model.
23 class PredictionModel {
24  public:
25   virtual ~PredictionModel();
26 
27   // Creates an Prediction model of the correct ModelType specified in
28   // |prediction_model|. The validation overhead of this factory can be high and
29   // should should be called in the background.
30   static std::unique_ptr<PredictionModel> Create(
31       std::unique_ptr<optimization_guide::proto::PredictionModel>
32           prediction_model);
33 
34   // Returns the OptimizationTargetDecision by evaluating the |model_|
35   // using the provided |model_features|. |prediction_score| will be populated
36   // with the score output by the model.
37   virtual optimization_guide::OptimizationTargetDecision Predict(
38       const base::flat_map<std::string, float>& model_features,
39       double* prediction_score) = 0;
40 
41   // Provide the version of the |model_| by |this|.
42   int64_t GetVersion() const;
43 
44   // Provide the model features required for evaluation of the |model_| by
45   // |this|.
46   base::flat_set<std::string> GetModelFeatures() const;
47 
48  protected:
49   PredictionModel(std::unique_ptr<optimization_guide::proto::PredictionModel>
50                       prediction_model);
51 
52   // The in-memory model used for prediction.
53   std::unique_ptr<optimization_guide::proto::Model> model_;
54 
55  private:
56   // Determines if the |model_| is complete and can be successfully evaluated by
57   // |this|.
58   virtual bool ValidatePredictionModel() const = 0;
59 
60   // The information that describes the |model_|
61   std::unique_ptr<optimization_guide::proto::ModelInfo> model_info_;
62 
63   // The set of features required by the |model_| to be evaluated.
64   base::flat_set<std::string> model_features_;
65 
66   // The version of the |model_|.
67   int64_t version_;
68 
69   SEQUENCE_CHECKER(sequence_checker_);
70 
71   DISALLOW_COPY_AND_ASSIGN(PredictionModel);
72 };
73 
74 }  // namespace optimization_guide
75 
76 #endif  // COMPONENTS_OPTIMIZATION_GUIDE_PREDICTION_MODEL_H_
77