1 /**
2 * @file methods/decision_tree/decision_tree_main.cpp
3 * @author Ryan Curtin
4 *
5 * A command-line program to build a decision tree.
6 *
7 * mlpack is free software; you may redistribute it and/or modify it under the
8 * terms of the 3-clause BSD license. You should have received a copy of the
9 * 3-clause BSD license along with mlpack. If not, see
10 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11 */
12 #include <mlpack/prereqs.hpp>
13 #include <mlpack/core/util/io.hpp>
14 #include <mlpack/core/util/mlpack_main.hpp>
15 #include "decision_tree.hpp"
16
17 using namespace std;
18 using namespace mlpack;
19 using namespace mlpack::tree;
20 using namespace mlpack::data;
21 using namespace mlpack::util;
22
23 // Program Name.
24 BINDING_NAME("Decision tree");
25
26 // Short description.
27 BINDING_SHORT_DESC(
28 "An implementation of an ID3-style decision tree for classification, which"
29 " supports categorical data. Given labeled data with numeric or "
30 "categorical features, a decision tree can be trained and saved; or, an "
31 "existing decision tree can be used for classification on new points.");
32
33 // Long description.
34 BINDING_LONG_DESC(
35 "Train and evaluate using a decision tree. Given a dataset containing "
36 "numeric or categorical features, and associated labels for each point in "
37 "the dataset, this program can train a decision tree on that data."
38 "\n\n"
39 "The training set and associated labels are specified with the " +
40 PRINT_PARAM_STRING("training") + " and " + PRINT_PARAM_STRING("labels") +
41 " parameters, respectively. The labels should be in the range [0, "
42 "num_classes - 1]. Optionally, if " +
43 PRINT_PARAM_STRING("labels") + " is not specified, the labels are assumed "
44 "to be the last dimension of the training dataset."
45 "\n\n"
46 "When a model is trained, the " + PRINT_PARAM_STRING("output_model") + " "
47 "output parameter may be used to save the trained model. A model may be "
48 "loaded for predictions with the " + PRINT_PARAM_STRING("input_model") +
49 " parameter. The " + PRINT_PARAM_STRING("input_model") + " parameter "
50 "may not be specified when the " + PRINT_PARAM_STRING("training") + " "
51 "parameter is specified. The " + PRINT_PARAM_STRING("minimum_leaf_size") +
52 " parameter specifies the minimum number of training points that must fall"
53 " into each leaf for it to be split. The " +
54 PRINT_PARAM_STRING("minimum_gain_split") + " parameter specifies "
55 "the minimum gain that is needed for the node to split. The " +
56 PRINT_PARAM_STRING("maximum_depth") + " parameter specifies "
57 "the maximum depth of the tree. If " +
58 PRINT_PARAM_STRING("print_training_error") + " is specified, the training "
59 "error will be printed."
60 "\n\n"
61 "Test data may be specified with the " + PRINT_PARAM_STRING("test") + " "
62 "parameter, and if performance numbers are desired for that test set, "
63 "labels may be specified with the " + PRINT_PARAM_STRING("test_labels") +
64 " parameter. Predictions for each test point may be saved via the " +
65 PRINT_PARAM_STRING("predictions") + " output parameter. Class "
66 "probabilities for each prediction may be saved with the " +
67 PRINT_PARAM_STRING("probabilities") + " output parameter.");
68
69 // Example.
70 BINDING_EXAMPLE(
71 "For example, to train a decision tree with a minimum leaf size of 20 on "
72 "the dataset contained in " + PRINT_DATASET("data") + " with labels " +
73 PRINT_DATASET("labels") + ", saving the output model to " +
74 PRINT_MODEL("tree") + " and printing the training error, one could "
75 "call"
76 "\n\n" +
77 PRINT_CALL("decision_tree", "training", "data", "labels", "labels",
78 "output_model", "tree", "minimum_leaf_size", 20, "minimum_gain_split",
79 1e-3, "print_training_accuracy", true) +
80 "\n\n"
81 "Then, to use that model to classify points in " +
82 PRINT_DATASET("test_set") + " and print the test error given the "
83 "labels " + PRINT_DATASET("test_labels") + " using that model, while "
84 "saving the predictions for each point to " +
85 PRINT_DATASET("predictions") + ", one could call "
86 "\n\n" +
87 PRINT_CALL("decision_tree", "input_model", "tree", "test", "test_set",
88 "test_labels", "test_labels", "predictions", "predictions"));
89
90 // See also...
91 BINDING_SEE_ALSO("Decision stump", "#decision_stump");
92 BINDING_SEE_ALSO("Random forest", "#random_forest");
93 BINDING_SEE_ALSO("Decision trees on Wikipedia",
94 "https://en.wikipedia.org/wiki/Decision_tree_learning");
95 BINDING_SEE_ALSO("Induction of Decision Trees (pdf)",
96 "https://link.springer.com/content/pdf/10.1007/BF00116251.pdf");
97 BINDING_SEE_ALSO("mlpack::tree::DecisionTree class documentation",
98 "@doxygen/classmlpack_1_1tree_1_1DecisionTree.html");
99
100 // Datasets.
101 PARAM_MATRIX_AND_INFO_IN("training", "Training dataset (may be categorical).",
102 "t");
103 PARAM_UROW_IN("labels", "Training labels.", "l");
104 PARAM_MATRIX_AND_INFO_IN("test", "Testing dataset (may be categorical).", "T");
105 PARAM_MATRIX_IN("weights", "The weight of labels", "w");
106 PARAM_UROW_IN("test_labels", "Test point labels, if accuracy calculation "
107 "is desired.", "L");
108
109 // Training parameters.
110 PARAM_INT_IN("minimum_leaf_size", "Minimum number of points in a leaf.", "n",
111 20);
112 PARAM_DOUBLE_IN("minimum_gain_split", "Minimum gain for node splitting.", "g",
113 1e-7);
114 PARAM_INT_IN("maximum_depth", "Maximum depth of the tree (0 means no limit).",
115 "D", 0);
116 // This is deprecated and should be removed in mlpack 4.0.0.
117 PARAM_FLAG("print_training_error", "Print the training error (deprecated; will "
118 "be removed in mlpack 4.0.0).", "e");
119 PARAM_FLAG("print_training_accuracy", "Print the training accuracy.", "a");
120
121 // Output parameters.
122 PARAM_MATRIX_OUT("probabilities", "Class probabilities for each test point.",
123 "P");
124 PARAM_UROW_OUT("predictions", "Class predictions for each test point.", "p");
125
126 /**
127 * This is the class that we will serialize. It is a pretty simple wrapper
128 * around DecisionTree<>.
129 */
130 class DecisionTreeModel
131 {
132 public:
133 // The tree itself, left public for direct access by this program.
134 DecisionTree<> tree;
135 DatasetInfo info;
136
137 // Create the model.
DecisionTreeModel()138 DecisionTreeModel() { /* Nothing to do. */ }
139
140 // Serialize the model.
141 template<typename Archive>
serialize(Archive & ar,const unsigned int)142 void serialize(Archive& ar, const unsigned int /* version */)
143 {
144 ar & BOOST_SERIALIZATION_NVP(tree);
145 ar & BOOST_SERIALIZATION_NVP(info);
146 }
147 };
148
149 // Models.
150 PARAM_MODEL_IN(DecisionTreeModel, "input_model", "Pre-trained decision tree, "
151 "to be used with test points.", "m");
152 PARAM_MODEL_OUT(DecisionTreeModel, "output_model", "Output for trained decision"
153 " tree.", "M");
154
155 // Convenience typedef.
156 typedef tuple<DatasetInfo, arma::mat> TupleType;
157
mlpackMain()158 static void mlpackMain()
159 {
160 // Check parameters.
161 RequireOnlyOnePassed({ "training", "input_model" }, true);
162 ReportIgnoredParam({{ "test", false }}, "test_labels");
163 RequireAtLeastOnePassed({ "output_model", "probabilities", "predictions" },
164 false, "no output will be saved");
165 ReportIgnoredParam({{ "training", false }}, "print_training_accuracy");
166
167 ReportIgnoredParam({{ "test", false }}, "predictions");
168 ReportIgnoredParam({{ "test", false }}, "predictions");
169
170 RequireParamValue<int>("minimum_leaf_size", [](int x) { return x > 0; }, true,
171 "leaf size must be positive");
172
173 RequireParamValue<int>("maximum_depth", [](int x) { return x >= 0; }, true,
174 "maximum depth must not be negative");
175
176 RequireParamValue<double>("minimum_gain_split", [](double x)
177 { return (x > 0.0 && x < 1.0); }, true,
178 "gain split must be a fraction in range [0,1]");
179
180 if (IO::HasParam("print_training_error"))
181 {
182 Log::Warn << "The option " << PRINT_PARAM_STRING("print_training_error")
183 << " is deprecated and will be removed in mlpack 4.0.0." << std::endl;
184 }
185
186 // Load the model or build the tree.
187 DecisionTreeModel* model;
188 arma::mat trainingSet;
189 arma::Row<size_t> labels;
190
191 if (IO::HasParam("training"))
192 {
193 model = new DecisionTreeModel();
194 model->info = std::move(std::get<0>(IO::GetParam<TupleType>("training")));
195 trainingSet = std::move(std::get<1>(IO::GetParam<TupleType>("training")));
196 if (IO::HasParam("labels"))
197 {
198 labels = std::move(IO::GetParam<arma::Row<size_t>>("labels"));
199 }
200 else
201 {
202 // Extract the labels as the last
203 Log::Info << "Using the last dimension of training set as labels."
204 << endl;
205 labels = arma::conv_to<arma::Row<size_t>>::from(
206 trainingSet.row(trainingSet.n_rows - 1));
207 trainingSet.shed_row(trainingSet.n_rows - 1);
208 }
209
210 const size_t numClasses = arma::max(arma::max(labels)) + 1;
211
212 // Now build the tree.
213 const size_t minLeafSize = (size_t) IO::GetParam<int>("minimum_leaf_size");
214 const size_t maxDepth = (size_t) IO::GetParam<int>("maximum_depth");
215 const double minimumGainSplit =
216 (double) IO::GetParam<double>("minimum_gain_split");
217
218 // Create decision tree with weighted labels.
219 if (IO::HasParam("weights"))
220 {
221 arma::Row<double> weights =
222 std::move(IO::GetParam<arma::Mat<double>>("weights"));
223 if (IO::HasParam("print_training_error") ||
224 IO::HasParam("print_training_accuracy"))
225 {
226 model->tree = DecisionTree<>(trainingSet, model->info, labels,
227 numClasses, std::move(weights), minLeafSize, minimumGainSplit,
228 maxDepth);
229 }
230 else
231 {
232 model->tree = DecisionTree<>(std::move(trainingSet), model->info,
233 std::move(labels), numClasses, std::move(weights), minLeafSize,
234 minimumGainSplit, maxDepth);
235 }
236 }
237 else
238 {
239 if (IO::HasParam("print_training_error"))
240 {
241 model->tree = DecisionTree<>(trainingSet, model->info, labels,
242 numClasses, minLeafSize, minimumGainSplit, maxDepth);
243 }
244 else
245 {
246 model->tree = DecisionTree<>(std::move(trainingSet), model->info,
247 std::move(labels), numClasses, minLeafSize, minimumGainSplit,
248 maxDepth);
249 }
250 }
251
252 // Do we need to print training error?
253 if (IO::HasParam("print_training_error") ||
254 IO::HasParam("print_training_accuracy"))
255 {
256 arma::Row<size_t> predictions;
257 arma::mat probabilities;
258
259 model->tree.Classify(trainingSet, predictions, probabilities);
260
261 size_t correct = 0;
262 for (size_t i = 0; i < trainingSet.n_cols; ++i)
263 if (predictions[i] == labels[i])
264 ++correct;
265
266 // Print number of correct points.
267 Log::Info << double(correct) / double(trainingSet.n_cols) * 100 << "% "
268 << "correct on training set (" << correct << " / "
269 << trainingSet.n_cols << ")." << endl;
270 }
271 }
272 else
273 {
274 model = IO::GetParam<DecisionTreeModel*>("input_model");
275 }
276
277 // Do we need to get predictions?
278 if (IO::HasParam("test"))
279 {
280 std::get<0>(IO::GetRawParam<TupleType>("test")) = model->info;
281 arma::mat testPoints = std::get<1>(IO::GetParam<TupleType>("test"));
282
283 arma::Row<size_t> predictions;
284 arma::mat probabilities;
285
286 model->tree.Classify(testPoints, predictions, probabilities);
287
288 // Do we need to calculate accuracy?
289 if (IO::HasParam("test_labels"))
290 {
291 arma::Row<size_t> testLabels =
292 std::move(IO::GetParam<arma::Row<size_t>>("test_labels"));
293
294 size_t correct = 0;
295 for (size_t i = 0; i < testPoints.n_cols; ++i)
296 if (predictions[i] == testLabels[i])
297 ++correct;
298
299 // Print number of correct points.
300 Log::Info << double(correct) / double(testPoints.n_cols) * 100 << "% "
301 << "correct on test set (" << correct << " / " << testPoints.n_cols
302 << ")." << endl;
303 }
304
305 // Do we need to save outputs?
306 IO::GetParam<arma::Row<size_t>>("predictions") = predictions;
307 IO::GetParam<arma::mat>("probabilities") = probabilities;
308 }
309
310 // Do we need to save the model?
311 IO::GetParam<DecisionTreeModel*>("output_model") = model;
312 }
313