1 /**
2 * @file methods/softmax_regression/softmax_regression_main.cpp
3 *
4 * Main program for softmax regression.
5 *
6 * mlpack is free software; you may redistribute it and/or modify it under the
7 * terms of the 3-clause BSD license. You should have received a copy of the
8 * 3-clause BSD license along with mlpack. If not, see
9 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
10 */
11 #include <mlpack/prereqs.hpp>
12 #include <mlpack/core/util/io.hpp>
13 #include <mlpack/core/util/mlpack_main.hpp>
14
15 #include <mlpack/methods/softmax_regression/softmax_regression.hpp>
16 #include <ensmallen.hpp>
17
18 #include <memory>
19 #include <set>
20
21 using namespace std;
22 using namespace mlpack;
23 using namespace mlpack::regression;
24 using namespace mlpack::util;
25
26 // Program Name.
27 BINDING_NAME("Softmax Regression");
28
29 // Short description.
30 BINDING_SHORT_DESC(
31 "An implementation of softmax regression for classification, which is a "
32 "multiclass generalization of logistic regression. Given labeled data, a "
33 "softmax regression model can be trained and saved for future use, or, a "
34 "pre-trained softmax regression model can be used for classification of "
35 "new points.");
36
37 // Long description.
38 BINDING_LONG_DESC(
39 "This program performs softmax regression, a generalization of logistic "
40 "regression to the multiclass case, and has support for L2 regularization. "
41 " The program is able to train a model, load an existing model, and give "
42 "predictions (and optionally their accuracy) for test data."
43 "\n\n"
44 "Training a softmax regression model is done by giving a file of training "
45 "points with the " + PRINT_PARAM_STRING("training") + " parameter and their"
46 " corresponding labels with the " + PRINT_PARAM_STRING("labels") +
47 " parameter. The number of classes can be manually specified with the " +
48 PRINT_PARAM_STRING("number_of_classes") + " parameter, and the maximum " +
49 "number of iterations of the L-BFGS optimizer can be specified with the " +
50 PRINT_PARAM_STRING("max_iterations") + " parameter. The L2 regularization "
51 "constant can be specified with the " + PRINT_PARAM_STRING("lambda") +
52 " parameter and if an intercept term is not desired in the model, the " +
53 PRINT_PARAM_STRING("no_intercept") + " parameter can be specified."
54 "\n\n"
55 "The trained model can be saved with the " +
56 PRINT_PARAM_STRING("output_model") + " output parameter. If training is not"
57 " desired, but only testing is, a model can be loaded with the " +
58 PRINT_PARAM_STRING("input_model") + " parameter. At the current time, a "
59 "loaded model cannot be trained further, so specifying both " +
60 PRINT_PARAM_STRING("input_model") + " and " +
61 PRINT_PARAM_STRING("training") + " is not allowed."
62 "\n\n"
63 "The program is also able to evaluate a model on test data. A test dataset"
64 " can be specified with the " + PRINT_PARAM_STRING("test") + " parameter. "
65 "Class predictions can be saved with the " +
66 PRINT_PARAM_STRING("predictions") + " output parameter. If labels are "
67 "specified for the test data with the " +
68 PRINT_PARAM_STRING("test_labels") + " parameter, then the program will "
69 "print the accuracy of the predictions on the given test set and its "
70 "corresponding labels.");
71
72 // Example.
73 BINDING_EXAMPLE(
74 "For example, to train a softmax regression model on the data " +
75 PRINT_DATASET("dataset") + " with labels " + PRINT_DATASET("labels") +
76 " with a maximum of 1000 iterations for training, saving the trained model "
77 "to " + PRINT_MODEL("sr_model") + ", the following command can be used: "
78 "\n\n" +
79 PRINT_CALL("softmax_regression", "training", "dataset", "labels", "labels",
80 "output_model", "sr_model") +
81 "\n\n"
82 "Then, to use " + PRINT_MODEL("sr_model") + " to classify the test points "
83 "in " + PRINT_DATASET("test_points") + ", saving the output predictions to"
84 " " + PRINT_DATASET("predictions") + ", the following command can be used:"
85 "\n\n" +
86 PRINT_CALL("softmax_regression", "input_model", "sr_model", "test",
87 "test_points", "predictions", "predictions"));
88
89 // See also...
90 BINDING_SEE_ALSO("@logistic_regression", "#logistic_regression");
91 BINDING_SEE_ALSO("@random_forest", "#random_forest");
92 BINDING_SEE_ALSO("Multinomial logistic regression (softmax regression) on "
93 "Wikipedia",
94 "https://en.wikipedia.org/wiki/Multinomial_logistic_regression");
95 BINDING_SEE_ALSO("mlpack::regression::SoftmaxRegression C++ class "
96 "documentation",
97 "@doxygen/classmlpack_1_1regression_1_1SoftmaxRegression.html");
98
99 // Required options.
100 PARAM_MATRIX_IN("training", "A matrix containing the training set (the matrix "
101 "of predictors, X).", "t");
102 PARAM_UROW_IN("labels", "A matrix containing labels (0 or 1) for the points "
103 "in the training set (y). The labels must order as a row.", "l");
104
105 // Model loading/saving.
106 PARAM_MODEL_IN(SoftmaxRegression, "input_model", "File containing existing "
107 "model (parameters).", "m");
108 PARAM_MODEL_OUT(SoftmaxRegression, "output_model", "File to save trained "
109 "softmax regression model to.", "M");
110
111 // Testing.
112 PARAM_MATRIX_IN("test", "Matrix containing test dataset.", "T");
113 PARAM_UROW_OUT("predictions", "Matrix to save predictions for test dataset "
114 "into.", "p");
115 PARAM_UROW_IN("test_labels", "Matrix containing test labels.", "L");
116
117 // Softmax configuration options.
118 PARAM_INT_IN("max_iterations", "Maximum number of iterations before "
119 "termination.", "n", 400);
120
121 PARAM_INT_IN("number_of_classes", "Number of classes for classification; if "
122 "unspecified (or 0), the number of classes found in the labels will be "
123 "used.", "c", 0);
124
125 PARAM_DOUBLE_IN("lambda", "L2-regularization constant", "r", 0.0001);
126
127 PARAM_FLAG("no_intercept", "Do not add the intercept term to the model.", "N");
128
129 // Count the number of classes in the given labels (if numClasses == 0).
130 size_t CalculateNumberOfClasses(const size_t numClasses,
131 const arma::Row<size_t>& trainLabels);
132
133 // Test the accuracy of the model.
134 template<typename Model>
135 void TestClassifyAcc(const size_t numClasses, const Model& model);
136
137 // Build the softmax model given the parameters.
138 template<typename Model>
139 Model* TrainSoftmax(const size_t maxIterations);
140
mlpackMain()141 static void mlpackMain()
142 {
143 const int maxIterations = IO::GetParam<int>("max_iterations");
144
145 // One of inputFile and modelFile must be specified.
146 RequireOnlyOnePassed({ "input_model", "training" }, true);
147 if (IO::HasParam("training"))
148 {
149 RequireAtLeastOnePassed({ "labels" }, true, "if training data is specified,"
150 " labels must also be specified");
151 }
152 ReportIgnoredParam({{ "training", false }}, "labels");
153 ReportIgnoredParam({{ "training", false }}, "max_iterations");
154 ReportIgnoredParam({{ "training", false }}, "number_of_classes");
155 ReportIgnoredParam({{ "training", false }}, "lambda");
156 ReportIgnoredParam({{ "training", false }}, "no_intercept");
157
158 RequireParamValue<int>("max_iterations", [](int x) { return x >= 0; }, true,
159 "maximum number of iterations must be greater than or equal to 0");
160 RequireParamValue<double>("lambda", [](double x) { return x >= 0.0; }, true,
161 "lambda penalty parameter must be greater than or equal to 0");
162 RequireParamValue<int>("number_of_classes", [](int x) { return x >= 0; },
163 true, "number of classes must be greater than or "
164 "equal to 0 (equal to 0 in case of unspecified.)");
165
166 // Make sure we have an output file of some sort.
167 RequireAtLeastOnePassed({ "output_model", "predictions" }, false, "no results"
168 " will be saved");
169
170 SoftmaxRegression* sm = TrainSoftmax<SoftmaxRegression>(maxIterations);
171
172 TestClassifyAcc(sm->NumClasses(), *sm);
173
174 IO::GetParam<SoftmaxRegression*>("output_model") = sm;
175 }
176
CalculateNumberOfClasses(const size_t numClasses,const arma::Row<size_t> & trainLabels)177 size_t CalculateNumberOfClasses(const size_t numClasses,
178 const arma::Row<size_t>& trainLabels)
179 {
180 if (numClasses == 0)
181 {
182 const set<size_t> unique_labels(begin(trainLabels),
183 end(trainLabels));
184 return unique_labels.size();
185 }
186 else
187 {
188 return numClasses;
189 }
190 }
191
192 template<typename Model>
TestClassifyAcc(size_t numClasses,const Model & model)193 void TestClassifyAcc(size_t numClasses, const Model& model)
194 {
195 using namespace mlpack;
196
197 // If there is no test set, there is nothing to test on.
198 if (!IO::HasParam("test"))
199 {
200 ReportIgnoredParam({{ "test", false }}, "test_labels");
201 ReportIgnoredParam({{ "test", false }}, "predictions");
202
203 return;
204 }
205
206 // Get the test dataset, and get predictions.
207 arma::mat testData = std::move(IO::GetParam<arma::mat>("test"));
208
209 arma::Row<size_t> predictLabels;
210 model.Classify(testData, predictLabels);
211
212 // Calculate accuracy, if desired.
213 if (IO::HasParam("test_labels"))
214 {
215 arma::Row<size_t> testLabels =
216 std::move(IO::GetParam<arma::Row<size_t>>("test_labels"));
217
218 if (testData.n_cols != testLabels.n_elem)
219 {
220 Log::Fatal << "Test data given with " << PRINT_PARAM_STRING("test")
221 << " has " << testData.n_cols << " points, but labels in "
222 << PRINT_PARAM_STRING("test_labels") << " have " << testLabels.n_elem
223 << " labels!" << endl;
224 }
225
226 vector<size_t> bingoLabels(numClasses, 0);
227 vector<size_t> labelSize(numClasses, 0);
228 for (arma::uword i = 0; i != predictLabels.n_elem; ++i)
229 {
230 if (predictLabels(i) == testLabels(i))
231 {
232 ++bingoLabels[testLabels(i)];
233 }
234 ++labelSize[testLabels(i)];
235 }
236
237 size_t totalBingo = 0;
238 for (size_t i = 0; i != bingoLabels.size(); ++i)
239 {
240 Log::Info << "Accuracy for points with label " << i << " is "
241 << (bingoLabels[i] / static_cast<double>(labelSize[i])) << " ("
242 << bingoLabels[i] << " of " << labelSize[i] << ")." << endl;
243 totalBingo += bingoLabels[i];
244 }
245
246 Log::Info << "Total accuracy for all points is "
247 << (totalBingo) / static_cast<double>(predictLabels.n_elem) << " ("
248 << totalBingo << " of " << predictLabels.n_elem << ")." << endl;
249 }
250 // Save predictions, if desired.
251 if (IO::HasParam("predictions"))
252 IO::GetParam<arma::Row<size_t>>("predictions") = std::move(predictLabels);
253 }
254
255 template<typename Model>
TrainSoftmax(const size_t maxIterations)256 Model* TrainSoftmax(const size_t maxIterations)
257 {
258 using namespace mlpack;
259
260 Model* sm;
261 if (IO::HasParam("input_model"))
262 {
263 sm = IO::GetParam<Model*>("input_model");
264 }
265 else
266 {
267 arma::mat trainData = std::move(IO::GetParam<arma::mat>("training"));
268 arma::Row<size_t> trainLabels =
269 std::move(IO::GetParam<arma::Row<size_t>>("labels"));
270
271 if (trainData.n_cols != trainLabels.n_elem)
272 Log::Fatal << "Samples of input_data should same as the size of "
273 << "input_label." << endl;
274
275 const size_t numClasses = CalculateNumberOfClasses(
276 (size_t) IO::GetParam<int>("number_of_classes"), trainLabels);
277
278 const bool intercept = IO::HasParam("no_intercept") ? false : true;
279
280 const size_t numBasis = 5;
281 ens::L_BFGS optimizer(numBasis, maxIterations);
282 sm = new Model(trainData, trainLabels, numClasses,
283 IO::GetParam<double>("lambda"), intercept, std::move(optimizer));
284 }
285 return sm;
286 }
287