1 /**
2  * @file tests/main_tests/decision_stump_test.cpp
3  * @author Manish Kumar
4  *
5  * Test mlpackMain() of decision_stump_main.cpp.
6  *
7  * mlpack is free software; you may redistribute it and/or modify it under the
8  * terms of the 3-clause BSD license.  You should have received a copy of the
9  * 3-clause BSD license along with mlpack.  If not, see
10  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11  */
12 #define BINDING_TYPE BINDING_TYPE_TEST
13 
14 #include <mlpack/core.hpp>
15 static const std::string testName = "DecisionStump";
16 
17 #include <mlpack/core/util/mlpack_main.hpp>
18 #include <mlpack/methods/decision_stump/decision_stump_main.cpp>
19 #include "test_helper.hpp"
20 
21 #include "../test_catch_tools.hpp"
22 #include "../catch.hpp"
23 
24 using namespace mlpack;
25 
26 struct DecisionStumpTestFixture
27 {
28  public:
DecisionStumpTestFixtureDecisionStumpTestFixture29   DecisionStumpTestFixture()
30   {
31     // Cache in the options for this program.
32     IO::RestoreSettings(testName);
33   }
34 
~DecisionStumpTestFixtureDecisionStumpTestFixture35   ~DecisionStumpTestFixture()
36   {
37     // Clear the settings.
38     bindings::tests::CleanMemory();
39     IO::ClearSettings();
40   }
41 };
42 
43 /**
44  * Ensure that we get desired dimensions when both training
45  * data and labels are passed.
46  */
47 TEST_CASE_METHOD(DecisionStumpTestFixture, "DecisionStumpOutputDimensionTest",
48                  "[DecisionStumpMainTest][BindingTests]")
49 {
50   arma::mat inputData;
51   if (!data::Load("trainSet.csv", inputData))
52     FAIL("Cannot load train dataset trainSet.csv!");
53 
54   // Get the labels out.
55   arma::Row<size_t> labels(inputData.n_cols);
56   for (size_t i = 0; i < inputData.n_cols; ++i)
57     labels[i] = inputData(inputData.n_rows - 1, i);
58 
59   // Delete the last row containing labels from input dataset.
60   inputData.shed_row(inputData.n_rows - 1);
61 
62   arma::mat testData;
63   if (!data::Load("testSet.csv", testData))
64     FAIL("Cannot load test dataset testSet.csv!");
65 
66   // Delete the last row containing labels from test dataset.
67   testData.shed_row(testData.n_rows - 1);
68 
69   size_t testSize = testData.n_cols;
70 
71   // Input training data.
72   SetInputParam("training", std::move(inputData));
73   SetInputParam("labels", std::move(labels));
74 
75   // Input test data.
76   SetInputParam("test", std::move(testData));
77 
78   mlpackMain();
79 
80   // Check that number of output points are equal to number of input points.
81   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize);
82 
83   // Check prediction have only single row.
84   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1);
85 }
86 
87 /**
88  * Check that last row of input file is used as labels
89  * when labels are not passed specifically and results
90  * are same from both label and labeless models.
91  */
92 TEST_CASE_METHOD(DecisionStumpTestFixture,
93                  "DecisionStumpLabelsLessDimensionTest",
94                  "[DecisionStumpMainTest][BindingTests]")
95 {
96   // Train DS without providing labels.
97   arma::mat inputData;
98   if (!data::Load("trainSet.csv", inputData))
99     FAIL("Cannot load train dataset trainSet.csv!");
100 
101   // Get the labels out.
102   arma::Row<size_t> labels(inputData.n_cols);
103   for (size_t i = 0; i < inputData.n_cols; ++i)
104     labels[i] = inputData(inputData.n_rows - 1, i);
105 
106   arma::mat testData;
107   if (!data::Load("testSet.csv", testData))
108     FAIL("Cannot load test dataset testSet.csv!");
109 
110   // Delete the last row containing labels from test dataset.
111   testData.shed_row(testData.n_rows - 1);
112 
113   size_t testSize = testData.n_cols;
114 
115   // Input training data.
116   SetInputParam("training", inputData);
117 
118   // Input test data.
119   SetInputParam("test", testData);
120 
121   mlpackMain();
122 
123   // Check that number of output points are equal to number of input points.
124   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize);
125 
126   // Check prediction have only single row.
127   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1);
128 
129   // Reset data passed.
130   IO::GetSingleton().Parameters()["training"].wasPassed = false;
131   IO::GetSingleton().Parameters()["test"].wasPassed = false;
132 
133   // Store outputs.
134   arma::Row<size_t> predictions;
135   predictions = std::move(IO::GetParam<arma::Row<size_t>>("predictions"));
136 
137   // Delete the previous model.
138   bindings::tests::CleanMemory();
139 
140   // Now train DS with labels provided.
141 
142   // Delete last row of inputData.
143   inputData.shed_row(inputData.n_rows - 1);
144 
145   // Input training data.
146   SetInputParam("training", std::move(inputData));
147   SetInputParam("test", std::move(testData));
148   // Pass Labels.
149   SetInputParam("labels", std::move(labels));
150 
151   mlpackMain();
152 
153   // Check that number of output points are equal to number of input points.
154   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize);
155 
156   // Check prediction have only single row.
157   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1);
158 
159   // Check that initial output and final output matrix
160   // from two models are same.
161   CheckMatrices(predictions, IO::GetParam<arma::Row<size_t>>("predictions"));
162 }
163 
164 /**
165  * Ensure that saved model can be used again.
166  */
167 TEST_CASE_METHOD(DecisionStumpTestFixture, "DecisionStumpModelReuseTest",
168                  "[DecisionStumpMainTest][BindingTests]")
169 {
170   arma::mat inputData;
171   if (!data::Load("trainSet.csv", inputData))
172     FAIL("Cannot load train dataset trainSet.csv!");
173 
174   arma::mat testData;
175   if (!data::Load("testSet.csv", testData))
176     FAIL("Cannot load test dataset testSet.csv!");
177 
178   // Delete the last row containing labels from test dataset.
179   testData.shed_row(testData.n_rows - 1);
180 
181   size_t testSize = testData.n_cols;
182 
183   // Input training data.
184   SetInputParam("training", std::move(inputData));
185 
186   // Input test data.
187   SetInputParam("test", testData);
188 
189   mlpackMain();
190 
191   arma::Row<size_t> predictions;
192   predictions = std::move(IO::GetParam<arma::Row<size_t>>("predictions"));
193 
194   // Reset passed parameters.
195   IO::GetSingleton().Parameters()["training"].wasPassed = false;
196   IO::GetSingleton().Parameters()["test"].wasPassed = false;
197 
198   // Input trained model.
199   SetInputParam("test", std::move(testData));
200   SetInputParam("input_model",
201                 std::move(IO::GetParam<DSModel*>("output_model")));
202 
203   mlpackMain();
204 
205   // Check that number of output points are equal to number of input points.
206   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize);
207 
208   // Check predictions have only single row.
209   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1);
210 
211   // Check that initial predictions and final predicitons matrix
212   // using saved model are same.
213   CheckMatrices(predictions, IO::GetParam<arma::Row<size_t>>("predictions"));
214 }
215 
216 /**
217  * Ensure that bucket_size is always positive.
218  */
219 TEST_CASE_METHOD(DecisionStumpTestFixture, "DecisionStumpBucketSizeTest",
220                  "[DecisionStumpMainTest][BindingTests]")
221 {
222   arma::mat inputData;
223   if (!data::Load("trainSet.csv", inputData))
224     FAIL("Cannot load train dataset trainSet.csv!");
225 
226   // Input training data.
227   SetInputParam("training", std::move(inputData));
228   SetInputParam("bucket_size", (int) 0);
229 
230   Log::Fatal.ignoreInput = true;
231   REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error);
232   Log::Fatal.ignoreInput = false;
233 }
234 
235 /**
236  * Make sure only one of training data or pre-trained model is passed.
237  */
238 TEST_CASE_METHOD(DecisionStumpTestFixture, "DecisionStumpTrainingVerTest",
239                  "[DecisionStumpMainTest][BindingTests]")
240 {
241   arma::mat inputData;
242   if (!data::Load("trainSet.csv", inputData))
243     FAIL("Cannot load train dataset trainSet.csv!");
244 
245   // Input training data.
246   SetInputParam("training", std::move(inputData));
247 
248   mlpackMain();
249 
250   // Input pre-trained model.
251   SetInputParam("input_model",
252                 std::move(IO::GetParam<DSModel*>("output_model")));
253 
254   Log::Fatal.ignoreInput = true;
255   REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error);
256   Log::Fatal.ignoreInput = false;
257 }
258