1 /** 2 * 3 * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU) 4 * info_at_agrum_dot_org 5 * 6 * This library is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU Lesser General Public License as published by 8 * the Free Software Foundation, either version 3 of the License, or 9 * (at your option) any later version. 10 * 11 * This library is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 * GNU Lesser General Public License for more details. 15 * 16 * You should have received a copy of the GNU Lesser General Public License 17 * along with this library. If not, see <http://www.gnu.org/licenses/>. 18 * 19 */ 20 21 22 #include <gumtest/AgrumTestSuite.h> 23 #include <gumtest/testsuite_utils.h> 24 #include <iostream> 25 26 #include <agrum/BN/inference/lazyPropagation.h> 27 #include <agrum/tools/database/DBRowGeneratorEM.h> 28 #include <agrum/tools/database/DBTranslator4LabelizedVariable.h> 29 #include <agrum/tools/database/DBTranslatorSet.h> 30 #include <agrum/BN/learning/aprioris/aprioriNoApriori.h> 31 #include <agrum/BN/learning/aprioris/aprioriSmoothing.h> 32 #include <agrum/BN/learning/paramUtils/paramEstimatorML.h> 33 34 namespace gum_tests { 35 36 class ParamEstimatorMLTestSuite: public CxxTest::TestSuite { 37 private: _normalize_(const std::vector<double> & vin)38 std::vector< double > _normalize_(const std::vector< double >& vin) { 39 double sum = 0; 40 for (const auto& val: vin) 41 sum += val; 42 std::vector< double > vout(vin); 43 for (auto& val: vout) 44 val /= sum; 45 return vout; 46 } 47 _xnormalize_(const std::vector<double> & vin)48 std::vector< double > _xnormalize_(const std::vector< double >& vin) { 49 std::vector< double > vout(vin); 50 for (std::size_t i = 0; i < vin.size(); i += 3) { 51 double sum = 0; 52 for (std::size_t j = std::size_t(0); j < 3; ++j) 53 sum += vin[i + j]; 54 for (std::size_t j = std::size_t(0); j < 3; ++j) 55 vout[i + j] /= sum; 56 } 57 return vout; 58 } 59 60 gum::Potential< double > _infer_(const gum::BayesNet<double> & bn,const std::vector<std::size_t> & targets,const gum::learning::DBRow<gum::learning::DBTranslatedValue> & row)61 _infer_(const gum::BayesNet< double >& bn, 62 const std::vector< std::size_t >& targets, 63 const gum::learning::DBRow< gum::learning::DBTranslatedValue >& row) { 64 gum::LazyPropagation< double > ve(&bn); 65 66 gum::NodeSet target_set; 67 for (auto target: targets) 68 target_set.insert(gum::NodeId(target)); 69 ve.addJointTarget(target_set); 70 71 const auto xrow = row.row(); 72 const auto row_size = xrow.size(); 73 for (std::size_t col = std::size_t(0); col < row_size; ++col) { 74 if (xrow[col].discr_val != std::numeric_limits< std::size_t >::max()) { 75 ve.addEvidence(gum::NodeId(col), xrow[col].discr_val); 76 } 77 } 78 79 gum::Potential< double > prob = ve.jointPosterior(target_set); 80 return prob; 81 } 82 83 84 public: test1()85 void test1() { 86 // create the translator set 87 gum::LabelizedVariable var("X1", "", 0); 88 var.addLabel("0"); 89 var.addLabel("1"); 90 var.addLabel("2"); 91 92 gum::learning::DBTranslatorSet<> trans_set; 93 { 94 const std::vector< std::string > miss; 95 gum::learning::DBTranslator4LabelizedVariable<> translator(var, miss); 96 std::vector< std::string > names{"A", "B", "C", "D", "E", "F"}; 97 98 for (std::size_t i = std::size_t(0); i < names.size(); ++i) { 99 translator.setVariableName(names[i]); 100 trans_set.insertTranslator(translator, i); 101 } 102 } 103 104 // create the database 105 gum::learning::DatabaseTable<> database(trans_set); 106 std::vector< std::string > row0{"0", "1", "0", "2", "1", "1"}; 107 std::vector< std::string > row1{"1", "2", "0", "1", "2", "2"}; 108 std::vector< std::string > row2{"2", "1", "0", "1", "1", "0"}; 109 std::vector< std::string > row3{"1", "0", "0", "0", "0", "0"}; 110 std::vector< std::string > row4{"0", "0", "0", "1", "1", "1"}; 111 for (int i = 0; i < 1000; ++i) 112 database.insertRow(row0); 113 for (int i = 0; i < 50; ++i) 114 database.insertRow(row1); 115 for (int i = 0; i < 75; ++i) 116 database.insertRow(row2); 117 for (int i = 0; i < 75; ++i) 118 database.insertRow(row3); 119 for (int i = 0; i < 200; ++i) 120 database.insertRow(row4); 121 122 // create the parser 123 gum::learning::DBRowGeneratorSet<> genset; 124 gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset); 125 gum::learning::AprioriSmoothing<> extern_apriori(database); 126 gum::learning::AprioriNoApriori<> intern_apriori(database); 127 128 gum::learning::ParamEstimatorML<> param_estimator(parser, extern_apriori, intern_apriori); 129 130 std::vector< double > v0 = param_estimator.parameters(gum::NodeId(0)); 131 std::vector< double > xv0 = _normalize_({1201, 126, 76}); 132 TS_ASSERT_EQUALS(v0, xv0) 133 134 std::vector< double > v1 = param_estimator.parameters(gum::NodeId(1)); 135 std::vector< double > xv1 = _normalize_({276, 1076, 51}); 136 TS_ASSERT_EQUALS(v1, xv1) 137 138 std::vector< double > v2 = param_estimator.parameters(gum::NodeId(2)); 139 std::vector< double > xv2 = _normalize_({1401, 1, 1}); 140 TS_ASSERT_EQUALS(v2, xv2) 141 142 std::vector< double > v02 = param_estimator.parameters(gum::NodeId(0), {gum::NodeId(2)}); 143 std::vector< double > xv02 = _xnormalize_({1201, 126, 76, 1, 1, 1, 1, 1, 1}); 144 TS_ASSERT_EQUALS(v02, xv02) 145 146 std::vector< double > v01 = param_estimator.parameters(gum::NodeId(0), {gum::NodeId(1)}); 147 std::vector< double > xv01 = _xnormalize_({201, 76, 1, 1001, 1, 76, 1, 51, 1}); 148 TS_ASSERT_EQUALS(v01, xv01) 149 150 std::vector< double > v21 = param_estimator.parameters(gum::NodeId(2), {gum::NodeId(1)}); 151 std::vector< double > xv21 = _xnormalize_({276, 1, 1, 1076, 1, 1, 51, 1, 1}); 152 TS_ASSERT_EQUALS(v21, xv21) 153 154 155 gum::learning::ParamEstimatorML<> param_estimator2(param_estimator); 156 std::vector< double > w0 = param_estimator2.parameters(gum::NodeId(0)); 157 TS_ASSERT_EQUALS(w0, xv0) 158 159 std::vector< double > w1 = param_estimator2.parameters(gum::NodeId(1)); 160 TS_ASSERT_EQUALS(w1, xv1) 161 162 std::vector< double > w2 = param_estimator2.parameters(gum::NodeId(2)); 163 TS_ASSERT_EQUALS(w2, xv2) 164 165 std::vector< double > w02 = param_estimator2.parameters(gum::NodeId(0), {gum::NodeId(2)}); 166 TS_ASSERT_EQUALS(w02, xv02) 167 168 std::vector< double > w01 = param_estimator2.parameters(gum::NodeId(0), {gum::NodeId(1)}); 169 TS_ASSERT_EQUALS(w01, xv01) 170 171 std::vector< double > w21 = param_estimator2.parameters(gum::NodeId(2), {gum::NodeId(1)}); 172 TS_ASSERT_EQUALS(w21, xv21) 173 174 175 gum::learning::ParamEstimatorML<> param_estimator3(std::move(param_estimator2)); 176 std::vector< double > x0 = param_estimator3.parameters(gum::NodeId(0)); 177 TS_ASSERT_EQUALS(x0, xv0) 178 179 std::vector< double > x1 = param_estimator3.parameters(gum::NodeId(1)); 180 TS_ASSERT_EQUALS(x1, xv1) 181 182 std::vector< double > x2 = param_estimator3.parameters(gum::NodeId(2)); 183 TS_ASSERT_EQUALS(x2, xv2) 184 185 std::vector< double > x02 = param_estimator3.parameters(gum::NodeId(0), {gum::NodeId(2)}); 186 TS_ASSERT_EQUALS(x02, xv02) 187 188 std::vector< double > x01 = param_estimator3.parameters(gum::NodeId(0), {gum::NodeId(1)}); 189 TS_ASSERT_EQUALS(x01, xv01) 190 191 std::vector< double > x21 = param_estimator3.parameters(gum::NodeId(2), {gum::NodeId(1)}); 192 TS_ASSERT_EQUALS(x21, xv21) 193 194 195 gum::learning::ParamEstimatorML<>* param_estimator4 = param_estimator.clone(); 196 std::vector< double > y0 = param_estimator4->parameters(gum::NodeId(0)); 197 TS_ASSERT_EQUALS(y0, xv0) 198 199 std::vector< double > y1 = param_estimator4->parameters(gum::NodeId(1)); 200 TS_ASSERT_EQUALS(y1, xv1) 201 202 std::vector< double > y2 = param_estimator4->parameters(gum::NodeId(2)); 203 TS_ASSERT_EQUALS(y2, xv2) 204 205 std::vector< double > y02 = param_estimator4->parameters(gum::NodeId(0), {gum::NodeId(2)}); 206 TS_ASSERT_EQUALS(y02, xv02) 207 208 std::vector< double > y01 = param_estimator4->parameters(gum::NodeId(0), {gum::NodeId(1)}); 209 TS_ASSERT_EQUALS(y01, xv01) 210 211 std::vector< double > y21 = param_estimator4->parameters(gum::NodeId(2), {gum::NodeId(1)}); 212 TS_ASSERT_EQUALS(y21, xv21) 213 214 delete param_estimator4; 215 } 216 217 test2()218 void test2() { 219 // create the translator set 220 gum::LabelizedVariable var("X1", "", 0); 221 var.addLabel("0"); 222 var.addLabel("1"); 223 var.addLabel("2"); 224 225 gum::learning::DBTranslatorSet<> trans_set; 226 { 227 const std::vector< std::string > miss; 228 gum::learning::DBTranslator4LabelizedVariable<> translator(var, miss); 229 std::vector< std::string > names{"A", "B", "C", "D", "E", "F"}; 230 231 for (std::size_t i = std::size_t(0); i < names.size(); ++i) { 232 translator.setVariableName(names[i]); 233 trans_set.insertTranslator(translator, i); 234 } 235 } 236 237 // create the database 238 gum::learning::DatabaseTable<> database(trans_set); 239 std::vector< std::string > row0{"0", "1", "0", "2", "1", "1"}; 240 std::vector< std::string > row1{"1", "2", "0", "1", "2", "2"}; 241 std::vector< std::string > row2{"2", "1", "0", "1", "1", "0"}; 242 std::vector< std::string > row3{"1", "0", "0", "0", "0", "0"}; 243 std::vector< std::string > row4{"0", "0", "0", "1", "1", "1"}; 244 for (int i = 0; i < 1000; ++i) 245 database.insertRow(row0); 246 for (int i = 0; i < 50; ++i) 247 database.insertRow(row1); 248 for (int i = 0; i < 75; ++i) 249 database.insertRow(row2); 250 for (int i = 0; i < 75; ++i) 251 database.insertRow(row3); 252 for (int i = 0; i < 200; ++i) 253 database.insertRow(row4); 254 255 // create the parser 256 gum::learning::DBRowGeneratorSet<> genset; 257 gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset); 258 gum::learning::AprioriSmoothing<> extern_apriori(database); 259 gum::learning::AprioriNoApriori<> intern_apriori(database); 260 261 std::vector< std::pair< std::size_t, std::size_t > > ranges{{800, 1000}, {1050, 1400}}; 262 263 gum::learning::ParamEstimatorML<> param_estimator(parser, 264 extern_apriori, 265 intern_apriori, 266 ranges); 267 268 std::vector< double > r0 = param_estimator.parameters(gum::NodeId(0)); 269 std::vector< double > xr0 = _normalize_({401, 76, 76}); 270 TS_ASSERT_EQUALS(r0, xr0) 271 272 std::vector< double > r1 = param_estimator.parameters(gum::NodeId(1)); 273 std::vector< double > xr1 = _normalize_({276, 276, 1}); 274 TS_ASSERT_EQUALS(r1, xr1) 275 276 std::vector< double > r2 = param_estimator.parameters(gum::NodeId(2)); 277 std::vector< double > xr2 = _normalize_({551, 1, 1}); 278 TS_ASSERT_EQUALS(r2, xr2) 279 280 std::vector< double > r02 = param_estimator.parameters(gum::NodeId(0), {gum::NodeId(2)}); 281 std::vector< double > xr02 = _xnormalize_({401, 76, 76, 1, 1, 1, 1, 1, 1}); 282 TS_ASSERT_EQUALS(r02, xr02) 283 284 std::vector< double > r01 = param_estimator.parameters(gum::NodeId(0), {gum::NodeId(1)}); 285 std::vector< double > xr01 = _xnormalize_({201, 76, 1, 201, 1, 76, 1, 1, 1}); 286 TS_ASSERT_EQUALS(r01, xr01) 287 288 std::vector< double > r21 = param_estimator.parameters(gum::NodeId(2), {gum::NodeId(1)}); 289 std::vector< double > xr21 = _xnormalize_({276, 1, 1, 276, 1, 1, 1, 1, 1}); 290 TS_ASSERT_EQUALS(r21, xr21) 291 292 293 gum::learning::ParamEstimatorML<> param_estimator2(param_estimator); 294 std::vector< double > v0 = param_estimator2.parameters(gum::NodeId(0)); 295 TS_ASSERT_EQUALS(v0, xr0) 296 297 std::vector< double > v1 = param_estimator2.parameters(gum::NodeId(1)); 298 TS_ASSERT_EQUALS(v1, xr1) 299 300 std::vector< double > v2 = param_estimator2.parameters(gum::NodeId(2)); 301 TS_ASSERT_EQUALS(v2, xr2) 302 303 std::vector< double > v02 = param_estimator2.parameters(gum::NodeId(0), {gum::NodeId(2)}); 304 TS_ASSERT_EQUALS(v02, xr02) 305 306 std::vector< double > v01 = param_estimator2.parameters(gum::NodeId(0), {gum::NodeId(1)}); 307 TS_ASSERT_EQUALS(v01, xr01) 308 309 std::vector< double > v21 = param_estimator2.parameters(gum::NodeId(2), {gum::NodeId(1)}); 310 TS_ASSERT_EQUALS(v21, xr21) 311 312 313 gum::learning::ParamEstimatorML<> param_estimator3(std::move(param_estimator2)); 314 std::vector< double > w0 = param_estimator3.parameters(gum::NodeId(0)); 315 TS_ASSERT_EQUALS(w0, xr0) 316 317 std::vector< double > w1 = param_estimator3.parameters(gum::NodeId(1)); 318 TS_ASSERT_EQUALS(w1, xr1) 319 320 std::vector< double > w2 = param_estimator3.parameters(gum::NodeId(2)); 321 TS_ASSERT_EQUALS(w2, xr2) 322 323 std::vector< double > w02 = param_estimator3.parameters(gum::NodeId(0), {gum::NodeId(2)}); 324 TS_ASSERT_EQUALS(w02, xr02) 325 326 std::vector< double > w01 = param_estimator3.parameters(gum::NodeId(0), {gum::NodeId(1)}); 327 TS_ASSERT_EQUALS(w01, xr01) 328 329 std::vector< double > w21 = param_estimator3.parameters(gum::NodeId(2), {gum::NodeId(1)}); 330 TS_ASSERT_EQUALS(w21, xr21) 331 332 333 gum::learning::ParamEstimatorML<>* param_estimator4 = param_estimator.clone(); 334 std::vector< double > x0 = param_estimator4->parameters(gum::NodeId(0)); 335 TS_ASSERT_EQUALS(x0, xr0) 336 337 std::vector< double > x1 = param_estimator4->parameters(gum::NodeId(1)); 338 TS_ASSERT_EQUALS(x1, xr1) 339 340 std::vector< double > x2 = param_estimator4->parameters(gum::NodeId(2)); 341 TS_ASSERT_EQUALS(x2, xr2) 342 343 std::vector< double > x02 = param_estimator4->parameters(gum::NodeId(0), {gum::NodeId(2)}); 344 TS_ASSERT_EQUALS(x02, xr02) 345 346 std::vector< double > x01 = param_estimator4->parameters(gum::NodeId(0), {gum::NodeId(1)}); 347 TS_ASSERT_EQUALS(x01, xr01) 348 349 std::vector< double > x21 = param_estimator4->parameters(gum::NodeId(2), {gum::NodeId(1)}); 350 TS_ASSERT_EQUALS(x21, xr21) 351 352 delete param_estimator4; 353 } 354 355 test3()356 void test3() { 357 // create the translator set 358 gum::LabelizedVariable var("X1", "", 0); 359 var.addLabel("0"); 360 var.addLabel("1"); 361 var.addLabel("2"); 362 363 gum::learning::DBTranslatorSet<> trans_set; 364 { 365 const std::vector< std::string > miss; 366 gum::learning::DBTranslator4LabelizedVariable<> translator(var, miss); 367 std::vector< std::string > names{"A", "B", "C", "D", "E", "F"}; 368 369 for (std::size_t i = std::size_t(0); i < names.size(); ++i) { 370 translator.setVariableName(names[i]); 371 trans_set.insertTranslator(translator, i); 372 } 373 } 374 375 // create the database 376 gum::learning::DatabaseTable<> database(trans_set); 377 std::vector< std::string > row0{"0", "1", "0", "2", "1", "1"}; 378 std::vector< std::string > row1{"1", "2", "0", "1", "2", "2"}; 379 std::vector< std::string > row2{"2", "1", "0", "1", "1", "0"}; 380 std::vector< std::string > row3{"1", "0", "0", "0", "0", "0"}; 381 std::vector< std::string > row4{"0", "0", "0", "1", "1", "1"}; 382 for (int i = 0; i < 1000; ++i) 383 database.insertRow(row0); 384 for (int i = 0; i < 50; ++i) 385 database.insertRow(row1); 386 for (int i = 0; i < 75; ++i) 387 database.insertRow(row2); 388 for (int i = 0; i < 75; ++i) 389 database.insertRow(row3); 390 for (int i = 0; i < 200; ++i) 391 database.insertRow(row4); 392 393 // create the parser 394 gum::learning::DBRowGeneratorSet<> genset; 395 gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset); 396 gum::learning::AprioriSmoothing<> extern_apriori(database); 397 gum::learning::AprioriNoApriori<> intern_apriori(database); 398 399 gum::Bijection< gum::NodeId, std::size_t > nodeId2columns; 400 nodeId2columns.insert(0, 2); 401 nodeId2columns.insert(1, 0); 402 nodeId2columns.insert(2, 1); 403 404 gum::learning::ParamEstimatorML<> param_estimator(parser, 405 extern_apriori, 406 intern_apriori, 407 nodeId2columns); 408 409 std::vector< double > v0 = param_estimator.parameters(gum::NodeId(1)); 410 std::vector< double > xv0 = _normalize_({1201, 126, 76}); 411 TS_ASSERT_EQUALS(v0, xv0) 412 413 std::vector< double > v1 = param_estimator.parameters(gum::NodeId(2)); 414 std::vector< double > xv1 = _normalize_({276, 1076, 51}); 415 TS_ASSERT_EQUALS(v1, xv1) 416 417 std::vector< double > v2 = param_estimator.parameters(gum::NodeId(0)); 418 std::vector< double > xv2 = _normalize_({1401, 1, 1}); 419 TS_ASSERT_EQUALS(v2, xv2) 420 421 std::vector< double > v02 = param_estimator.parameters(gum::NodeId(1), {gum::NodeId(0)}); 422 std::vector< double > xv02 = _xnormalize_({1201, 126, 76, 1, 1, 1, 1, 1, 1}); 423 TS_ASSERT_EQUALS(v02, xv02) 424 425 std::vector< double > v01 = param_estimator.parameters(gum::NodeId(1), {gum::NodeId(2)}); 426 std::vector< double > xv01 = _xnormalize_({201, 76, 1, 1001, 1, 76, 1, 51, 1}); 427 TS_ASSERT_EQUALS(v01, xv01) 428 429 std::vector< double > v21 = param_estimator.parameters(gum::NodeId(0), {gum::NodeId(2)}); 430 std::vector< double > xv21 = _xnormalize_({276, 1, 1, 1076, 1, 1, 51, 1, 1}); 431 TS_ASSERT_EQUALS(v21, xv21) 432 433 434 gum::learning::ParamEstimatorML<> param_estimator2(param_estimator); 435 std::vector< double > w0 = param_estimator2.parameters(gum::NodeId(1)); 436 TS_ASSERT_EQUALS(w0, xv0) 437 438 std::vector< double > w1 = param_estimator2.parameters(gum::NodeId(2)); 439 TS_ASSERT_EQUALS(w1, xv1) 440 441 std::vector< double > w2 = param_estimator2.parameters(gum::NodeId(0)); 442 TS_ASSERT_EQUALS(w2, xv2) 443 444 std::vector< double > w02 = param_estimator2.parameters(gum::NodeId(1), {gum::NodeId(0)}); 445 TS_ASSERT_EQUALS(w02, xv02) 446 447 std::vector< double > w01 = param_estimator2.parameters(gum::NodeId(1), {gum::NodeId(2)}); 448 TS_ASSERT_EQUALS(w01, xv01) 449 450 std::vector< double > w21 = param_estimator2.parameters(gum::NodeId(0), {gum::NodeId(2)}); 451 TS_ASSERT_EQUALS(w21, xv21) 452 453 454 gum::learning::ParamEstimatorML<> param_estimator3(std::move(param_estimator2)); 455 std::vector< double > x0 = param_estimator3.parameters(gum::NodeId(1)); 456 TS_ASSERT_EQUALS(x0, xv0) 457 458 std::vector< double > x1 = param_estimator3.parameters(gum::NodeId(2)); 459 TS_ASSERT_EQUALS(x1, xv1) 460 461 std::vector< double > x2 = param_estimator3.parameters(gum::NodeId(0)); 462 TS_ASSERT_EQUALS(x2, xv2) 463 464 std::vector< double > x02 = param_estimator3.parameters(gum::NodeId(1), {gum::NodeId(0)}); 465 TS_ASSERT_EQUALS(x02, xv02) 466 467 std::vector< double > x01 = param_estimator3.parameters(gum::NodeId(1), {gum::NodeId(2)}); 468 TS_ASSERT_EQUALS(x01, xv01) 469 470 std::vector< double > x21 = param_estimator3.parameters(gum::NodeId(0), {gum::NodeId(2)}); 471 TS_ASSERT_EQUALS(x21, xv21) 472 473 474 gum::learning::ParamEstimatorML<>* param_estimator4 = param_estimator.clone(); 475 std::vector< double > y0 = param_estimator4->parameters(gum::NodeId(1)); 476 TS_ASSERT_EQUALS(y0, xv0) 477 478 std::vector< double > y1 = param_estimator4->parameters(gum::NodeId(2)); 479 TS_ASSERT_EQUALS(y1, xv1) 480 481 std::vector< double > y2 = param_estimator4->parameters(gum::NodeId(0)); 482 TS_ASSERT_EQUALS(y2, xv2) 483 484 std::vector< double > y02 = param_estimator4->parameters(gum::NodeId(1), {gum::NodeId(0)}); 485 TS_ASSERT_EQUALS(y02, xv02) 486 487 std::vector< double > y01 = param_estimator4->parameters(gum::NodeId(1), {gum::NodeId(2)}); 488 TS_ASSERT_EQUALS(y01, xv01) 489 490 std::vector< double > y21 = param_estimator4->parameters(gum::NodeId(0), {gum::NodeId(2)}); 491 TS_ASSERT_EQUALS(y21, xv21) 492 493 delete param_estimator4; 494 } 495 496 test4()497 void test4() { 498 // create the translator set 499 gum::LabelizedVariable var("X1", "", 0); 500 var.addLabel("0"); 501 var.addLabel("1"); 502 var.addLabel("2"); 503 504 gum::learning::DBTranslatorSet<> trans_set; 505 { 506 const std::vector< std::string > miss; 507 gum::learning::DBTranslator4LabelizedVariable<> translator(var, miss); 508 std::vector< std::string > names{"A", "B", "C", "D", "E", "F"}; 509 510 for (std::size_t i = std::size_t(0); i < names.size(); ++i) { 511 translator.setVariableName(names[i]); 512 trans_set.insertTranslator(translator, i); 513 } 514 } 515 516 // create the database 517 gum::learning::DatabaseTable<> database(trans_set); 518 std::vector< std::string > row0{"0", "1", "0", "2", "1", "1"}; 519 std::vector< std::string > row1{"1", "2", "0", "1", "2", "2"}; 520 std::vector< std::string > row2{"2", "1", "0", "1", "1", "0"}; 521 std::vector< std::string > row3{"1", "0", "0", "0", "0", "0"}; 522 std::vector< std::string > row4{"0", "0", "0", "1", "1", "1"}; 523 for (int i = 0; i < 1000; ++i) 524 database.insertRow(row0); 525 for (int i = 0; i < 50; ++i) 526 database.insertRow(row1); 527 for (int i = 0; i < 75; ++i) 528 database.insertRow(row2); 529 for (int i = 0; i < 75; ++i) 530 database.insertRow(row3); 531 for (int i = 0; i < 200; ++i) 532 database.insertRow(row4); 533 534 // create the parser 535 gum::learning::DBRowGeneratorSet<> genset; 536 gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset); 537 gum::learning::AprioriSmoothing<> extern_apriori(database); 538 gum::learning::AprioriNoApriori<> intern_apriori(database); 539 540 std::vector< std::pair< std::size_t, std::size_t > > ranges{{800, 1000}, {1050, 1400}}; 541 542 gum::Bijection< gum::NodeId, std::size_t > nodeId2columns; 543 nodeId2columns.insert(0, 2); 544 nodeId2columns.insert(1, 0); 545 nodeId2columns.insert(2, 1); 546 547 gum::learning::ParamEstimatorML<> param_estimator(parser, 548 extern_apriori, 549 intern_apriori, 550 ranges, 551 nodeId2columns); 552 553 std::vector< double > r0 = param_estimator.parameters(gum::NodeId(1)); 554 std::vector< double > xr0 = _normalize_({401, 76, 76}); 555 TS_ASSERT_EQUALS(r0, xr0) 556 557 std::vector< double > r1 = param_estimator.parameters(gum::NodeId(2)); 558 std::vector< double > xr1 = _normalize_({276, 276, 1}); 559 TS_ASSERT_EQUALS(r1, xr1) 560 561 std::vector< double > r2 = param_estimator.parameters(gum::NodeId(0)); 562 std::vector< double > xr2 = _normalize_({551, 1, 1}); 563 TS_ASSERT_EQUALS(r2, xr2) 564 565 std::vector< double > r02 = param_estimator.parameters(gum::NodeId(1), {gum::NodeId(0)}); 566 std::vector< double > xr02 = _xnormalize_({401, 76, 76, 1, 1, 1, 1, 1, 1}); 567 TS_ASSERT_EQUALS(r02, xr02) 568 569 std::vector< double > r01 = param_estimator.parameters(gum::NodeId(1), {gum::NodeId(2)}); 570 std::vector< double > xr01 = _xnormalize_({201, 76, 1, 201, 1, 76, 1, 1, 1}); 571 TS_ASSERT_EQUALS(r01, xr01) 572 573 std::vector< double > r21 = param_estimator.parameters(gum::NodeId(0), {gum::NodeId(2)}); 574 std::vector< double > xr21 = _xnormalize_({276, 1, 1, 276, 1, 1, 1, 1, 1}); 575 TS_ASSERT_EQUALS(r21, xr21) 576 577 578 gum::learning::ParamEstimatorML<> param_estimator2(param_estimator); 579 std::vector< double > v0 = param_estimator2.parameters(gum::NodeId(1)); 580 TS_ASSERT_EQUALS(v0, xr0) 581 582 std::vector< double > v1 = param_estimator2.parameters(gum::NodeId(2)); 583 TS_ASSERT_EQUALS(v1, xr1) 584 585 std::vector< double > v2 = param_estimator2.parameters(gum::NodeId(0)); 586 TS_ASSERT_EQUALS(v2, xr2) 587 588 std::vector< double > v02 = param_estimator2.parameters(gum::NodeId(1), {gum::NodeId(0)}); 589 TS_ASSERT_EQUALS(v02, xr02) 590 591 std::vector< double > v01 = param_estimator2.parameters(gum::NodeId(1), {gum::NodeId(2)}); 592 TS_ASSERT_EQUALS(v01, xr01) 593 594 std::vector< double > v21 = param_estimator2.parameters(gum::NodeId(0), {gum::NodeId(2)}); 595 TS_ASSERT_EQUALS(v21, xr21) 596 597 598 gum::learning::ParamEstimatorML<> param_estimator3(std::move(param_estimator2)); 599 std::vector< double > w0 = param_estimator3.parameters(gum::NodeId(1)); 600 TS_ASSERT_EQUALS(w0, xr0) 601 602 std::vector< double > w1 = param_estimator3.parameters(gum::NodeId(2)); 603 TS_ASSERT_EQUALS(w1, xr1) 604 605 std::vector< double > w2 = param_estimator3.parameters(gum::NodeId(0)); 606 TS_ASSERT_EQUALS(w2, xr2) 607 608 std::vector< double > w02 = param_estimator3.parameters(gum::NodeId(1), {gum::NodeId(0)}); 609 TS_ASSERT_EQUALS(w02, xr02) 610 611 std::vector< double > w01 = param_estimator3.parameters(gum::NodeId(1), {gum::NodeId(2)}); 612 TS_ASSERT_EQUALS(w01, xr01) 613 614 std::vector< double > w21 = param_estimator3.parameters(gum::NodeId(0), {gum::NodeId(2)}); 615 TS_ASSERT_EQUALS(w21, xr21) 616 617 618 gum::learning::ParamEstimatorML<>* param_estimator4 = param_estimator.clone(); 619 std::vector< double > x0 = param_estimator4->parameters(gum::NodeId(1)); 620 TS_ASSERT_EQUALS(x0, xr0) 621 622 std::vector< double > x1 = param_estimator4->parameters(gum::NodeId(2)); 623 TS_ASSERT_EQUALS(x1, xr1) 624 625 std::vector< double > x2 = param_estimator4->parameters(gum::NodeId(0)); 626 TS_ASSERT_EQUALS(x2, xr2) 627 628 std::vector< double > x02 = param_estimator4->parameters(gum::NodeId(1), {gum::NodeId(0)}); 629 TS_ASSERT_EQUALS(x02, xr02) 630 631 std::vector< double > x01 = param_estimator4->parameters(gum::NodeId(1), {gum::NodeId(2)}); 632 TS_ASSERT_EQUALS(x01, xr01) 633 634 std::vector< double > x21 = param_estimator4->parameters(gum::NodeId(0), {gum::NodeId(2)}); 635 TS_ASSERT_EQUALS(x21, xr21) 636 637 delete param_estimator4; 638 } 639 testChangeRanges()640 void testChangeRanges() { 641 // create the translator set 642 gum::LabelizedVariable var("X1", "", 0); 643 var.addLabel("0"); 644 var.addLabel("1"); 645 var.addLabel("2"); 646 647 gum::learning::DBTranslatorSet<> trans_set; 648 { 649 const std::vector< std::string > miss; 650 gum::learning::DBTranslator4LabelizedVariable<> translator(var, miss); 651 std::vector< std::string > names{"A", "B", "C", "D", "E", "F"}; 652 653 for (std::size_t i = std::size_t(0); i < names.size(); ++i) { 654 translator.setVariableName(names[i]); 655 trans_set.insertTranslator(translator, i); 656 } 657 } 658 659 // create the database 660 gum::learning::DatabaseTable<> database(trans_set); 661 std::vector< std::string > row0{"0", "1", "0", "2", "1", "1"}; 662 std::vector< std::string > row1{"1", "2", "0", "1", "2", "2"}; 663 std::vector< std::string > row2{"2", "1", "0", "1", "1", "0"}; 664 std::vector< std::string > row3{"1", "0", "0", "0", "0", "0"}; 665 std::vector< std::string > row4{"0", "0", "0", "1", "1", "1"}; 666 for (int i = 0; i < 1000; ++i) 667 database.insertRow(row0); 668 for (int i = 0; i < 50; ++i) 669 database.insertRow(row1); 670 for (int i = 0; i < 75; ++i) 671 database.insertRow(row2); 672 for (int i = 0; i < 75; ++i) 673 database.insertRow(row3); 674 for (int i = 0; i < 200; ++i) 675 database.insertRow(row4); 676 677 // create the parser 678 gum::learning::DBRowGeneratorSet<> genset; 679 gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset); 680 gum::learning::AprioriSmoothing<> extern_apriori(database); 681 gum::learning::AprioriNoApriori<> intern_apriori(database); 682 683 gum::learning::ParamEstimatorML<> param_estimator(parser, extern_apriori, intern_apriori); 684 685 std::vector< double > v0 = param_estimator.parameters(gum::NodeId(0)); 686 std::vector< double > xv0 = _normalize_({1201, 126, 76}); 687 TS_ASSERT_EQUALS(v0, xv0) 688 689 std::vector< double > v1 = param_estimator.parameters(gum::NodeId(1)); 690 std::vector< double > xv1 = _normalize_({276, 1076, 51}); 691 TS_ASSERT_EQUALS(v1, xv1) 692 693 std::vector< double > v2 = param_estimator.parameters(gum::NodeId(2)); 694 std::vector< double > xv2 = _normalize_({1401, 1, 1}); 695 TS_ASSERT_EQUALS(v2, xv2) 696 697 std::vector< double > v02 = param_estimator.parameters(gum::NodeId(0), {gum::NodeId(2)}); 698 std::vector< double > xv02 = _xnormalize_({1201, 126, 76, 1, 1, 1, 1, 1, 1}); 699 TS_ASSERT_EQUALS(v02, xv02) 700 701 std::vector< double > v01 = param_estimator.parameters(gum::NodeId(0), {gum::NodeId(1)}); 702 std::vector< double > xv01 = _xnormalize_({201, 76, 1, 1001, 1, 76, 1, 51, 1}); 703 TS_ASSERT_EQUALS(v01, xv01) 704 705 std::vector< double > v21 = param_estimator.parameters(gum::NodeId(2), {gum::NodeId(1)}); 706 std::vector< double > xv21 = _xnormalize_({276, 1, 1, 1076, 1, 1, 51, 1, 1}); 707 TS_ASSERT_EQUALS(v21, xv21) 708 709 710 gum::learning::ParamEstimatorML<> param_estimator2(param_estimator); 711 std::vector< double > w0 = param_estimator2.parameters(gum::NodeId(0)); 712 TS_ASSERT_EQUALS(w0, xv0) 713 714 std::vector< double > w1 = param_estimator2.parameters(gum::NodeId(1)); 715 TS_ASSERT_EQUALS(w1, xv1) 716 717 std::vector< double > w2 = param_estimator2.parameters(gum::NodeId(2)); 718 TS_ASSERT_EQUALS(w2, xv2) 719 720 std::vector< double > w02 = param_estimator2.parameters(gum::NodeId(0), {gum::NodeId(2)}); 721 TS_ASSERT_EQUALS(w02, xv02) 722 723 std::vector< double > w01 = param_estimator2.parameters(gum::NodeId(0), {gum::NodeId(1)}); 724 TS_ASSERT_EQUALS(w01, xv01) 725 726 std::vector< double > w21 = param_estimator2.parameters(gum::NodeId(2), {gum::NodeId(1)}); 727 TS_ASSERT_EQUALS(w21, xv21) 728 729 std::vector< std::pair< std::size_t, std::size_t > > ranges{{800, 1000}, {1050, 1400}}; 730 param_estimator.setRanges(ranges); 731 732 std::vector< double > ar0 = param_estimator.parameters(gum::NodeId(0)); 733 std::vector< double > axr0 = _normalize_({401, 76, 76}); 734 TS_ASSERT_EQUALS(ar0, axr0) 735 736 std::vector< double > ar1 = param_estimator.parameters(gum::NodeId(1)); 737 std::vector< double > axr1 = _normalize_({276, 276, 1}); 738 TS_ASSERT_EQUALS(ar1, axr1) 739 740 std::vector< double > ar2 = param_estimator.parameters(gum::NodeId(2)); 741 std::vector< double > axr2 = _normalize_({551, 1, 1}); 742 TS_ASSERT_EQUALS(ar2, axr2) 743 744 std::vector< double > ar02 = param_estimator.parameters(gum::NodeId(0), {gum::NodeId(2)}); 745 std::vector< double > axr02 = _xnormalize_({401, 76, 76, 1, 1, 1, 1, 1, 1}); 746 TS_ASSERT_EQUALS(ar02, axr02) 747 748 std::vector< double > ar01 = param_estimator.parameters(gum::NodeId(0), {gum::NodeId(1)}); 749 std::vector< double > axr01 = _xnormalize_({201, 76, 1, 201, 1, 76, 1, 1, 1}); 750 TS_ASSERT_EQUALS(ar01, axr01) 751 752 std::vector< double > ar21 = param_estimator.parameters(gum::NodeId(2), {gum::NodeId(1)}); 753 std::vector< double > axr21 = _xnormalize_({276, 1, 1, 276, 1, 1, 1, 1, 1}); 754 TS_ASSERT_EQUALS(ar21, axr21) 755 756 param_estimator2.setRanges(ranges); 757 758 std::vector< double > av0 = param_estimator2.parameters(gum::NodeId(0)); 759 TS_ASSERT_EQUALS(av0, axr0) 760 761 std::vector< double > av1 = param_estimator2.parameters(gum::NodeId(1)); 762 TS_ASSERT_EQUALS(av1, axr1) 763 764 std::vector< double > av2 = param_estimator2.parameters(gum::NodeId(2)); 765 TS_ASSERT_EQUALS(av2, axr2) 766 767 std::vector< double > av02 = param_estimator2.parameters(gum::NodeId(0), {gum::NodeId(2)}); 768 TS_ASSERT_EQUALS(av02, axr02) 769 770 std::vector< double > av01 = param_estimator2.parameters(gum::NodeId(0), {gum::NodeId(1)}); 771 TS_ASSERT_EQUALS(av01, axr01) 772 773 std::vector< double > av21 = param_estimator2.parameters(gum::NodeId(2), {gum::NodeId(1)}); 774 TS_ASSERT_EQUALS(av21, axr21) 775 776 TS_ASSERT_EQUALS(param_estimator2.ranges(), ranges) 777 778 param_estimator2.clearRanges(); 779 780 std::vector< double > bv0 = param_estimator2.parameters(gum::NodeId(0)); 781 TS_ASSERT_EQUALS(bv0, xv0) 782 783 std::vector< double > bv1 = param_estimator2.parameters(gum::NodeId(1)); 784 TS_ASSERT_EQUALS(bv1, xv1) 785 786 std::vector< double > bv2 = param_estimator2.parameters(gum::NodeId(2)); 787 TS_ASSERT_EQUALS(bv2, xv2) 788 789 std::vector< double > bv02 = param_estimator2.parameters(gum::NodeId(0), {gum::NodeId(2)}); 790 TS_ASSERT_EQUALS(bv02, xv02) 791 792 std::vector< double > bv01 = param_estimator2.parameters(gum::NodeId(0), {gum::NodeId(1)}); 793 TS_ASSERT_EQUALS(bv01, xv01) 794 795 std::vector< double > bv21 = param_estimator2.parameters(gum::NodeId(2), {gum::NodeId(1)}); 796 TS_ASSERT_EQUALS(bv21, xv21) 797 } 798 799 testEM()800 void testEM() { 801 gum::LabelizedVariable var("x", "", 0); 802 var.addLabel("0"); 803 var.addLabel("1"); 804 const std::vector< std::string > miss{"N/A", "?"}; 805 gum::learning::DBTranslator4LabelizedVariable<> translator(var, miss); 806 gum::learning::DBTranslatorSet<> set; 807 for (std::size_t i = std::size_t(0); i < std::size_t(4); ++i) 808 set.insertTranslator(translator, i); 809 810 set[0].setVariableName("A"); 811 set[1].setVariableName("B"); 812 set[2].setVariableName("C"); 813 set[3].setVariableName("D"); 814 815 gum::learning::DatabaseTable<> database(set); 816 std::vector< std::string > row1{"0", "1", "1", "0"}; 817 std::vector< std::string > row2{"0", "?", "1", "0"}; 818 std::vector< std::string > row3{"0", "?", "?", "0"}; 819 std::vector< std::string > row4{"?", "?", "1", "0"}; 820 std::vector< std::string > row5{"?", "?", "?", "?"}; 821 for (int i = 0; i < 100; ++i) { 822 database.insertRow(row1); 823 database.insertRow(row2); 824 database.insertRow(row3); 825 database.insertRow(row4); 826 database.insertRow(row5); 827 } 828 829 const std::vector< gum::learning::DBTranslatedValueType > col_types{ 830 gum::learning::DBTranslatedValueType::DISCRETE, 831 gum::learning::DBTranslatedValueType::DISCRETE, 832 gum::learning::DBTranslatedValueType::DISCRETE, 833 gum::learning::DBTranslatedValueType::DISCRETE}; 834 835 auto bn0 = gum::BayesNet< double >::fastPrototype("A;B;C;D"); 836 bn0.cpt("A").fillWith({0.3, 0.7}); 837 bn0.cpt("B").fillWith({0.3, 0.7}); 838 bn0.cpt("C").fillWith({0.3, 0.7}); 839 bn0.cpt("D").fillWith({0.3, 0.7}); 840 841 gum::learning::DBRowGeneratorIdentity<> generator1(col_types); 842 gum::learning::DBRowGeneratorEM<> generator2(col_types, bn0); 843 gum::learning::DBRowGenerator<>& gen2 = generator2; // fix for g++-4.8 844 gum::learning::DBRowGeneratorIdentity<> generator3(col_types); 845 gum::learning::DBRowGeneratorEM<> generator4(col_types, bn0); 846 gum::learning::DBRowGenerator<>& gen4 = generator4; // fix for g++-4.8 847 848 gum::learning::DBRowGeneratorSet<> genset; 849 genset.insertGenerator(generator1); 850 genset.insertGenerator(gen2); 851 genset.insertGenerator(generator3); 852 genset.insertGenerator(gen4); 853 854 gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset); 855 gum::learning::AprioriNoApriori<> extern_apriori(database); 856 gum::learning::AprioriNoApriori<> intern_apriori(database); 857 858 gum::learning::ParamEstimatorML<> param_estimator(parser, extern_apriori, intern_apriori); 859 860 auto bn = gum::BayesNet< double >::fastPrototype("A->B->C<-D"); 861 bn.cpt("A").fillWith({0.3, 0.7}); 862 bn.cpt("B").fillWith({0.4, 0.6, 0.7, 0.3}); 863 bn.cpt("C").fillWith({0.2, 0.8, 0.3, 0.7, 0.4, 0.6, 0.5, 0.5}); 864 bn.cpt("D").fillWith({0.8, 0.2}); 865 866 // bugfix for parallel execution of VariableElimination 867 const gum::DAG& dag = bn.dag(); 868 for (const auto node: dag) { 869 dag.parents(node); 870 dag.children(node); 871 } 872 873 param_estimator.setBayesNet(bn); 874 875 gum::learning::IdCondSet<> ids(0, std::vector< gum::NodeId >{1}, true); 876 // gum::learning::IdCondSet<> ids(0, {}, true); 877 std::vector< double > counts = param_estimator.parameters(gum::NodeId(0), {gum::NodeId(1)}); 878 879 std::vector< double > xcounts(4, 0.0); 880 std::vector< double > sum(4, 0.0); 881 int nb_row = 0; 882 for (const auto& row: database) { 883 gum::Potential< double > proba = _infer_(bn, {std::size_t(0), std::size_t(1)}, row); 884 885 std::size_t idx; 886 for (gum::Instantiation inst(proba); !inst.end(); ++inst) { 887 if (proba.variablesSequence()[0]->name() == "A") 888 idx = inst.val(0) + std::size_t(2) * inst.val(1); 889 else 890 idx = inst.val(1) + std::size_t(2) * inst.val(0); 891 xcounts[idx] += proba.get(inst); 892 } 893 894 ++nb_row; 895 if (nb_row == 5) break; 896 } 897 898 sum[0] = xcounts[0] + xcounts[1]; 899 sum[1] = sum[0]; 900 sum[2] = xcounts[2] + xcounts[3]; 901 sum[3] = sum[2]; 902 903 for (std::size_t i = std::size_t(0); i < std::size_t(4); ++i) 904 xcounts[i] /= sum[i]; 905 906 for (std::size_t i = std::size_t(0); i < std::size_t(4); ++i) { 907 TS_ASSERT_DELTA(counts[i], xcounts[i], 0.001) 908 } 909 } 910 testZeroInPseudoCounts()911 void testZeroInPseudoCounts() { 912 gum::learning::DBTranslatorSet<> trans_set; 913 { 914 // create the translator set 915 gum::LabelizedVariable var("X1", "", 0); 916 var.addLabel("0"); 917 var.addLabel("1"); 918 var.addLabel("2"); 919 920 const std::vector< std::string > miss; 921 gum::learning::DBTranslator4LabelizedVariable<> translator(var, miss); 922 std::vector< std::string > names{"A", "B"}; 923 924 for (std::size_t i = std::size_t(0); i < names.size(); ++i) { 925 translator.setVariableName(names[i]); 926 trans_set.insertTranslator(translator, i); 927 } 928 } 929 930 // create the database 931 gum::learning::DatabaseTable<> database(trans_set); 932 database.insertRow({"0", "1"}); 933 database.insertRow({"1", "0"}); 934 database.insertRow({"1", "1"}); 935 database.insertRow({"0", "1"}); 936 database.insertRow({"0", "0"}); 937 938 // create the parser 939 gum::learning::DBRowGeneratorSet<> genset; 940 gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset); 941 gum::learning::AprioriNoApriori<> extern_apriori(database); 942 gum::learning::AprioriNoApriori<> intern_apriori(database); 943 944 gum::learning::ParamEstimatorML<> param_estimator(parser, extern_apriori, intern_apriori); 945 946 TS_GUM_ASSERT_THROWS_NOTHING(param_estimator.parameters(gum::NodeId(0))) 947 TS_GUM_ASSERT_THROWS_NOTHING(param_estimator.parameters(gum::NodeId(1))) 948 TS_ASSERT_THROWS(param_estimator.parameters(gum::NodeId(1), {gum::NodeId(0)}), 949 gum::DatabaseError) 950 } 951 }; 952 953 } // namespace gum_tests 954