1 /**
2  * @file methods/decision_tree/decision_tree_main.cpp
3  * @author Ryan Curtin
4  *
5  * A command-line program to build a decision tree.
6  *
7  * mlpack is free software; you may redistribute it and/or modify it under the
8  * terms of the 3-clause BSD license.  You should have received a copy of the
9  * 3-clause BSD license along with mlpack.  If not, see
10  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11  */
12 #include <mlpack/prereqs.hpp>
13 #include <mlpack/core/util/io.hpp>
14 #include <mlpack/core/util/mlpack_main.hpp>
15 #include "decision_tree.hpp"
16 
17 using namespace std;
18 using namespace mlpack;
19 using namespace mlpack::tree;
20 using namespace mlpack::data;
21 using namespace mlpack::util;
22 
23 // Program Name.
24 BINDING_NAME("Decision tree");
25 
26 // Short description.
27 BINDING_SHORT_DESC(
28     "An implementation of an ID3-style decision tree for classification, which"
29     " supports categorical data.  Given labeled data with numeric or "
30     "categorical features, a decision tree can be trained and saved; or, an "
31     "existing decision tree can be used for classification on new points.");
32 
33 // Long description.
34 BINDING_LONG_DESC(
35     "Train and evaluate using a decision tree.  Given a dataset containing "
36     "numeric or categorical features, and associated labels for each point in "
37     "the dataset, this program can train a decision tree on that data."
38     "\n\n"
39     "The training set and associated labels are specified with the " +
40     PRINT_PARAM_STRING("training") + " and " + PRINT_PARAM_STRING("labels") +
41     " parameters, respectively.  The labels should be in the range [0, "
42     "num_classes - 1]. Optionally, if " +
43     PRINT_PARAM_STRING("labels") + " is not specified, the labels are assumed "
44     "to be the last dimension of the training dataset."
45     "\n\n"
46     "When a model is trained, the " + PRINT_PARAM_STRING("output_model") + " "
47     "output parameter may be used to save the trained model.  A model may be "
48     "loaded for predictions with the " + PRINT_PARAM_STRING("input_model") +
49     " parameter.  The " + PRINT_PARAM_STRING("input_model") + " parameter "
50     "may not be specified when the " + PRINT_PARAM_STRING("training") + " "
51     "parameter is specified.  The " + PRINT_PARAM_STRING("minimum_leaf_size") +
52     " parameter specifies the minimum number of training points that must fall"
53     " into each leaf for it to be split.  The " +
54     PRINT_PARAM_STRING("minimum_gain_split") + " parameter specifies "
55     "the minimum gain that is needed for the node to split.  The " +
56     PRINT_PARAM_STRING("maximum_depth") + " parameter specifies "
57     "the maximum depth of the tree.  If " +
58     PRINT_PARAM_STRING("print_training_error") + " is specified, the training "
59     "error will be printed."
60     "\n\n"
61     "Test data may be specified with the " + PRINT_PARAM_STRING("test") + " "
62     "parameter, and if performance numbers are desired for that test set, "
63     "labels may be specified with the " + PRINT_PARAM_STRING("test_labels") +
64     " parameter.  Predictions for each test point may be saved via the " +
65     PRINT_PARAM_STRING("predictions") + " output parameter.  Class "
66     "probabilities for each prediction may be saved with the " +
67     PRINT_PARAM_STRING("probabilities") + " output parameter.");
68 
69 // Example.
70 BINDING_EXAMPLE(
71     "For example, to train a decision tree with a minimum leaf size of 20 on "
72     "the dataset contained in " + PRINT_DATASET("data") + " with labels " +
73     PRINT_DATASET("labels") + ", saving the output model to " +
74     PRINT_MODEL("tree") + " and printing the training error, one could "
75     "call"
76     "\n\n" +
77     PRINT_CALL("decision_tree", "training", "data", "labels", "labels",
78         "output_model", "tree", "minimum_leaf_size", 20, "minimum_gain_split",
79         1e-3, "print_training_accuracy", true) +
80     "\n\n"
81     "Then, to use that model to classify points in " +
82     PRINT_DATASET("test_set") + " and print the test error given the "
83     "labels " + PRINT_DATASET("test_labels") + " using that model, while "
84     "saving the predictions for each point to " +
85     PRINT_DATASET("predictions") + ", one could call "
86     "\n\n" +
87     PRINT_CALL("decision_tree", "input_model", "tree", "test", "test_set",
88         "test_labels", "test_labels", "predictions", "predictions"));
89 
90 // See also...
91 BINDING_SEE_ALSO("Decision stump", "#decision_stump");
92 BINDING_SEE_ALSO("Random forest", "#random_forest");
93 BINDING_SEE_ALSO("Decision trees on Wikipedia",
94         "https://en.wikipedia.org/wiki/Decision_tree_learning");
95 BINDING_SEE_ALSO("Induction of Decision Trees (pdf)",
96         "https://link.springer.com/content/pdf/10.1007/BF00116251.pdf");
97 BINDING_SEE_ALSO("mlpack::tree::DecisionTree class documentation",
98         "@doxygen/classmlpack_1_1tree_1_1DecisionTree.html");
99 
100 // Datasets.
101 PARAM_MATRIX_AND_INFO_IN("training", "Training dataset (may be categorical).",
102     "t");
103 PARAM_UROW_IN("labels", "Training labels.", "l");
104 PARAM_MATRIX_AND_INFO_IN("test", "Testing dataset (may be categorical).", "T");
105 PARAM_MATRIX_IN("weights", "The weight of labels", "w");
106 PARAM_UROW_IN("test_labels", "Test point labels, if accuracy calculation "
107     "is desired.", "L");
108 
109 // Training parameters.
110 PARAM_INT_IN("minimum_leaf_size", "Minimum number of points in a leaf.", "n",
111     20);
112 PARAM_DOUBLE_IN("minimum_gain_split", "Minimum gain for node splitting.", "g",
113     1e-7);
114 PARAM_INT_IN("maximum_depth", "Maximum depth of the tree (0 means no limit).",
115     "D", 0);
116 // This is deprecated and should be removed in mlpack 4.0.0.
117 PARAM_FLAG("print_training_error", "Print the training error (deprecated; will "
118       "be removed in mlpack 4.0.0).", "e");
119 PARAM_FLAG("print_training_accuracy", "Print the training accuracy.", "a");
120 
121 // Output parameters.
122 PARAM_MATRIX_OUT("probabilities", "Class probabilities for each test point.",
123     "P");
124 PARAM_UROW_OUT("predictions", "Class predictions for each test point.", "p");
125 
126 /**
127  * This is the class that we will serialize.  It is a pretty simple wrapper
128  * around DecisionTree<>.
129  */
130 class DecisionTreeModel
131 {
132  public:
133   // The tree itself, left public for direct access by this program.
134   DecisionTree<> tree;
135   DatasetInfo info;
136 
137   // Create the model.
DecisionTreeModel()138   DecisionTreeModel() { /* Nothing to do. */ }
139 
140   // Serialize the model.
141   template<typename Archive>
serialize(Archive & ar,const unsigned int)142   void serialize(Archive& ar, const unsigned int /* version */)
143   {
144     ar & BOOST_SERIALIZATION_NVP(tree);
145     ar & BOOST_SERIALIZATION_NVP(info);
146   }
147 };
148 
149 // Models.
150 PARAM_MODEL_IN(DecisionTreeModel, "input_model", "Pre-trained decision tree, "
151     "to be used with test points.", "m");
152 PARAM_MODEL_OUT(DecisionTreeModel, "output_model", "Output for trained decision"
153     " tree.", "M");
154 
155 // Convenience typedef.
156 typedef tuple<DatasetInfo, arma::mat> TupleType;
157 
mlpackMain()158 static void mlpackMain()
159 {
160   // Check parameters.
161   RequireOnlyOnePassed({ "training", "input_model" }, true);
162   ReportIgnoredParam({{ "test", false }}, "test_labels");
163   RequireAtLeastOnePassed({ "output_model", "probabilities", "predictions" },
164       false, "no output will be saved");
165   ReportIgnoredParam({{ "training", false }}, "print_training_accuracy");
166 
167   ReportIgnoredParam({{ "test", false }}, "predictions");
168   ReportIgnoredParam({{ "test", false }}, "predictions");
169 
170   RequireParamValue<int>("minimum_leaf_size", [](int x) { return x > 0; }, true,
171       "leaf size must be positive");
172 
173   RequireParamValue<int>("maximum_depth", [](int x) { return x >= 0; }, true,
174       "maximum depth must not be negative");
175 
176   RequireParamValue<double>("minimum_gain_split", [](double x)
177                          { return (x > 0.0 && x < 1.0); }, true,
178                          "gain split must be a fraction in range [0,1]");
179 
180   if (IO::HasParam("print_training_error"))
181   {
182     Log::Warn << "The option " << PRINT_PARAM_STRING("print_training_error")
183         << " is deprecated and will be removed in mlpack 4.0.0." << std::endl;
184   }
185 
186   // Load the model or build the tree.
187   DecisionTreeModel* model;
188   arma::mat trainingSet;
189   arma::Row<size_t> labels;
190 
191   if (IO::HasParam("training"))
192   {
193     model = new DecisionTreeModel();
194     model->info = std::move(std::get<0>(IO::GetParam<TupleType>("training")));
195     trainingSet = std::move(std::get<1>(IO::GetParam<TupleType>("training")));
196     if (IO::HasParam("labels"))
197     {
198       labels = std::move(IO::GetParam<arma::Row<size_t>>("labels"));
199     }
200     else
201     {
202       // Extract the labels as the last
203       Log::Info << "Using the last dimension of training set as labels."
204           << endl;
205       labels = arma::conv_to<arma::Row<size_t>>::from(
206           trainingSet.row(trainingSet.n_rows - 1));
207       trainingSet.shed_row(trainingSet.n_rows - 1);
208     }
209 
210     const size_t numClasses = arma::max(arma::max(labels)) + 1;
211 
212     // Now build the tree.
213     const size_t minLeafSize = (size_t) IO::GetParam<int>("minimum_leaf_size");
214     const size_t maxDepth = (size_t) IO::GetParam<int>("maximum_depth");
215     const double minimumGainSplit =
216                            (double) IO::GetParam<double>("minimum_gain_split");
217 
218     // Create decision tree with weighted labels.
219     if (IO::HasParam("weights"))
220     {
221       arma::Row<double> weights =
222           std::move(IO::GetParam<arma::Mat<double>>("weights"));
223       if (IO::HasParam("print_training_error") ||
224           IO::HasParam("print_training_accuracy"))
225       {
226         model->tree = DecisionTree<>(trainingSet, model->info, labels,
227             numClasses, std::move(weights), minLeafSize, minimumGainSplit,
228             maxDepth);
229       }
230       else
231       {
232         model->tree = DecisionTree<>(std::move(trainingSet), model->info,
233             std::move(labels), numClasses, std::move(weights), minLeafSize,
234             minimumGainSplit, maxDepth);
235       }
236     }
237     else
238     {
239       if (IO::HasParam("print_training_error"))
240       {
241         model->tree = DecisionTree<>(trainingSet, model->info, labels,
242             numClasses, minLeafSize, minimumGainSplit, maxDepth);
243       }
244       else
245       {
246         model->tree = DecisionTree<>(std::move(trainingSet), model->info,
247             std::move(labels), numClasses, minLeafSize, minimumGainSplit,
248             maxDepth);
249       }
250     }
251 
252     // Do we need to print training error?
253     if (IO::HasParam("print_training_error") ||
254         IO::HasParam("print_training_accuracy"))
255     {
256       arma::Row<size_t> predictions;
257       arma::mat probabilities;
258 
259       model->tree.Classify(trainingSet, predictions, probabilities);
260 
261       size_t correct = 0;
262       for (size_t i = 0; i < trainingSet.n_cols; ++i)
263         if (predictions[i] == labels[i])
264           ++correct;
265 
266       // Print number of correct points.
267       Log::Info << double(correct) / double(trainingSet.n_cols) * 100 << "% "
268           << "correct on training set (" << correct << " / "
269           << trainingSet.n_cols << ")." << endl;
270     }
271   }
272   else
273   {
274     model = IO::GetParam<DecisionTreeModel*>("input_model");
275   }
276 
277   // Do we need to get predictions?
278   if (IO::HasParam("test"))
279   {
280     std::get<0>(IO::GetRawParam<TupleType>("test")) = model->info;
281     arma::mat testPoints = std::get<1>(IO::GetParam<TupleType>("test"));
282 
283     arma::Row<size_t> predictions;
284     arma::mat probabilities;
285 
286     model->tree.Classify(testPoints, predictions, probabilities);
287 
288     // Do we need to calculate accuracy?
289     if (IO::HasParam("test_labels"))
290     {
291       arma::Row<size_t> testLabels =
292           std::move(IO::GetParam<arma::Row<size_t>>("test_labels"));
293 
294       size_t correct = 0;
295       for (size_t i = 0; i < testPoints.n_cols; ++i)
296         if (predictions[i] == testLabels[i])
297           ++correct;
298 
299       // Print number of correct points.
300       Log::Info << double(correct) / double(testPoints.n_cols) * 100 << "% "
301           << "correct on test set (" << correct << " / " << testPoints.n_cols
302           << ")." << endl;
303     }
304 
305     // Do we need to save outputs?
306     IO::GetParam<arma::Row<size_t>>("predictions") = predictions;
307     IO::GetParam<arma::mat>("probabilities") = probabilities;
308   }
309 
310   // Do we need to save the model?
311   IO::GetParam<DecisionTreeModel*>("output_model") = model;
312 }
313