1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifndef MEDIA_LEARNING_IMPL_RANDOM_TREE_TRAINER_H_
6 #define MEDIA_LEARNING_IMPL_RANDOM_TREE_TRAINER_H_
7
8 #include <limits>
9 #include <map>
10 #include <memory>
11 #include <set>
12
13 #include "base/component_export.h"
14 #include "base/containers/flat_map.h"
15 #include "base/macros.h"
16 #include "media/learning/common/learning_task.h"
17 #include "media/learning/impl/random_number_generator.h"
18 #include "media/learning/impl/training_algorithm.h"
19
20 namespace media {
21 namespace learning {
22
23 // Trains RandomTree decision tree classifier / regressor.
24 //
25 // Decision trees, including RandomTree, classify instances as follows. Each
26 // non-leaf node is marked with a feature number |i|. The value of the |i|-th
27 // feature of the instance is then used to select which outgoing edge is
28 // traversed. This repeats until arriving at a leaf, which has a distribution
29 // over target values that is the prediction. The tree structure, including the
30 // feature index at each node and distribution at each leaf, is chosen once when
31 // the tree is trained.
32 //
33 // Training involves starting with a set of training examples, each of which has
34 // features and a target value. The tree is constructed recursively, starting
35 // with the root. For the node being constructed, the training algorithm is
36 // given the portion of the training set that would reach the node, if it were
37 // sent down the tree in a similar fashion as described above. It then
38 // considers assigning each (unused) feature index as the index to split the
39 // training examples at this node. For each index |t|, it groups the training
40 // set into subsets, each of which consists of those examples with the same
41 // of the |i|-th feature. It then computes a score for the split using the
42 // target values that ended up in each group. The index with the best score is
43 // chosen for the split.
44 //
45 // For nominal features, we split the feature into all of its nominal values.
46 // This is somewhat nonstandard; one would normally convert to one-hot numeric
47 // features first. See OneHotConverter if you'd like to do this.
48 //
49 // For numeric features, we choose a split point uniformly at random between its
50 // min and max values in the training data. We do this because it's suitable
51 // for extra trees. RandomForest trees want to select the best split point for
52 // each feature, rather than uniformly. Either way, of course, we choose the
53 // best split among the (feature, split point) pairs we're considering.
54 //
55 // Also note that for one-hot features, these are the same thing. So, this
56 // implementation is suitable for extra trees with numeric (possibly one hot)
57 // features, or RF with one-hot nominal features. Note that non-one-hot nominal
58 // features probably work fine with RF too. Numeric, non-binary features don't
59 // work with RF, unless one changes the split point selection.
60 //
61 // The training algorithm then recurses to build child nodes. One child node is
62 // created for each observed value of the |i|-th feature in the training set.
63 // The child node is trained using the subset of the training set that shares
64 // that node's value for feature |i|.
65 //
66 // The above is a generic decision tree training algorithm. A RandomTree
67 // differs from that mostly in how it selects the feature to split at each node
68 // during training. Rather than computing a score for each feature, a
69 // RandomTree chooses a random subset of the features and only compares those.
70 //
71 // See https://en.wikipedia.org/wiki/Random_forest for information. Note that
72 // this is just a single tree, not the whole forest.
73 //
74 // Note that this variant chooses split points randomly, as described by the
75 // ExtraTrees algorithm. This is slightly different than RandomForest, which
76 // chooses split points to improve the split's score.
COMPONENT_EXPORT(LEARNING_IMPL)77 class COMPONENT_EXPORT(LEARNING_IMPL) RandomTreeTrainer
78 : public TrainingAlgorithm,
79 public HasRandomNumberGenerator {
80 public:
81 explicit RandomTreeTrainer(RandomNumberGenerator* rng = nullptr);
82 ~RandomTreeTrainer() override;
83
84 // Train on all examples. Calls |model_cb| with the trained model, which
85 // won't happen before this returns.
86 void Train(const LearningTask& task,
87 const TrainingData& examples,
88 TrainedModelCB model_cb) override;
89
90 private:
91 // Train on the subset |training_idx|.
92 std::unique_ptr<Model> Train(const LearningTask& task,
93 const TrainingData& examples,
94 const std::vector<size_t>& training_idx);
95
96 // Set of feature indices.
97 using FeatureSet = std::set<int>;
98
99 // Information about a proposed split, and the training sets that would result
100 // from that split.
101 struct Split {
102 Split();
103 explicit Split(int index);
104 Split(Split&& rhs);
105 ~Split();
106
107 Split& operator=(Split&& rhs);
108
109 // Feature index to split on.
110 size_t split_index = 0;
111
112 // For numeric splits, branch 0 is <= |split_point|, and 1 is > .
113 FeatureValue split_point;
114
115 // Expected nats needed to compute the class, given that we're at this
116 // node in the tree.
117 // "nat" == entropy measured with natural log rather than base-2.
118 double nats_remaining = std::numeric_limits<double>::infinity();
119
120 // Per-branch (i.e. per-child node) information about this split.
121 struct BranchInfo {
122 explicit BranchInfo();
123 BranchInfo(const BranchInfo& rhs) = delete;
124 BranchInfo(BranchInfo&& rhs);
125 ~BranchInfo();
126
127 BranchInfo& operator=(const BranchInfo& rhs) = delete;
128 BranchInfo& operator=(BranchInfo&& rhs) = delete;
129
130 // Training set for this branch of the split. |training_idx| holds the
131 // indices that we're using out of our training data.
132 std::vector<size_t> training_idx;
133
134 // Number of occurances of each target value in |training_data| along this
135 // branch of the split.
136 // This is a flat_map since we're likely to have a very small (e.g.,
137 // "true / "false") number of targets.
138 TargetHistogram target_histogram;
139 };
140
141 // [feature value at this split] = info about which examples take this
142 // branch of the split.
143 std::map<FeatureValue, BranchInfo> branch_infos;
144
145 DISALLOW_COPY_AND_ASSIGN(Split);
146 };
147
148 // Build this node from |training_data|. |used_set| is the set of features
149 // that we already used higher in the tree.
150 std::unique_ptr<Model> Build(const LearningTask& task,
151 const TrainingData& training_data,
152 const std::vector<size_t>& training_idx,
153 const FeatureSet& used_set);
154
155 // Compute and return a split of |training_data| on the |index|-th feature.
156 Split ConstructSplit(const LearningTask& task,
157 const TrainingData& training_data,
158 const std::vector<size_t>& training_idx,
159 int index);
160
161 // Fill in |nats_remaining| for |split| for a nominal target.
162 // |total_incoming_weight| is the total weight of all instances coming into
163 // the node that we're splitting.
164 void ComputeSplitScore_Nominal(Split* split, double total_incoming_weight);
165
166 // Fill in |nats_remaining| for |split| for a numeric target.
167 void ComputeSplitScore_Numeric(Split* split, double total_incoming_weight);
168
169 // Compute the split point for |training_data| for a nominal feature.
170 FeatureValue FindSplitPoint_Nominal(size_t index,
171 const TrainingData& training_data,
172 const std::vector<size_t>& training_idx);
173
174 // Compute the split point for |training_data| for a numeric feature.
175 FeatureValue FindSplitPoint_Numeric(size_t index,
176 const TrainingData& training_data,
177 const std::vector<size_t>& training_idx);
178
179 DISALLOW_COPY_AND_ASSIGN(RandomTreeTrainer);
180 };
181
182 } // namespace learning
183 } // namespace media
184
185 #endif // MEDIA_LEARNING_IMPL_RANDOM_TREE_TRAINER_H_
186