1 // Copyright 2020 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef CHROME_SERVICES_MACHINE_LEARNING_PUBLIC_CPP_DECISION_TREE_DECISION_TREE_PREDICTION_MODEL_H_ 6 #define CHROME_SERVICES_MACHINE_LEARNING_PUBLIC_CPP_DECISION_TREE_DECISION_TREE_PREDICTION_MODEL_H_ 7 8 #include <memory> 9 #include <string> 10 11 #include "base/containers/flat_map.h" 12 #include "base/containers/flat_set.h" 13 #include "base/macros.h" 14 #include "base/sequence_checker.h" 15 #include "chrome/services/machine_learning/public/cpp/decision_tree/prediction_model.h" 16 #include "chrome/services/machine_learning/public/mojom/decision_tree.mojom.h" 17 #include "components/optimization_guide/proto/models.pb.h" 18 19 namespace machine_learning { 20 namespace decision_tree { 21 22 // A concrete PredictionModel capable of evaluating the decision tree model type 23 // supported by the optimization guide. 24 class DecisionTreePredictionModel : public PredictionModel { 25 public: 26 explicit DecisionTreePredictionModel( 27 std::unique_ptr<optimization_guide::proto::PredictionModel> 28 prediction_model); 29 30 ~DecisionTreePredictionModel() override; 31 32 DecisionTreePredictionModel(const DecisionTreePredictionModel&) = delete; 33 DecisionTreePredictionModel& operator=(const DecisionTreePredictionModel&) = 34 delete; 35 36 // PredictionModel implementation: 37 mojom::DecisionTreePredictionResult Predict( 38 const base::flat_map<std::string, float>& model_features, 39 double* prediction_score) override; 40 41 private: 42 // Evaluates the provided model, either an ensemble or decision tree model, 43 // with the |model_features| and stores the output in |result|. Returns false 44 // if evaluation fails. 45 bool EvaluateModel(const optimization_guide::proto::Model& model, 46 const base::flat_map<std::string, float>& model_features, 47 double* result); 48 49 // Evaluates the decision tree model with the |model_features| and 50 // stores the output in |result|. Returns false if evaluation fails. 51 bool EvaluateDecisionTree( 52 const optimization_guide::proto::DecisionTree& tree, 53 const base::flat_map<std::string, float>& model_features, 54 double* result); 55 56 // Evaluates an ensemble model with the |model_features| and 57 // stores the output in |result|. Returns false if evaluation fails. 58 bool EvaluateEnsembleModel( 59 const optimization_guide::proto::Ensemble& ensemble, 60 const base::flat_map<std::string, float>& model_features, 61 double* result); 62 63 // Performs a depth first traversal the |tree| based on |model_features| 64 // and stores the value of the leaf in |result|. Returns false if the 65 // traversal or node evaluation fails. 66 bool TraverseTree(const optimization_guide::proto::DecisionTree& tree, 67 const optimization_guide::proto::TreeNode& node, 68 const base::flat_map<std::string, float>& model_features, 69 double* result); 70 71 // PredictionModel implementation: 72 bool ValidatePredictionModel() const override; 73 74 // Validates a model or submodel of an ensemble. Returns 75 // false if the model is invalid. 76 bool ValidateModel(const optimization_guide::proto::Model& model) const; 77 78 // Validates an ensemble model. Returns false if the ensemble 79 // if invalid. 80 bool ValidateEnsembleModel( 81 const optimization_guide::proto::Ensemble& ensemble) const; 82 83 // Validates a decision tree model. Returns false if the 84 // decision tree model is invalid. 85 bool ValidateDecisionTree( 86 const optimization_guide::proto::DecisionTree& tree) const; 87 88 // Validates a leaf. Returns false if the leaf is invalid. 89 bool ValidateLeaf(const optimization_guide::proto::Leaf& leaf) const; 90 91 // Validates an inequality test. Returns false if the 92 // inequality test is invalid. 93 bool ValidateInequalityTest( 94 const optimization_guide::proto::InequalityTest& inequality_test) const; 95 96 // Validates each node of a decision tree by traversing every 97 // node of the |tree|. Returns false if any part of the tree is invalid. 98 bool ValidateTreeNode(const optimization_guide::proto::DecisionTree& tree, 99 const optimization_guide::proto::TreeNode& node, 100 const int& node_index) const; 101 102 SEQUENCE_CHECKER(sequence_checker_); 103 }; 104 105 } // namespace decision_tree 106 } // namespace machine_learning 107 108 #endif // CHROME_SERVICES_MACHINE_LEARNING_PUBLIC_CPP_DECISION_TREE_DECISION_TREE_PREDICTION_MODEL_H_ 109