1 // Copyright 2020 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef COMPONENTS_OPTIMIZATION_GUIDE_PREDICTION_MODEL_H_ 6 #define COMPONENTS_OPTIMIZATION_GUIDE_PREDICTION_MODEL_H_ 7 8 #include <stdint.h> 9 #include <memory> 10 #include <string> 11 12 #include "base/containers/flat_map.h" 13 #include "base/containers/flat_set.h" 14 #include "base/macros.h" 15 #include "base/sequence_checker.h" 16 #include "components/optimization_guide/optimization_guide_enums.h" 17 #include "components/optimization_guide/proto/models.pb.h" 18 19 namespace optimization_guide { 20 21 // A PredictionModel supported by the optimization guide that makes an 22 // OptimizationTargetDecision by evaluating a prediction model. 23 class PredictionModel { 24 public: 25 virtual ~PredictionModel(); 26 27 // Creates an Prediction model of the correct ModelType specified in 28 // |prediction_model|. The validation overhead of this factory can be high and 29 // should should be called in the background. 30 static std::unique_ptr<PredictionModel> Create( 31 std::unique_ptr<optimization_guide::proto::PredictionModel> 32 prediction_model); 33 34 // Returns the OptimizationTargetDecision by evaluating the |model_| 35 // using the provided |model_features|. |prediction_score| will be populated 36 // with the score output by the model. 37 virtual optimization_guide::OptimizationTargetDecision Predict( 38 const base::flat_map<std::string, float>& model_features, 39 double* prediction_score) = 0; 40 41 // Provide the version of the |model_| by |this|. 42 int64_t GetVersion() const; 43 44 // Provide the model features required for evaluation of the |model_| by 45 // |this|. 46 base::flat_set<std::string> GetModelFeatures() const; 47 48 protected: 49 PredictionModel(std::unique_ptr<optimization_guide::proto::PredictionModel> 50 prediction_model); 51 52 // The in-memory model used for prediction. 53 std::unique_ptr<optimization_guide::proto::Model> model_; 54 55 private: 56 // Determines if the |model_| is complete and can be successfully evaluated by 57 // |this|. 58 virtual bool ValidatePredictionModel() const = 0; 59 60 // The information that describes the |model_| 61 std::unique_ptr<optimization_guide::proto::ModelInfo> model_info_; 62 63 // The set of features required by the |model_| to be evaluated. 64 base::flat_set<std::string> model_features_; 65 66 // The version of the |model_|. 67 int64_t version_; 68 69 SEQUENCE_CHECKER(sequence_checker_); 70 71 DISALLOW_COPY_AND_ASSIGN(PredictionModel); 72 }; 73 74 } // namespace optimization_guide 75 76 #endif // COMPONENTS_OPTIMIZATION_GUIDE_PREDICTION_MODEL_H_ 77