1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef MEDIA_LEARNING_IMPL_RANDOM_TREE_TRAINER_H_
6 #define MEDIA_LEARNING_IMPL_RANDOM_TREE_TRAINER_H_
7 
8 #include <limits>
9 #include <map>
10 #include <memory>
11 #include <set>
12 
13 #include "base/component_export.h"
14 #include "base/containers/flat_map.h"
15 #include "base/macros.h"
16 #include "media/learning/common/learning_task.h"
17 #include "media/learning/impl/random_number_generator.h"
18 #include "media/learning/impl/training_algorithm.h"
19 
20 namespace media {
21 namespace learning {
22 
23 // Trains RandomTree decision tree classifier / regressor.
24 //
25 // Decision trees, including RandomTree, classify instances as follows.  Each
26 // non-leaf node is marked with a feature number |i|.  The value of the |i|-th
27 // feature of the instance is then used to select which outgoing edge is
28 // traversed.  This repeats until arriving at a leaf, which has a distribution
29 // over target values that is the prediction.  The tree structure, including the
30 // feature index at each node and distribution at each leaf, is chosen once when
31 // the tree is trained.
32 //
33 // Training involves starting with a set of training examples, each of which has
34 // features and a target value.  The tree is constructed recursively, starting
35 // with the root.  For the node being constructed, the training algorithm is
36 // given the portion of the training set that would reach the node, if it were
37 // sent down the tree in a similar fashion as described above.  It then
38 // considers assigning each (unused) feature index as the index to split the
39 // training examples at this node.  For each index |t|, it groups the training
40 // set into subsets, each of which consists of those examples with the same
41 // of the |i|-th feature.  It then computes a score for the split using the
42 // target values that ended up in each group.  The index with the best score is
43 // chosen for the split.
44 //
45 // For nominal features, we split the feature into all of its nominal values.
46 // This is somewhat nonstandard; one would normally convert to one-hot numeric
47 // features first.  See OneHotConverter if you'd like to do this.
48 //
49 // For numeric features, we choose a split point uniformly at random between its
50 // min and max values in the training data.  We do this because it's suitable
51 // for extra trees.  RandomForest trees want to select the best split point for
52 // each feature, rather than uniformly.  Either way, of course, we choose the
53 // best split among the (feature, split point) pairs we're considering.
54 //
55 // Also note that for one-hot features, these are the same thing.  So, this
56 // implementation is suitable for extra trees with numeric (possibly one hot)
57 // features, or RF with one-hot nominal features.  Note that non-one-hot nominal
58 // features probably work fine with RF too.  Numeric, non-binary features don't
59 // work with RF, unless one changes the split point selection.
60 //
61 // The training algorithm then recurses to build child nodes.  One child node is
62 // created for each observed value of the |i|-th feature in the training set.
63 // The child node is trained using the subset of the training set that shares
64 // that node's value for feature |i|.
65 //
66 // The above is a generic decision tree training algorithm.  A RandomTree
67 // differs from that mostly in how it selects the feature to split at each node
68 // during training.  Rather than computing a score for each feature, a
69 // RandomTree chooses a random subset of the features and only compares those.
70 //
71 // See https://en.wikipedia.org/wiki/Random_forest for information.  Note that
72 // this is just a single tree, not the whole forest.
73 //
74 // Note that this variant chooses split points randomly, as described by the
75 // ExtraTrees algorithm.  This is slightly different than RandomForest, which
76 // chooses split points to improve the split's score.
COMPONENT_EXPORT(LEARNING_IMPL)77 class COMPONENT_EXPORT(LEARNING_IMPL) RandomTreeTrainer
78     : public TrainingAlgorithm,
79       public HasRandomNumberGenerator {
80  public:
81   explicit RandomTreeTrainer(RandomNumberGenerator* rng = nullptr);
82   ~RandomTreeTrainer() override;
83 
84   // Train on all examples.  Calls |model_cb| with the trained model, which
85   // won't happen before this returns.
86   void Train(const LearningTask& task,
87              const TrainingData& examples,
88              TrainedModelCB model_cb) override;
89 
90  private:
91   // Train on the subset |training_idx|.
92   std::unique_ptr<Model> Train(const LearningTask& task,
93                                const TrainingData& examples,
94                                const std::vector<size_t>& training_idx);
95 
96   // Set of feature indices.
97   using FeatureSet = std::set<int>;
98 
99   // Information about a proposed split, and the training sets that would result
100   // from that split.
101   struct Split {
102     Split();
103     explicit Split(int index);
104     Split(Split&& rhs);
105     ~Split();
106 
107     Split& operator=(Split&& rhs);
108 
109     // Feature index to split on.
110     size_t split_index = 0;
111 
112     // For numeric splits, branch 0 is <= |split_point|, and 1 is > .
113     FeatureValue split_point;
114 
115     // Expected nats needed to compute the class, given that we're at this
116     // node in the tree.
117     // "nat" == entropy measured with natural log rather than base-2.
118     double nats_remaining = std::numeric_limits<double>::infinity();
119 
120     // Per-branch (i.e. per-child node) information about this split.
121     struct BranchInfo {
122       explicit BranchInfo();
123       BranchInfo(const BranchInfo& rhs) = delete;
124       BranchInfo(BranchInfo&& rhs);
125       ~BranchInfo();
126 
127       BranchInfo& operator=(const BranchInfo& rhs) = delete;
128       BranchInfo& operator=(BranchInfo&& rhs) = delete;
129 
130       // Training set for this branch of the split.  |training_idx| holds the
131       // indices that we're using out of our training data.
132       std::vector<size_t> training_idx;
133 
134       // Number of occurances of each target value in |training_data| along this
135       // branch of the split.
136       // This is a flat_map since we're likely to have a very small (e.g.,
137       // "true / "false") number of targets.
138       TargetHistogram target_histogram;
139     };
140 
141     // [feature value at this split] = info about which examples take this
142     // branch of the split.
143     std::map<FeatureValue, BranchInfo> branch_infos;
144 
145     DISALLOW_COPY_AND_ASSIGN(Split);
146   };
147 
148   // Build this node from |training_data|.  |used_set| is the set of features
149   // that we already used higher in the tree.
150   std::unique_ptr<Model> Build(const LearningTask& task,
151                                const TrainingData& training_data,
152                                const std::vector<size_t>& training_idx,
153                                const FeatureSet& used_set);
154 
155   // Compute and return a split of |training_data| on the |index|-th feature.
156   Split ConstructSplit(const LearningTask& task,
157                        const TrainingData& training_data,
158                        const std::vector<size_t>& training_idx,
159                        int index);
160 
161   // Fill in |nats_remaining| for |split| for a nominal target.
162   // |total_incoming_weight| is the total weight of all instances coming into
163   // the node that we're splitting.
164   void ComputeSplitScore_Nominal(Split* split, double total_incoming_weight);
165 
166   // Fill in |nats_remaining| for |split| for a numeric target.
167   void ComputeSplitScore_Numeric(Split* split, double total_incoming_weight);
168 
169   // Compute the split point for |training_data| for a nominal feature.
170   FeatureValue FindSplitPoint_Nominal(size_t index,
171                                       const TrainingData& training_data,
172                                       const std::vector<size_t>& training_idx);
173 
174   // Compute the split point for |training_data| for a numeric feature.
175   FeatureValue FindSplitPoint_Numeric(size_t index,
176                                       const TrainingData& training_data,
177                                       const std::vector<size_t>& training_idx);
178 
179   DISALLOW_COPY_AND_ASSIGN(RandomTreeTrainer);
180 };
181 
182 }  // namespace learning
183 }  // namespace media
184 
185 #endif  // MEDIA_LEARNING_IMPL_RANDOM_TREE_TRAINER_H_
186