1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef COMPONENTS_OPTIMIZATION_GUIDE_DECISION_TREE_PREDICTION_MODEL_H_
6 #define COMPONENTS_OPTIMIZATION_GUIDE_DECISION_TREE_PREDICTION_MODEL_H_
7 
8 #include <memory>
9 #include <string>
10 
11 #include "base/containers/flat_map.h"
12 #include "base/containers/flat_set.h"
13 #include "base/macros.h"
14 #include "base/sequence_checker.h"
15 #include "components/optimization_guide/prediction_model.h"
16 #include "components/optimization_guide/proto/models.pb.h"
17 
18 namespace optimization_guide {
19 
20 // A concrete PredictionModel capable of evaluating the decision tree model type
21 // supported by the optimization guide.
22 class DecisionTreePredictionModel : public PredictionModel {
23  public:
24   explicit DecisionTreePredictionModel(
25       std::unique_ptr<optimization_guide::proto::PredictionModel>
26           prediction_model);
27 
28   ~DecisionTreePredictionModel() override;
29 
30   // PredictionModel implementation:
31   optimization_guide::OptimizationTargetDecision Predict(
32       const base::flat_map<std::string, float>& model_features,
33       double* prediction_score) override;
34 
35  private:
36   // Evaluates the provided model, either an ensemble or decision tree model,
37   // with the |model_features| and stores the output in |result|. Returns false
38   // if evaluation fails.
39   bool EvaluateModel(const proto::Model& model,
40                      const base::flat_map<std::string, float>& model_features,
41                      double* result);
42 
43   // Evaluates the decision tree model with the |model_features| and
44   // stores the output in |result|. Returns false if the evaluation fails.
45   bool EvaluateDecisionTree(
46       const proto::DecisionTree& tree,
47       const base::flat_map<std::string, float>& model_features,
48       double* result);
49 
50   // Evaluates an ensemble model with the |model_features| and
51   // stores the output in |result|. Returns false if the evaluation fails.
52   bool EvaluateEnsembleModel(
53       const proto::Ensemble& ensemble,
54       const base::flat_map<std::string, float>& model_features,
55       double* result);
56 
57   // Performs a depth first traversal the  |tree| based on |model_features|
58   // and stores the value of the leaf in |result|. Returns false if the
59   // traversal or node evaluation fails.
60   bool TraverseTree(const proto::DecisionTree& tree,
61                     const proto::TreeNode& node,
62                     const base::flat_map<std::string, float>& model_features,
63                     double* result);
64 
65   // PredictionModel implementation:
66   bool ValidatePredictionModel() const override;
67 
68   // Validates a model or submodel of an ensemble. Returns
69   // false if the model is invalid.
70   bool ValidateModel(const proto::Model& model) const;
71 
72   // Validates an ensemble model. Returns false if the ensemble
73   // if invalid.
74   bool ValidateEnsembleModel(const proto::Ensemble& ensemble) const;
75 
76   // Validates a decision tree model. Returns false if the
77   // decision tree model is invalid.
78   bool ValidateDecisionTree(const proto::DecisionTree& tree) const;
79 
80   // Validates a leaf. Returns false if the leaf is invalid.
81   bool ValidateLeaf(const proto::Leaf& leaf) const;
82 
83   // Validates an inequality test. Returns false if the
84   // inequality test is invalid.
85   bool ValidateInequalityTest(
86       const proto::InequalityTest& inequality_test) const;
87 
88   // Validates each node of a decision tree by traversing every
89   // node of the |tree|. Returns false if any part of the tree is invalid.
90   bool ValidateTreeNode(const proto::DecisionTree& tree,
91                         const proto::TreeNode& node,
92                         const int& node_index) const;
93 
94   SEQUENCE_CHECKER(sequence_checker_);
95 
96   DISALLOW_COPY_AND_ASSIGN(DecisionTreePredictionModel);
97 };
98 
99 }  // namespace optimization_guide
100 
101 #endif  // COMPONENTS_OPTIMIZATION_GUIDE_DECISION_TREE_PREDICTION_MODEL_H_
102