1 // Copyright 2020 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef COMPONENTS_OPTIMIZATION_GUIDE_DECISION_TREE_PREDICTION_MODEL_H_ 6 #define COMPONENTS_OPTIMIZATION_GUIDE_DECISION_TREE_PREDICTION_MODEL_H_ 7 8 #include <memory> 9 #include <string> 10 11 #include "base/containers/flat_map.h" 12 #include "base/containers/flat_set.h" 13 #include "base/macros.h" 14 #include "base/sequence_checker.h" 15 #include "components/optimization_guide/prediction_model.h" 16 #include "components/optimization_guide/proto/models.pb.h" 17 18 namespace optimization_guide { 19 20 // A concrete PredictionModel capable of evaluating the decision tree model type 21 // supported by the optimization guide. 22 class DecisionTreePredictionModel : public PredictionModel { 23 public: 24 explicit DecisionTreePredictionModel( 25 std::unique_ptr<optimization_guide::proto::PredictionModel> 26 prediction_model); 27 28 ~DecisionTreePredictionModel() override; 29 30 // PredictionModel implementation: 31 optimization_guide::OptimizationTargetDecision Predict( 32 const base::flat_map<std::string, float>& model_features, 33 double* prediction_score) override; 34 35 private: 36 // Evaluates the provided model, either an ensemble or decision tree model, 37 // with the |model_features| and stores the output in |result|. Returns false 38 // if evaluation fails. 39 bool EvaluateModel(const proto::Model& model, 40 const base::flat_map<std::string, float>& model_features, 41 double* result); 42 43 // Evaluates the decision tree model with the |model_features| and 44 // stores the output in |result|. Returns false if the evaluation fails. 45 bool EvaluateDecisionTree( 46 const proto::DecisionTree& tree, 47 const base::flat_map<std::string, float>& model_features, 48 double* result); 49 50 // Evaluates an ensemble model with the |model_features| and 51 // stores the output in |result|. Returns false if the evaluation fails. 52 bool EvaluateEnsembleModel( 53 const proto::Ensemble& ensemble, 54 const base::flat_map<std::string, float>& model_features, 55 double* result); 56 57 // Performs a depth first traversal the |tree| based on |model_features| 58 // and stores the value of the leaf in |result|. Returns false if the 59 // traversal or node evaluation fails. 60 bool TraverseTree(const proto::DecisionTree& tree, 61 const proto::TreeNode& node, 62 const base::flat_map<std::string, float>& model_features, 63 double* result); 64 65 // PredictionModel implementation: 66 bool ValidatePredictionModel() const override; 67 68 // Validates a model or submodel of an ensemble. Returns 69 // false if the model is invalid. 70 bool ValidateModel(const proto::Model& model) const; 71 72 // Validates an ensemble model. Returns false if the ensemble 73 // if invalid. 74 bool ValidateEnsembleModel(const proto::Ensemble& ensemble) const; 75 76 // Validates a decision tree model. Returns false if the 77 // decision tree model is invalid. 78 bool ValidateDecisionTree(const proto::DecisionTree& tree) const; 79 80 // Validates a leaf. Returns false if the leaf is invalid. 81 bool ValidateLeaf(const proto::Leaf& leaf) const; 82 83 // Validates an inequality test. Returns false if the 84 // inequality test is invalid. 85 bool ValidateInequalityTest( 86 const proto::InequalityTest& inequality_test) const; 87 88 // Validates each node of a decision tree by traversing every 89 // node of the |tree|. Returns false if any part of the tree is invalid. 90 bool ValidateTreeNode(const proto::DecisionTree& tree, 91 const proto::TreeNode& node, 92 const int& node_index) const; 93 94 SEQUENCE_CHECKER(sequence_checker_); 95 96 DISALLOW_COPY_AND_ASSIGN(DecisionTreePredictionModel); 97 }; 98 99 } // namespace optimization_guide 100 101 #endif // COMPONENTS_OPTIMIZATION_GUIDE_DECISION_TREE_PREDICTION_MODEL_H_ 102