1 /**
2 * @file methods/hoeffding_trees/hoeffding_tree_main.cpp
3 * @author Ryan Curtin
4 *
5 * A command-line executable that can build a streaming decision tree.
6 *
7 * mlpack is free software; you may redistribute it and/or modify it under the
8 * terms of the 3-clause BSD license. You should have received a copy of the
9 * 3-clause BSD license along with mlpack. If not, see
10 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11 */
12 #include <mlpack/prereqs.hpp>
13 #include <mlpack/core/util/io.hpp>
14 #include <mlpack/core/util/mlpack_main.hpp>
15
16 #include <mlpack/methods/hoeffding_trees/hoeffding_tree.hpp>
17 #include <mlpack/methods/hoeffding_trees/binary_numeric_split.hpp>
18 #include <mlpack/methods/hoeffding_trees/information_gain.hpp>
19 #include <mlpack/methods/hoeffding_trees/hoeffding_tree_model.hpp>
20 #include <queue>
21
22 using namespace std;
23 using namespace mlpack;
24 using namespace mlpack::tree;
25 using namespace mlpack::data;
26 using namespace mlpack::util;
27
28 // Program Name.
29 BINDING_NAME("Hoeffding trees");
30
31 // Short description.
32 BINDING_SHORT_DESC(
33 "An implementation of Hoeffding trees, a form of streaming decision tree "
34 "for classification. Given labeled data, a Hoeffding tree can be trained "
35 "and saved for later use, or a pre-trained Hoeffding tree can be used for "
36 "predicting the classifications of new points.");
37
38 // Long description.
39 BINDING_LONG_DESC(
40 "This program implements Hoeffding trees, a form of streaming decision tree"
41 " suited best for large (or streaming) datasets. This program supports "
42 "both categorical and numeric data. Given an input dataset, this program "
43 "is able to train the tree with numerous training options, and save the "
44 "model to a file. The program is also able to use a trained model or a "
45 "model from file in order to predict classes for a given test set."
46 "\n\n"
47 "The training file and associated labels are specified with the " +
48 PRINT_PARAM_STRING("training") + " and " + PRINT_PARAM_STRING("labels") +
49 " parameters, respectively. Optionally, if " +
50 PRINT_PARAM_STRING("labels") + " is not specified, the labels are assumed "
51 "to be the last dimension of the training dataset."
52 "\n\n"
53 "The training may be performed in batch mode "
54 "(like a typical decision tree algorithm) by specifying the " +
55 PRINT_PARAM_STRING("batch_mode") + " option, but this may not be the best "
56 "option for large datasets."
57 "\n\n"
58 "When a model is trained, it may be saved via the " +
59 PRINT_PARAM_STRING("output_model") + " output parameter. A model may be "
60 "loaded from file for further training or testing with the " +
61 PRINT_PARAM_STRING("input_model") + " parameter."
62 "\n\n"
63 "Test data may be specified with the " + PRINT_PARAM_STRING("test") + " "
64 "parameter, and if performance statistics are desired for that test set, "
65 "labels may be specified with the " + PRINT_PARAM_STRING("test_labels") +
66 " parameter. Predictions for each test point may be saved with the " +
67 PRINT_PARAM_STRING("predictions") + " output parameter, and class "
68 "probabilities for each prediction may be saved with the " +
69 PRINT_PARAM_STRING("probabilities") + " output parameter.");
70
71 // Example.
72 BINDING_EXAMPLE(
73 "For example, to train a Hoeffding tree with confidence 0.99 with data " +
74 PRINT_DATASET("dataset") + ", saving the trained tree to " +
75 PRINT_MODEL("tree") + ", the following command may be used:"
76 "\n\n" +
77 PRINT_CALL("hoeffding_tree", "training", "dataset", "confidence", 0.99,
78 "output_model", "tree") +
79 "\n\n"
80 "Then, this tree may be used to make predictions on the test set " +
81 PRINT_DATASET("test_set") + ", saving the predictions into " +
82 PRINT_DATASET("predictions") + " and the class probabilities into " +
83 PRINT_DATASET("class_probs") + " with the following command: "
84 "\n\n" +
85 PRINT_CALL("hoeffding_tree", "input_model", "tree", "test", "test_set",
86 "predictions", "predictions", "probabilities", "class_probs"));
87
88 // See also...
89 BINDING_SEE_ALSO("@decision_tree", "#decision_tree");
90 BINDING_SEE_ALSO("@random_forest", "#random_forest");
91 BINDING_SEE_ALSO("Mining High-Speed Data Streams (pdf)",
92 "http://dm.cs.washington.edu/papers/vfdt-kdd00.pdf");
93 BINDING_SEE_ALSO("mlpack::tree::HoeffdingTree class documentation",
94 "@doxygen/classmlpack_1_1tree_1_1HoeffdingTree.html");
95
96 PARAM_MATRIX_AND_INFO_IN("training", "Training dataset (may be categorical).",
97 "t");
98 PARAM_UROW_IN("labels", "Labels for training dataset.", "l");
99
100 PARAM_DOUBLE_IN("confidence", "Confidence before splitting (between 0 and 1).",
101 "c", 0.95);
102 PARAM_INT_IN("max_samples", "Maximum number of samples before splitting.", "n",
103 5000);
104 PARAM_INT_IN("min_samples", "Minimum number of samples before splitting.", "I",
105 100);
106
107 PARAM_MODEL_IN(HoeffdingTreeModel, "input_model", "Input trained Hoeffding tree"
108 " model.", "m");
109 PARAM_MODEL_OUT(HoeffdingTreeModel, "output_model", "Output for trained "
110 "Hoeffding tree model.", "M");
111
112 PARAM_MATRIX_AND_INFO_IN("test", "Testing dataset (may be categorical).", "T");
113 PARAM_UROW_IN("test_labels", "Labels of test data.", "L");
114 PARAM_UROW_OUT("predictions", "Matrix to output label predictions for test "
115 "data into.", "p");
116 PARAM_MATRIX_OUT("probabilities", "In addition to predicting labels, provide "
117 "rediction probabilities in this matrix.", "P");
118
119 PARAM_STRING_IN("numeric_split_strategy", "The splitting strategy to use for "
120 "numeric features: 'domingos' or 'binary'.", "N", "binary");
121 PARAM_FLAG("batch_mode", "If true, samples will be considered in batch instead "
122 "of as a stream. This generally results in better trees but at the cost of"
123 " memory usage and runtime.", "b");
124 PARAM_FLAG("info_gain", "If set, information gain is used instead of Gini "
125 "impurity for calculating Hoeffding bounds.", "i");
126 PARAM_INT_IN("passes", "Number of passes to take over the dataset.", "s", 1);
127
128 PARAM_INT_IN("bins", "If the 'domingos' split strategy is used, this specifies "
129 "the number of bins for each numeric split.", "B", 10);
130 PARAM_INT_IN("observations_before_binning", "If the 'domingos' split strategy "
131 "is used, this specifies the number of samples observed before binning is "
132 "performed.", "o", 100);
133
134 // Convenience typedef.
135 typedef tuple<DatasetInfo, arma::mat> TupleType;
136
mlpackMain()137 static void mlpackMain()
138 {
139 // Check input parameters for validity.
140 const string numericSplitStrategy =
141 IO::GetParam<string>("numeric_split_strategy");
142
143 RequireAtLeastOnePassed({ "training", "input_model" }, true);
144
145 RequireAtLeastOnePassed({ "output_model", "predictions", "probabilities",
146 "test_labels" }, false, "no output will be given");
147
148 ReportIgnoredParam({{ "test", false }}, "probabilities");
149 ReportIgnoredParam({{ "test", false }}, "predictions");
150
151 ReportIgnoredParam({{ "training", false }}, "batch_mode");
152 ReportIgnoredParam({{ "training", false }}, "passes");
153
154 if (IO::HasParam("test"))
155 {
156 RequireAtLeastOnePassed({ "predictions", "probabilities", "test_labels" },
157 false, "no output will be given");
158 }
159
160 RequireParamInSet<string>("numeric_split_strategy", { "domingos", "binary" },
161 true, "unrecognized numeric split strategy");
162
163 // Do we need to load a model or do we already have one?
164 HoeffdingTreeModel* model;
165 DatasetInfo datasetInfo;
166 arma::mat trainingSet;
167 arma::Row<size_t> labels;
168 if (IO::HasParam("input_model"))
169 {
170 model = IO::GetParam<HoeffdingTreeModel*>("input_model");
171 }
172 else
173 {
174 // Initialize a model.
175 if (!IO::HasParam("info_gain") && (numericSplitStrategy == "domingos"))
176 model = new HoeffdingTreeModel(HoeffdingTreeModel::GINI_HOEFFDING);
177 else if (!IO::HasParam("info_gain") && (numericSplitStrategy == "binary"))
178 model = new HoeffdingTreeModel(HoeffdingTreeModel::GINI_BINARY);
179 else if (IO::HasParam("info_gain") && (numericSplitStrategy == "domingos"))
180 model = new HoeffdingTreeModel(HoeffdingTreeModel::INFO_HOEFFDING);
181 else
182 model = new HoeffdingTreeModel(HoeffdingTreeModel::INFO_BINARY);
183 }
184
185 // Now, do we need to train?
186 if (IO::HasParam("training"))
187 {
188 // Load necessary parameters for training.
189 const double confidence = IO::GetParam<double>("confidence");
190 const size_t maxSamples = (size_t) IO::GetParam<int>("max_samples");
191 const size_t minSamples = (size_t) IO::GetParam<int>("min_samples");
192 bool batchTraining = IO::HasParam("batch_mode");
193 const size_t bins = (size_t) IO::GetParam<int>("bins");
194 const size_t observationsBeforeBinning = (size_t)
195 IO::GetParam<int>("observations_before_binning");
196 size_t passes = (size_t) IO::GetParam<int>("passes");
197 if (passes > 1)
198 batchTraining = false; // We already warned about this earlier.
199
200 // We need to train the model. First, load the data.
201 datasetInfo = std::move(std::get<0>(IO::GetParam<TupleType>("training")));
202 trainingSet = std::move(std::get<1>(IO::GetParam<TupleType>("training")));
203 for (size_t i = 0; i < trainingSet.n_rows; ++i)
204 Log::Info << datasetInfo.NumMappings(i) << " mappings in dimension "
205 << i << "." << endl;
206
207 if (IO::HasParam("labels"))
208 {
209 labels = std::move(IO::GetParam<arma::Row<size_t>>("labels"));
210 }
211 else
212 {
213 // Extract the labels from the last dimension of training set.
214 Log::Info << "Using the last dimension of training set as labels."
215 << endl;
216 labels = arma::conv_to<arma::Row<size_t>>::from(
217 trainingSet.row(trainingSet.n_rows - 1));
218 trainingSet.shed_row(trainingSet.n_rows - 1);
219 }
220
221 // Next, create the model with the right type. Then build the tree with the
222 // appropriate type of instantiated numeric split type. This is a little
223 // bit ugly. Maybe there is a nicer way to get this numeric split
224 // information to the trees, but this is ok for now.
225 Timer::Start("tree_training");
226
227 // Do we need to initialize a model?
228 if (!IO::HasParam("input_model"))
229 {
230 // Build the model.
231 model->BuildModel(trainingSet, datasetInfo, labels,
232 arma::max(labels) + 1, batchTraining, confidence, maxSamples,
233 100, minSamples, bins, observationsBeforeBinning);
234 --passes; // This model-building takes one pass.
235 }
236
237 // Now pass over the trees as many times as we need to.
238 if (batchTraining)
239 {
240 // We only need to do batch training if we've not already called
241 // BuildModel.
242 if (IO::HasParam("input_model"))
243 model->Train(trainingSet, labels, true);
244 }
245 else
246 {
247 for (size_t p = 0; p < passes; ++p)
248 model->Train(trainingSet, labels, false);
249 }
250
251 Timer::Stop("tree_training");
252 }
253
254 // Do we need to evaluate the training set error?
255 if (IO::HasParam("training"))
256 {
257 // Get training error.
258 arma::Row<size_t> predictions;
259 model->Classify(trainingSet, predictions);
260
261 size_t correct = 0;
262 for (size_t i = 0; i < labels.n_elem; ++i)
263 if (labels[i] == predictions[i])
264 ++correct;
265
266 Log::Info << correct << " out of " << labels.n_elem << " correct "
267 << "on training set (" << double(correct) / double(labels.n_elem) *
268 100.0 << ")." << endl;
269 }
270
271 // Get the number of nodes in the tree.
272 Log::Info << model->NumNodes() << " nodes in the tree." << endl;
273
274 // The tree is trained or loaded. Now do any testing if we need.
275 if (IO::HasParam("test"))
276 {
277 // Before loading, pre-set the dataset info by getting the raw parameter
278 // (that doesn't call data::Load()).
279 std::get<0>(IO::GetRawParam<TupleType>("test")) = datasetInfo;
280 arma::mat testSet = std::get<1>(IO::GetParam<TupleType>("test"));
281
282 arma::Row<size_t> predictions;
283 arma::rowvec probabilities;
284
285 Timer::Start("tree_testing");
286 model->Classify(testSet, predictions, probabilities);
287 Timer::Stop("tree_testing");
288
289 if (IO::HasParam("test_labels"))
290 {
291 arma::Row<size_t> testLabels =
292 std::move(IO::GetParam<arma::Row<size_t>>("test_labels"));
293
294 size_t correct = 0;
295 for (size_t i = 0; i < testLabels.n_elem; ++i)
296 {
297 if (predictions[i] == testLabels[i])
298 ++correct;
299 }
300 Log::Info << correct << " out of " << testLabels.n_elem << " correct "
301 << "on test set (" << double(correct) / double(testLabels.n_elem) *
302 100.0 << ")." << endl;
303 }
304
305 IO::GetParam<arma::Row<size_t>>("predictions") = std::move(predictions);
306 IO::GetParam<arma::mat>("probabilities") = std::move(probabilities);
307 }
308
309 // Check the accuracy on the training set.
310 IO::GetParam<HoeffdingTreeModel*>("output_model") = model;
311 }
312