1 /** 2 * @file tests/main_tests/hmm_generate_test.cpp 3 * @author Daivik Nema 4 * 5 * Test mlpackMain() of hmm_generate_main.cpp 6 * 7 * mlpack is free software; you may redistribute it and/or modify it under the 8 * terms of the 3-clause BSD license. You should have received a copy of the 9 * 3-clause BSD license along with mlpack. If not, see 10 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 11 */ 12 #include <string> 13 14 #define BINDING_TYPE BINDING_TYPE_TEST 15 static const std::string testName = "HMMGenerate"; 16 17 #include <mlpack/core.hpp> 18 #include <mlpack/core/util/mlpack_main.hpp> 19 #include "test_helper.hpp" 20 #include <mlpack/methods/hmm/hmm_model.hpp> 21 #include <mlpack/methods/hmm/hmm.hpp> 22 #include <mlpack/methods/hmm/hmm_generate_main.cpp> 23 24 #include "../catch.hpp" 25 #include "../test_catch_tools.hpp" 26 27 #include "hmm_test_utils.hpp" 28 29 using namespace mlpack; 30 31 struct HMMGenerateTestFixture 32 { 33 public: HMMGenerateTestFixtureHMMGenerateTestFixture34 HMMGenerateTestFixture() 35 { 36 // Cache in the options for this program. 37 IO::RestoreSettings(testName); 38 } 39 ~HMMGenerateTestFixtureHMMGenerateTestFixture40 ~HMMGenerateTestFixture() 41 { 42 // Clear the settings. 43 bindings::tests::CleanMemory(); 44 IO::ClearSettings(); 45 } 46 }; 47 48 TEST_CASE_METHOD(HMMGenerateTestFixture, 49 "HMMGenerateDiscreteHMMCheckDimensionsTest", 50 "[HMMGenerateMainTest][BindingTests]") 51 { 52 // Load data to train a discrete HMM model with. 53 arma::mat inp; 54 data::Load("obs1.csv", inp); 55 std::vector<arma::mat> trainSeq = {inp}; 56 57 // Initialize and train a discrete HMM model. 58 HMMModel* h = new HMMModel(DiscreteHMM); 59 h->PerformAction<InitHMMModel, std::vector<arma::mat>>(&trainSeq); 60 h->PerformAction<TrainHMMModel, std::vector<arma::mat>>(&trainSeq); 61 62 // Now that we have a trained HMM model, we can use it to generate a sequence 63 // of states and observations - using the hmm_generate utility. 64 // Load the input model to be used for inference and the length of sequence 65 // to be generated. 66 int length = 3; 67 SetInputParam("model", h); 68 SetInputParam("length", length); 69 70 // Call to hmm_generate_main. 71 mlpackMain(); 72 73 // Get the generated observation sequence. Ensure that the generated sequence 74 // has the correct length (as provided in the input). 75 arma::mat obsSeq = IO::GetParam<arma::mat>("output"); 76 REQUIRE(obsSeq.n_cols == (size_t)length); 77 REQUIRE(obsSeq.n_rows == (size_t)1); 78 REQUIRE(obsSeq.n_elem == (size_t)length); 79 80 // Get the generated state sequence. Ensure that the generated sequence 81 // has the correct length (as provided in the input). 82 arma::Mat<size_t> stateSeq = IO::GetParam<arma::Mat<size_t>>("state"); 83 REQUIRE(stateSeq.n_cols == (size_t)length); 84 REQUIRE(stateSeq.n_rows == (size_t)1); 85 REQUIRE(stateSeq.n_elem == (size_t)length); 86 } 87 88 TEST_CASE_METHOD(HMMGenerateTestFixture, 89 "HMMGenerateGaussianHMMCheckDimensionsTest", 90 "[HMMGenerateMainTest][BindingTests]") 91 { 92 // Load data to train a gaussian HMM model with. 93 arma::mat inp; 94 data::Load("obs1.csv", inp); 95 std::vector<arma::mat> trainSeq = {inp}; 96 97 // Initialize and train a gaussian HMM model. 98 HMMModel* h = new HMMModel(GaussianHMM); 99 h->PerformAction<InitHMMModel, std::vector<arma::mat>>(&trainSeq); 100 h->PerformAction<TrainHMMModel, std::vector<arma::mat>>(&trainSeq); 101 102 // Now that we have a trained HMM model, we can use it to generate a sequence 103 // of states and observations - using the hmm_generate utility. 104 // Load the input model to be used for inference and the length of sequence 105 // to be generated. 106 int length = 3; 107 SetInputParam("model", h); 108 SetInputParam("length", length); 109 110 // Call to hmm_generate_main. 111 mlpackMain(); 112 113 // Get the generated observation sequence. Ensure that the generated sequence 114 // has the correct length (as provided in the input). 115 arma::mat obsSeq = IO::GetParam<arma::mat>("output"); 116 REQUIRE(obsSeq.n_cols == (size_t)length); 117 REQUIRE(obsSeq.n_rows == (size_t)1); 118 REQUIRE(obsSeq.n_elem == (size_t)length); 119 120 // Get the generated state sequence. Ensure that the generated sequence 121 // has the correct length (as provided in the input). 122 arma::Mat<size_t> stateSeq = IO::GetParam<arma::Mat<size_t>>("state"); 123 REQUIRE(stateSeq.n_cols == (size_t)length); 124 REQUIRE(stateSeq.n_rows == (size_t)1); 125 REQUIRE(stateSeq.n_elem == (size_t)length); 126 } 127 128 TEST_CASE_METHOD(HMMGenerateTestFixture, 129 "HMMGenerateGMMHMMCheckDimensionsTest", 130 "[HMMGenerateMainTest][BindingTests]") 131 { 132 // Initialize and train a GMM HMM model. 133 HMMModel* h = new HMMModel(GaussianMixtureModelHMM); 134 *(h->GMMHMM()) = HMM<GMM>(2, GMM(2, 2)); 135 136 // Manually set the components. 137 h->GMMHMM()->Transition() = arma::mat("0.40 0.60; 0.60 0.40"); 138 h->GMMHMM()->Emission().resize(2); 139 h->GMMHMM()->Emission()[0] = GMM(2, 2); 140 h->GMMHMM()->Emission()[0].Weights() = arma::vec("0.3 0.7"); 141 h->GMMHMM()->Emission()[0].Component(0) = GaussianDistribution("4.25 3.10", 142 "1.00 0.20; 0.20 0.89"); 143 h->GMMHMM()->Emission()[0].Component(1) = GaussianDistribution("7.10 5.01", 144 "1.00 0.00; 0.00 1.01"); 145 h->GMMHMM()->Emission()[1] = GMM(2, 2); 146 h->GMMHMM()->Emission()[1].Weights() = arma::vec("0.20 0.80"); 147 h->GMMHMM()->Emission()[1].Component(0) = GaussianDistribution("-3.00 -6.12", 148 "1.00 0.00; 0.00 1.00"); 149 h->GMMHMM()->Emission()[1].Component(1) = GaussianDistribution("-4.25 -2.12", 150 "1.50 0.60; 0.60 1.20"); 151 152 // Now that we have a trained HMM model, we can use it to generate a sequence 153 // of states and observations - using the hmm_generate utility. 154 // Load the input model to be used for inference and the length of sequence 155 // to be generated. 156 int length = 3; 157 SetInputParam("model", h); 158 SetInputParam("length", length); 159 160 // Call to hmm_generate_main 161 mlpackMain(); 162 163 // Get the generated observation sequence. Ensure that the generated sequence 164 // has the correct length (as provided in the input). 165 arma::mat obsSeq = IO::GetParam<arma::mat>("output"); 166 REQUIRE(obsSeq.n_cols == (size_t) length); 167 REQUIRE(obsSeq.n_rows == (size_t) 2); 168 REQUIRE(obsSeq.n_elem == (size_t) length * 2); 169 170 // Get the generated state sequence. Ensure that the generated sequence 171 // has the correct length (as provided in the input). 172 arma::Mat<size_t> stateSeq = IO::GetParam<arma::Mat<size_t>>("state"); 173 REQUIRE(stateSeq.n_cols == (size_t) length); 174 REQUIRE(stateSeq.n_rows == (size_t) 1); 175 REQUIRE(stateSeq.n_elem == (size_t) length); 176 } 177 178 TEST_CASE_METHOD(HMMGenerateTestFixture, 179 "HMMGenerateDiagonalGMMHMMCheckDimensionsTest", 180 "[HMMGenerateMainTest][BindingTests]") 181 { 182 // Initialize and train a DiagonalGMM HMM model. 183 HMMModel* h = new HMMModel(DiagonalGaussianMixtureModelHMM); 184 *(h->DiagGMMHMM()) = HMM<DiagonalGMM>(2, DiagonalGMM(2, 2)); 185 186 // Manually set the components. 187 h->DiagGMMHMM()->Transition() = arma::mat("0.30 0.70; 0.70 0.30"); 188 h->DiagGMMHMM()->Emission().resize(2); 189 h->DiagGMMHMM()->Emission()[0] = DiagonalGMM(2, 2); 190 h->DiagGMMHMM()->Emission()[0].Weights() = arma::vec("0.2 0.8"); 191 h->DiagGMMHMM()->Emission()[0].Component(0) = DiagonalGaussianDistribution( 192 "2.75 1.60", "0.50 0.50"); 193 h->DiagGMMHMM()->Emission()[0].Component(1) = DiagonalGaussianDistribution( 194 "6.15 2.51", "1.00 1.50"); 195 h->DiagGMMHMM()->Emission()[1] = DiagonalGMM(2, 2); 196 h->DiagGMMHMM()->Emission()[1].Weights() = arma::vec("0.4 0.6"); 197 h->DiagGMMHMM()->Emission()[1].Component(0) = DiagonalGaussianDistribution( 198 "-1.00 -3.42", "0.20 1.00"); 199 h->DiagGMMHMM()->Emission()[1].Component(1) = DiagonalGaussianDistribution( 200 "-3.10 -5.05", "1.20 0.80"); 201 202 // Now that we have a trained HMM model, we can use it to generate a sequence 203 // of states and observations - using the hmm_generate utility. 204 // Load the input model to be used for inference and the length of sequence 205 // to be generated. 206 int length = 3; 207 SetInputParam("model", h); 208 SetInputParam("length", length); 209 210 // Call to hmm_generate_main. 211 mlpackMain(); 212 213 // Get the generated observation sequence. Ensure that the generated sequence 214 // has the correct length (as provided in the input). 215 arma::mat obsSeq = IO::GetParam<arma::mat>("output"); 216 REQUIRE(obsSeq.n_cols == (size_t) length); 217 REQUIRE(obsSeq.n_rows == (size_t) 2); 218 REQUIRE(obsSeq.n_elem == (size_t) length * 2); 219 220 // Get the generated state sequence. Ensure that the generated sequence 221 // has the correct length (as provided in the input). 222 arma::Mat<size_t> stateSeq = IO::GetParam<arma::Mat<size_t>>("state"); 223 REQUIRE(stateSeq.n_cols == (size_t) length); 224 REQUIRE(stateSeq.n_rows == (size_t) 1); 225 REQUIRE(stateSeq.n_elem == (size_t) length); 226 } 227 228 TEST_CASE_METHOD(HMMGenerateTestFixture, 229 "HMMGenerateLengthPositiveTest", 230 "[HMMGenerateMainTest][BindingTests]") 231 { 232 // Load data to train a Gaussian Mixture Model HMM model with. 233 arma::mat inp; 234 data::Load("obs1.csv", inp); 235 std::vector<arma::mat> trainSeq = {inp}; 236 237 // Initialize and train a HMM model. 238 HMMModel* h = new HMMModel(DiscreteHMM); 239 h->PerformAction<InitHMMModel, std::vector<arma::mat>>(&trainSeq); 240 h->PerformAction<TrainHMMModel, std::vector<arma::mat>>(&trainSeq); 241 242 // Set the params for the hmm_generate invocation 243 // Note that the length is negative - we expect that a runtime error will be 244 // raised in the call to hmm_generate_main 245 int length = -3; // Invalid 246 SetInputParam("model", h); 247 SetInputParam("length", length); 248 249 Log::Fatal.ignoreInput = true; 250 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error); 251 Log::Fatal.ignoreInput = false; 252 } 253 254 TEST_CASE_METHOD(HMMGenerateTestFixture, 255 "HMMGenerateValidStartStateTest", 256 "[HMMGenerateMainTest][BindingTests]") 257 { 258 // Load data to train a Gaussian Mixture Model HMM model with. 259 arma::mat inp; 260 data::Load("obs1.csv", inp); 261 std::vector<arma::mat> trainSeq = {inp}; 262 263 // Initialize and train a HMM model. 264 HMMModel* h = new HMMModel(DiscreteHMM); 265 h->PerformAction<InitHMMModel, std::vector<arma::mat>>(&trainSeq); 266 h->PerformAction<TrainHMMModel, std::vector<arma::mat>>(&trainSeq); 267 268 // Set the params for the hmm_generate invocation 269 // Note that the start state is invalid - we expect that a runtime error will 270 // be raised in the call to hmm_generate_main 271 int length = 3; 272 int startState = 2; // Invalid 273 SetInputParam("model", h); 274 SetInputParam("length", length); 275 SetInputParam("start_state", startState); 276 277 Log::Fatal.ignoreInput = true; 278 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error); 279 Log::Fatal.ignoreInput = false; 280 } 281