1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef CHROME_SERVICES_MACHINE_LEARNING_PUBLIC_CPP_DECISION_TREE_DECISION_TREE_PREDICTION_MODEL_H_
6 #define CHROME_SERVICES_MACHINE_LEARNING_PUBLIC_CPP_DECISION_TREE_DECISION_TREE_PREDICTION_MODEL_H_
7 
8 #include <memory>
9 #include <string>
10 
11 #include "base/containers/flat_map.h"
12 #include "base/containers/flat_set.h"
13 #include "base/macros.h"
14 #include "base/sequence_checker.h"
15 #include "chrome/services/machine_learning/public/cpp/decision_tree/prediction_model.h"
16 #include "chrome/services/machine_learning/public/mojom/decision_tree.mojom.h"
17 #include "components/optimization_guide/proto/models.pb.h"
18 
19 namespace machine_learning {
20 namespace decision_tree {
21 
22 // A concrete PredictionModel capable of evaluating the decision tree model type
23 // supported by the optimization guide.
24 class DecisionTreePredictionModel : public PredictionModel {
25  public:
26   explicit DecisionTreePredictionModel(
27       std::unique_ptr<optimization_guide::proto::PredictionModel>
28           prediction_model);
29 
30   ~DecisionTreePredictionModel() override;
31 
32   DecisionTreePredictionModel(const DecisionTreePredictionModel&) = delete;
33   DecisionTreePredictionModel& operator=(const DecisionTreePredictionModel&) =
34       delete;
35 
36   // PredictionModel implementation:
37   mojom::DecisionTreePredictionResult Predict(
38       const base::flat_map<std::string, float>& model_features,
39       double* prediction_score) override;
40 
41  private:
42   // Evaluates the provided model, either an ensemble or decision tree model,
43   // with the |model_features| and stores the output in |result|. Returns false
44   // if evaluation fails.
45   bool EvaluateModel(const optimization_guide::proto::Model& model,
46                      const base::flat_map<std::string, float>& model_features,
47                      double* result);
48 
49   // Evaluates the decision tree model with the |model_features| and
50   // stores the output in |result|. Returns false if evaluation fails.
51   bool EvaluateDecisionTree(
52       const optimization_guide::proto::DecisionTree& tree,
53       const base::flat_map<std::string, float>& model_features,
54       double* result);
55 
56   // Evaluates an ensemble model with the |model_features| and
57   // stores the output in |result|. Returns false if evaluation fails.
58   bool EvaluateEnsembleModel(
59       const optimization_guide::proto::Ensemble& ensemble,
60       const base::flat_map<std::string, float>& model_features,
61       double* result);
62 
63   // Performs a depth first traversal the |tree| based on |model_features|
64   // and stores the value of the leaf in |result|. Returns false if the
65   // traversal or node evaluation fails.
66   bool TraverseTree(const optimization_guide::proto::DecisionTree& tree,
67                     const optimization_guide::proto::TreeNode& node,
68                     const base::flat_map<std::string, float>& model_features,
69                     double* result);
70 
71   // PredictionModel implementation:
72   bool ValidatePredictionModel() const override;
73 
74   // Validates a model or submodel of an ensemble. Returns
75   // false if the model is invalid.
76   bool ValidateModel(const optimization_guide::proto::Model& model) const;
77 
78   // Validates an ensemble model. Returns false if the ensemble
79   // if invalid.
80   bool ValidateEnsembleModel(
81       const optimization_guide::proto::Ensemble& ensemble) const;
82 
83   // Validates a decision tree model. Returns false if the
84   // decision tree model is invalid.
85   bool ValidateDecisionTree(
86       const optimization_guide::proto::DecisionTree& tree) const;
87 
88   // Validates a leaf. Returns false if the leaf is invalid.
89   bool ValidateLeaf(const optimization_guide::proto::Leaf& leaf) const;
90 
91   // Validates an inequality test. Returns false if the
92   // inequality test is invalid.
93   bool ValidateInequalityTest(
94       const optimization_guide::proto::InequalityTest& inequality_test) const;
95 
96   // Validates each node of a decision tree by traversing every
97   // node of the |tree|. Returns false if any part of the tree is invalid.
98   bool ValidateTreeNode(const optimization_guide::proto::DecisionTree& tree,
99                         const optimization_guide::proto::TreeNode& node,
100                         const int& node_index) const;
101 
102   SEQUENCE_CHECKER(sequence_checker_);
103 };
104 
105 }  // namespace decision_tree
106 }  // namespace machine_learning
107 
108 #endif  // CHROME_SERVICES_MACHINE_LEARNING_PUBLIC_CPP_DECISION_TREE_DECISION_TREE_PREDICTION_MODEL_H_
109