1 /** 2 * @file tests/main_tests/decision_stump_test.cpp 3 * @author Manish Kumar 4 * 5 * Test mlpackMain() of decision_stump_main.cpp. 6 * 7 * mlpack is free software; you may redistribute it and/or modify it under the 8 * terms of the 3-clause BSD license. You should have received a copy of the 9 * 3-clause BSD license along with mlpack. If not, see 10 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 11 */ 12 #define BINDING_TYPE BINDING_TYPE_TEST 13 14 #include <mlpack/core.hpp> 15 static const std::string testName = "DecisionStump"; 16 17 #include <mlpack/core/util/mlpack_main.hpp> 18 #include <mlpack/methods/decision_stump/decision_stump_main.cpp> 19 #include "test_helper.hpp" 20 21 #include "../test_catch_tools.hpp" 22 #include "../catch.hpp" 23 24 using namespace mlpack; 25 26 struct DecisionStumpTestFixture 27 { 28 public: DecisionStumpTestFixtureDecisionStumpTestFixture29 DecisionStumpTestFixture() 30 { 31 // Cache in the options for this program. 32 IO::RestoreSettings(testName); 33 } 34 ~DecisionStumpTestFixtureDecisionStumpTestFixture35 ~DecisionStumpTestFixture() 36 { 37 // Clear the settings. 38 bindings::tests::CleanMemory(); 39 IO::ClearSettings(); 40 } 41 }; 42 43 /** 44 * Ensure that we get desired dimensions when both training 45 * data and labels are passed. 46 */ 47 TEST_CASE_METHOD(DecisionStumpTestFixture, "DecisionStumpOutputDimensionTest", 48 "[DecisionStumpMainTest][BindingTests]") 49 { 50 arma::mat inputData; 51 if (!data::Load("trainSet.csv", inputData)) 52 FAIL("Cannot load train dataset trainSet.csv!"); 53 54 // Get the labels out. 55 arma::Row<size_t> labels(inputData.n_cols); 56 for (size_t i = 0; i < inputData.n_cols; ++i) 57 labels[i] = inputData(inputData.n_rows - 1, i); 58 59 // Delete the last row containing labels from input dataset. 60 inputData.shed_row(inputData.n_rows - 1); 61 62 arma::mat testData; 63 if (!data::Load("testSet.csv", testData)) 64 FAIL("Cannot load test dataset testSet.csv!"); 65 66 // Delete the last row containing labels from test dataset. 67 testData.shed_row(testData.n_rows - 1); 68 69 size_t testSize = testData.n_cols; 70 71 // Input training data. 72 SetInputParam("training", std::move(inputData)); 73 SetInputParam("labels", std::move(labels)); 74 75 // Input test data. 76 SetInputParam("test", std::move(testData)); 77 78 mlpackMain(); 79 80 // Check that number of output points are equal to number of input points. 81 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize); 82 83 // Check prediction have only single row. 84 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1); 85 } 86 87 /** 88 * Check that last row of input file is used as labels 89 * when labels are not passed specifically and results 90 * are same from both label and labeless models. 91 */ 92 TEST_CASE_METHOD(DecisionStumpTestFixture, 93 "DecisionStumpLabelsLessDimensionTest", 94 "[DecisionStumpMainTest][BindingTests]") 95 { 96 // Train DS without providing labels. 97 arma::mat inputData; 98 if (!data::Load("trainSet.csv", inputData)) 99 FAIL("Cannot load train dataset trainSet.csv!"); 100 101 // Get the labels out. 102 arma::Row<size_t> labels(inputData.n_cols); 103 for (size_t i = 0; i < inputData.n_cols; ++i) 104 labels[i] = inputData(inputData.n_rows - 1, i); 105 106 arma::mat testData; 107 if (!data::Load("testSet.csv", testData)) 108 FAIL("Cannot load test dataset testSet.csv!"); 109 110 // Delete the last row containing labels from test dataset. 111 testData.shed_row(testData.n_rows - 1); 112 113 size_t testSize = testData.n_cols; 114 115 // Input training data. 116 SetInputParam("training", inputData); 117 118 // Input test data. 119 SetInputParam("test", testData); 120 121 mlpackMain(); 122 123 // Check that number of output points are equal to number of input points. 124 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize); 125 126 // Check prediction have only single row. 127 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1); 128 129 // Reset data passed. 130 IO::GetSingleton().Parameters()["training"].wasPassed = false; 131 IO::GetSingleton().Parameters()["test"].wasPassed = false; 132 133 // Store outputs. 134 arma::Row<size_t> predictions; 135 predictions = std::move(IO::GetParam<arma::Row<size_t>>("predictions")); 136 137 // Delete the previous model. 138 bindings::tests::CleanMemory(); 139 140 // Now train DS with labels provided. 141 142 // Delete last row of inputData. 143 inputData.shed_row(inputData.n_rows - 1); 144 145 // Input training data. 146 SetInputParam("training", std::move(inputData)); 147 SetInputParam("test", std::move(testData)); 148 // Pass Labels. 149 SetInputParam("labels", std::move(labels)); 150 151 mlpackMain(); 152 153 // Check that number of output points are equal to number of input points. 154 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize); 155 156 // Check prediction have only single row. 157 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1); 158 159 // Check that initial output and final output matrix 160 // from two models are same. 161 CheckMatrices(predictions, IO::GetParam<arma::Row<size_t>>("predictions")); 162 } 163 164 /** 165 * Ensure that saved model can be used again. 166 */ 167 TEST_CASE_METHOD(DecisionStumpTestFixture, "DecisionStumpModelReuseTest", 168 "[DecisionStumpMainTest][BindingTests]") 169 { 170 arma::mat inputData; 171 if (!data::Load("trainSet.csv", inputData)) 172 FAIL("Cannot load train dataset trainSet.csv!"); 173 174 arma::mat testData; 175 if (!data::Load("testSet.csv", testData)) 176 FAIL("Cannot load test dataset testSet.csv!"); 177 178 // Delete the last row containing labels from test dataset. 179 testData.shed_row(testData.n_rows - 1); 180 181 size_t testSize = testData.n_cols; 182 183 // Input training data. 184 SetInputParam("training", std::move(inputData)); 185 186 // Input test data. 187 SetInputParam("test", testData); 188 189 mlpackMain(); 190 191 arma::Row<size_t> predictions; 192 predictions = std::move(IO::GetParam<arma::Row<size_t>>("predictions")); 193 194 // Reset passed parameters. 195 IO::GetSingleton().Parameters()["training"].wasPassed = false; 196 IO::GetSingleton().Parameters()["test"].wasPassed = false; 197 198 // Input trained model. 199 SetInputParam("test", std::move(testData)); 200 SetInputParam("input_model", 201 std::move(IO::GetParam<DSModel*>("output_model"))); 202 203 mlpackMain(); 204 205 // Check that number of output points are equal to number of input points. 206 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize); 207 208 // Check predictions have only single row. 209 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1); 210 211 // Check that initial predictions and final predicitons matrix 212 // using saved model are same. 213 CheckMatrices(predictions, IO::GetParam<arma::Row<size_t>>("predictions")); 214 } 215 216 /** 217 * Ensure that bucket_size is always positive. 218 */ 219 TEST_CASE_METHOD(DecisionStumpTestFixture, "DecisionStumpBucketSizeTest", 220 "[DecisionStumpMainTest][BindingTests]") 221 { 222 arma::mat inputData; 223 if (!data::Load("trainSet.csv", inputData)) 224 FAIL("Cannot load train dataset trainSet.csv!"); 225 226 // Input training data. 227 SetInputParam("training", std::move(inputData)); 228 SetInputParam("bucket_size", (int) 0); 229 230 Log::Fatal.ignoreInput = true; 231 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error); 232 Log::Fatal.ignoreInput = false; 233 } 234 235 /** 236 * Make sure only one of training data or pre-trained model is passed. 237 */ 238 TEST_CASE_METHOD(DecisionStumpTestFixture, "DecisionStumpTrainingVerTest", 239 "[DecisionStumpMainTest][BindingTests]") 240 { 241 arma::mat inputData; 242 if (!data::Load("trainSet.csv", inputData)) 243 FAIL("Cannot load train dataset trainSet.csv!"); 244 245 // Input training data. 246 SetInputParam("training", std::move(inputData)); 247 248 mlpackMain(); 249 250 // Input pre-trained model. 251 SetInputParam("input_model", 252 std::move(IO::GetParam<DSModel*>("output_model"))); 253 254 Log::Fatal.ignoreInput = true; 255 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error); 256 Log::Fatal.ignoreInput = false; 257 } 258