1 /** 2 * 3 * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU) 4 * info_at_agrum_dot_org 5 * 6 * This library is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU Lesser General Public License as published by 8 * the Free Software Foundation, either version 3 of the License, or 9 * (at your option) any later version. 10 * 11 * This library is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 * GNU Lesser General Public License for more details. 15 * 16 * You should have received a copy of the GNU Lesser General Public License 17 * along with this library. If not, see <http://www.gnu.org/licenses/>. 18 * 19 */ 20 21 22 // floating point env 23 #include <cfenv> 24 #include <vector> 25 #include <string> 26 #include <iostream> 27 28 #include <gumtest/AgrumTestSuite.h> 29 #include <gumtest/testsuite_utils.h> 30 31 #include <agrum/BN/learning/BNLearner.h> 32 33 #include <agrum/tools/core/approximations/approximationSchemeListener.h> 34 35 #include <agrum/BN/database/BNDatabaseGenerator.h> 36 37 namespace gum_tests { 38 39 class aSimpleBNLeanerListener: public gum::ApproximationSchemeListener { 40 private: 41 gum::Size _nbr_; 42 std::string _mess_; 43 44 public: aSimpleBNLeanerListener(gum::IApproximationSchemeConfiguration & sch)45 aSimpleBNLeanerListener(gum::IApproximationSchemeConfiguration& sch) : 46 gum::ApproximationSchemeListener(sch), _nbr_(0), _mess_(""){}; 47 whenProgress(const void * buffer,const gum::Size a,const double b,const double c)48 void whenProgress(const void* buffer, const gum::Size a, const double b, const double c) { 49 _nbr_++; 50 } 51 whenStop(const void * buffer,const std::string s)52 void whenStop(const void* buffer, const std::string s) { _mess_ = s; } 53 getNbr()54 gum::Size getNbr() { return _nbr_; } 55 getMess()56 std::string getMess() { return _mess_; } 57 }; 58 59 class BNLearnerTestSuite: public CxxTest::TestSuite { 60 public: test_asia()61 void test_asia() { 62 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 63 64 learner.useLocalSearchWithTabuList(100, 1); 65 learner.setMaxIndegree(10); 66 learner.useScoreLog2Likelihood(); 67 68 TS_GUM_ASSERT_THROWS_NOTHING(learner.useScoreBD()) 69 TS_ASSERT_DIFFERS("", learner.checkScoreAprioriCompatibility()) 70 TS_GUM_ASSERT_THROWS_NOTHING(learner.useScoreBDeu()) 71 TS_ASSERT_EQUALS("", learner.checkScoreAprioriCompatibility()) 72 learner.useScoreLog2Likelihood(); 73 74 learner.useK2(std::vector< gum::NodeId >{1, 5, 2, 6, 0, 3, 4, 7}); 75 // learner.addForbiddenArc ( gum::Arc (4,3) ); 76 // learner.addForbiddenArc ( gum::Arc (5,1) ); 77 // learner.addForbiddenArc ( gum::Arc (5,7) ); 78 79 // learner.addMandatoryArc ( gum::Arc ( learner.nodeId ( "bronchitis" ), 80 // learner.nodeId ( "lung_cancer" ) 81 // ) ); 82 83 learner.addMandatoryArc("bronchitis", "lung_cancer"); 84 85 learner.useAprioriSmoothing(); 86 // learner.useAprioriDirichlet ( GET_RESSOURCES_PATH( "asia.csv" ) ); 87 88 gum::NodeProperty< gum::Size > slice_order{std::make_pair(gum::NodeId(0), (gum::Size)1), 89 std::make_pair(gum::NodeId(3), (gum::Size)0), 90 std::make_pair(gum::NodeId(1), (gum::Size)0)}; 91 learner.setSliceOrder(slice_order); 92 93 const std::vector< std::string >& names = learner.names(); 94 TS_ASSERT(!names.empty()) 95 96 try { 97 gum::BayesNet< double > bn = learner.learnBN(); 98 TS_ASSERT_EQUALS(bn.dag().arcs().size(), (gum::Size)9) 99 } catch (gum::Exception& e) { GUM_SHOWERROR(e); } 100 101 learner.setDatabaseWeight(10.0); 102 const auto& db = learner.database(); 103 const double weight = 10.0 / double(db.nbRows()); 104 for (const auto& row: db) { 105 TS_ASSERT_EQUALS(row.weight(), weight) 106 } 107 TS_ASSERT_DELTA(learner.databaseWeight(), 10.0, 1e-4) 108 109 const std::size_t nbr = db.nbRows(); 110 for (std::size_t i = std::size_t(0); i < nbr; ++i) { 111 if (i % 2) learner.setRecordWeight(i, 2.0); 112 } 113 114 std::size_t index = std::size_t(0); 115 for (const auto& row: db) { 116 if (index % 2) { 117 TS_ASSERT_EQUALS(row.weight(), 2.0) 118 TS_ASSERT_EQUALS(learner.recordWeight(index), 2.0) 119 } else { 120 TS_ASSERT_EQUALS(row.weight(), 10.0) 121 TS_ASSERT_EQUALS(learner.recordWeight(index), 10.0) 122 } 123 ++index; 124 } 125 } 126 test_induceTypes()127 void test_induceTypes() { 128 { 129 gum::learning::BNLearner< double > learner1(GET_RESSOURCES_PATH("csv/asia.csv")); 130 learner1.useScoreBDeu(); 131 learner1.useNoApriori(); 132 gum::BayesNet< double > bn1 = learner1.learnBN(); 133 134 gum::learning::BNLearner< double > learner2(GET_RESSOURCES_PATH("csv/asia.csv"), 135 true, 136 {"?"}); 137 for (const auto trans2: learner2.database().translatorSet().translators()) { 138 const auto& var2 = trans2->variable(); 139 TS_ASSERT((var2->varType() == gum::VarType::Range) 140 || (var2->varType() == gum::VarType::Integer)); 141 } 142 143 learner2.useScoreBDeu(); 144 learner2.useNoApriori(); 145 gum::BayesNet< double > bn2 = learner2.learnBN(); 146 147 TS_ASSERT_EQUALS(bn1.dag(), bn2.dag()) 148 } 149 150 { 151 gum::learning::BNLearner< double > learner1(GET_RESSOURCES_PATH("csv/alarm.csv")); 152 learner1.useScoreBDeu(); 153 learner1.useNoApriori(); 154 gum::BayesNet< double > bn1 = learner1.learnBN(); 155 156 gum::learning::BNLearner< double > learner2(GET_RESSOURCES_PATH("csv/alarm.csv"), 157 true, 158 {"?"}); 159 160 for (const auto trans2: learner2.database().translatorSet().translators()) { 161 const auto& var2 = trans2->variable(); 162 TS_ASSERT((var2->varType() == gum::VarType::Range) 163 || (var2->varType() == gum::VarType::Integer)); 164 } 165 166 learner2.useScoreBDeu(); 167 learner2.useNoApriori(); 168 gum::BayesNet< double > bn2 = learner2.learnBN(); 169 170 TS_ASSERT_EQUALS(bn1.dag(), bn2.dag()) 171 } 172 173 { 174 auto bn = gum::BayesNet< double >::fastPrototype("A->B<-C->D->E<-B"); 175 gum::learning::BNDatabaseGenerator< double > genere(bn); 176 genere.setRandomVarOrder(); 177 genere.drawSamples(2000); 178 genere.toCSV(GET_RESSOURCES_PATH("outputs/bnlearner_dirichlet.csv")); 179 180 auto bn2 = gum::BayesNet< double >::fastPrototype("A->B->C->D->E"); 181 gum::learning::BNDatabaseGenerator< double > genere2(bn2); 182 genere2.drawSamples(100); 183 genere2.toCSV(GET_RESSOURCES_PATH("outputs/bnlearner_database.csv")); 184 185 gum::learning::BNLearner< double > learner1( 186 GET_RESSOURCES_PATH("outputs/bnlearner_database.csv")); 187 learner1.useAprioriDirichlet(GET_RESSOURCES_PATH("outputs/bnlearner_dirichlet.csv"), 10); 188 learner1.useScoreAIC(); 189 gum::BayesNet< double > xbn1 = learner1.learnBN(); 190 191 gum::learning::BNLearner< double > learner2( 192 GET_RESSOURCES_PATH("outputs/bnlearner_database.csv"), 193 true, 194 {"?"}); 195 196 for (const auto trans2: learner2.database().translatorSet().translators()) { 197 const auto& var2 = trans2->variable(); 198 TS_ASSERT((var2->varType() == gum::VarType::Range) 199 || (var2->varType() == gum::VarType::Integer)); 200 } 201 202 learner2.useAprioriDirichlet(GET_RESSOURCES_PATH("outputs/bnlearner_dirichlet.csv"), 10); 203 learner2.useScoreAIC(); 204 gum::BayesNet< double > xbn2 = learner2.learnBN(); 205 206 TS_ASSERT_EQUALS(xbn1.dag(), xbn2.dag()) 207 } 208 } 209 210 test_guill()211 void test_guill() { 212 try { 213 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3_withguill.csv")); 214 TS_FAIL("asia3_withguill.csv contains syntax error (with \")."); 215 } catch (gum::SyntaxError& e) {}; 216 } 217 test_ranges()218 void test_ranges() { 219 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 220 221 learner.useGreedyHillClimbing(); 222 learner.useScoreBIC(); 223 learner.useAprioriSmoothing(); 224 225 const std::size_t k = 5; 226 const auto& database = learner.database(); 227 const std::size_t dbsize = database.nbRows(); 228 std::size_t foldSize = dbsize / k; 229 230 gum::learning::DBRowGeneratorSet<> genset; 231 gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset); 232 gum::learning::AprioriSmoothing<> apriori(database); 233 apriori.setWeight(1); 234 235 gum::learning::StructuralConstraintSetStatic< gum::learning::StructuralConstraintDAG > 236 struct_constraint; 237 238 gum::learning::GraphChangesGenerator4DiGraph< decltype(struct_constraint) > op_set( 239 struct_constraint); 240 241 gum::learning::GreedyHillClimbing search; 242 243 gum::learning::ScoreBIC<> score(parser, apriori); 244 gum::learning::ParamEstimatorML<> estimator(parser, apriori, score.internalApriori()); 245 for (std::size_t fold = 0; fold < k; fold++) { 246 // create the ranges of rows over which we perform the learning 247 const std::size_t unfold_deb = fold * foldSize; 248 const std::size_t unfold_end = unfold_deb + foldSize; 249 250 std::vector< std::pair< std::size_t, std::size_t > > ranges; 251 if (fold == std::size_t(0)) { 252 ranges.push_back(std::pair< std::size_t, std::size_t >(unfold_end, dbsize)); 253 } else { 254 ranges.push_back(std::pair< std::size_t, std::size_t >(std::size_t(0), unfold_deb)); 255 256 if (fold != k - 1) { 257 ranges.push_back(std::pair< std::size_t, std::size_t >(unfold_end, dbsize)); 258 } 259 } 260 261 learner.useDatabaseRanges(ranges); 262 TS_ASSERT_EQUALS(learner.databaseRanges(), ranges) 263 264 learner.clearDatabaseRanges(); 265 TS_ASSERT_DIFFERS(learner.databaseRanges(), ranges) 266 267 learner.useCrossValidationFold(fold, k); 268 TS_ASSERT_EQUALS(learner.databaseRanges(), ranges) 269 270 gum::BayesNet< double > bn1 = learner.learnBN(); 271 272 273 score.setRanges(ranges); 274 estimator.setRanges(ranges); 275 gum::learning::GraphChangesSelector4DiGraph< decltype(struct_constraint), decltype(op_set) > 276 selector(score, struct_constraint, op_set); 277 gum::BayesNet< double > bn2 = search.learnBN< double >(selector, estimator); 278 279 TS_ASSERT_EQUALS(bn1.dag(), bn2.dag()) 280 281 gum::Instantiation I1, I2; 282 283 for (auto& name: database.variableNames()) { 284 I1.add(bn1.variableFromName(name)); 285 I2.add(bn2.variableFromName(name)); 286 } 287 288 double LL1 = 0.0, LL2 = 0.0; 289 const std::size_t nbCol = database.nbVariables(); 290 parser.setRange(unfold_deb, unfold_end); 291 while (parser.hasRows()) { 292 const gum::learning::DBRow< gum::learning::DBTranslatedValue >& row = parser.row(); 293 for (std::size_t i = 0; i < nbCol; ++i) { 294 I1.chgVal(i, row[i].discr_val); 295 I2.chgVal(i, row[i].discr_val); 296 } 297 298 LL1 += bn1.log2JointProbability(I1) * row.weight(); 299 LL2 += bn2.log2JointProbability(I2) * row.weight(); 300 } 301 302 TS_ASSERT_EQUALS(LL1, LL2) 303 } 304 } 305 306 test_asia_3off2()307 void test_asia_3off2() { 308 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia.csv")); 309 310 aSimpleBNLeanerListener listen(learner); 311 312 learner.useGreedyHillClimbing(); 313 314 learner.use3off2(); 315 learner.useNMLCorrection(); 316 learner.addForbiddenArc(gum::Arc(4, 1)); 317 // learner.addForbiddenArc ( gum::Arc (5,1) ); 318 // learner.addForbiddenArc ( gum::Arc (5,7) ); 319 320 learner.addMandatoryArc(gum::Arc(7, 5)); 321 gum::DAG i_dag; 322 for (gum::NodeId i = 0; i < 8; ++i) { 323 i_dag.addNodeWithId(i); 324 } 325 learner.setInitialDAG(i_dag); 326 // learner.addMandatoryArc( "bronchitis", "lung_cancer" ); 327 328 const std::vector< std::string >& names = learner.names(); 329 TS_ASSERT(!names.empty()) 330 331 try { 332 gum::BayesNet< double > bn = learner.learnBN(); 333 TS_ASSERT_EQUALS(bn.dag().arcs().size(), (gum::Size)9) 334 // TS_ASSERT_EQUALS(listen.getNbr(), (gum::Size)86) 335 TS_ASSERT(!bn.dag().existsArc(4, 1)) 336 TS_ASSERT(bn.dag().existsArc(7, 5)) 337 338 gum::MixedGraph mg = learner.learnMixedStructure(); 339 TS_ASSERT_EQUALS(mg.arcs().size(), (gum::Size)8) 340 TS_ASSERT_EQUALS(mg.edges().size(), (gum::Size)1) 341 TS_ASSERT(!mg.existsArc(4, 1)) 342 TS_ASSERT(mg.existsArc(7, 5)) 343 std::vector< gum::Arc > latents = learner.latentVariables(); 344 TS_ASSERT_EQUALS(latents.size(), (gum::Size)2) 345 } catch (gum::Exception& e) { GUM_SHOWERROR(e); } 346 } 347 348 // WARNING: this test is commented on purpose: you need a running database 349 // with a table filled with the content of the asia.csv file. You will also 350 // need a proper odbc configuration (under linux and macos you'll need 351 // unixodbc and specific database odbc drivers). 352 // void test_asia_db() { 353 // try { 354 // auto db = gum::learning::DatabaseFromSQL( 355 // "PostgreSQL", 356 // "lto", 357 // "Password2Change", 358 // "select smoking , lung_cancer , bronchitis , visit_to_asia , " 359 // "tuberculosis , tuberculos_or_cancer , dyspnoea , positive_xray " 360 // "from asia;" ); 361 // gum::learning::BNLearner<double> learner( db ); 362 363 // learner.useLocalSearchWithTabuList( 100, 1 ); 364 // learner.setMaxIndegree( 10 ); 365 // learner.useScoreLog2Likelihood(); 366 367 // TS_ASSERT_THROWS( learner.useScoreBD(), gum::IncompatibleScoreApriori 368 // ); 369 // TS_GUM_ASSERT_THROWS_NOTHING( learner.useScoreBDeu() ); 370 // learner.useScoreLog2Likelihood(); 371 372 // learner.useK2( std::vector<gum::NodeId>{1, 5, 2, 6, 0, 3, 4, 7} ); 373 // learner.addMandatoryArc( "bronchitis", "lung_cancer" ); 374 // learner.useAprioriSmoothing(); 375 376 // gum::NodeProperty<unsigned int> slice_order{ 377 // std::make_pair( gum::NodeId( 0 ), 1 ), 378 // std::make_pair( gum::NodeId( 3 ), 0 ), 379 // std::make_pair( gum::NodeId( 1 ), 0 )}; 380 381 // const std::vector<std::string>& names = learner.names(); 382 // TS_ASSERT( !names.empty() ) 383 384 // try { 385 // gum::BayesNet<double> bn = learner.learnBN(); 386 // TS_ASSERT_EQUALS( bn.dag().arcs().size() , 9 ) 387 // } catch ( gum::Exception& e ) { 388 // GUM_SHOWERROR( e ); 389 // } 390 // } catch ( gum::Exception& e ) { 391 // GUM_TRACE( e.errorType() ); 392 // GUM_TRACE( e.errorContent() ); 393 // GUM_TRACE( e.errorCallStack() ); 394 // TS_FAIL("plop"); 395 // } 396 // } 397 test_asia_with_domain_sizes()398 void test_asia_with_domain_sizes() { 399 gum::learning::BNLearner< double > learn(GET_RESSOURCES_PATH("csv/asia3.csv")); 400 const auto& database = learn.database(); 401 402 gum::BayesNet< double > bn; 403 for (auto& name: database.variableNames()) { 404 gum::LabelizedVariable var(name, name, {"false", "true", "big"}); 405 bn.add(var); 406 } 407 408 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"), bn); 409 learner.useScoreBIC(); 410 learner.useAprioriSmoothing(); 411 412 gum::BayesNet< double > bn2 = learner.learnBN(); 413 for (auto& name: database.variableNames()) { 414 TS_ASSERT_EQUALS(bn2.variableFromName(name).domainSize(), (gum::Size)3) 415 } 416 } 417 xtest_asia_with_user_modalities_string_min()418 void xtest_asia_with_user_modalities_string_min() { 419 gum::NodeProperty< gum::Sequence< std::string > > modals; 420 modals.insert(0, gum::Sequence< std::string >()); 421 modals[0].insert("false"); 422 modals[0].insert("true"); 423 modals[0].insert("big"); 424 425 modals.insert(2, gum::Sequence< std::string >()); 426 modals[2].insert("big"); 427 modals[2].insert("bigbig"); 428 modals[2].insert("true"); 429 modals[2].insert("bigbigbig"); 430 modals[2].insert("false"); 431 432 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 433 // GET_RESSOURCES_PATH("csv/asia3.csv"), modals, true); 434 435 learner.useGreedyHillClimbing(); 436 learner.setMaxIndegree(10); 437 learner.useScoreLog2Likelihood(); 438 439 TS_ASSERT_THROWS(learner.useScoreBD(), gum::IncompatibleScoreApriori) 440 TS_GUM_ASSERT_THROWS_NOTHING(learner.useScoreBDeu()); 441 learner.useScoreLog2Likelihood(); 442 443 learner.useK2(std::vector< gum::NodeId >{1, 5, 2, 6, 0, 3, 4, 7}); 444 // learner.addForbiddenArc ( gum::Arc (4,3) ); 445 // learner.addForbiddenArc ( gum::Arc (5,1) ); 446 // learner.addForbiddenArc ( gum::Arc (5,7) ); 447 448 // learner.addMandatoryArc ( gum::Arc ( learner.nodeId ( "bronchitis" ), 449 // learner.nodeId ( "lung_cancer" ) 450 // ) ); 451 452 learner.addMandatoryArc("bronchitis", "lung_cancer"); 453 454 learner.useAprioriSmoothing(); 455 // learner.useAprioriDirichlet ( GET_RESSOURCES_PATH( "asia.csv" ) ); 456 457 gum::NodeProperty< gum::Size > slice_order{std::make_pair(gum::NodeId(0), (gum::Size)1), 458 std::make_pair(gum::NodeId(3), (gum::Size)0), 459 std::make_pair(gum::NodeId(1), (gum::Size)0)}; 460 learner.setSliceOrder(slice_order); 461 462 const std::vector< std::string >& names = learner.names(); 463 TS_ASSERT(!names.empty()) 464 465 try { 466 gum::BayesNet< double > bn = learner.learnBN(); 467 TS_ASSERT_EQUALS(bn.variable(0).domainSize(), (gum::Size)2) 468 TS_ASSERT_EQUALS(bn.variable(2).domainSize(), (gum::Size)2) 469 TS_ASSERT_EQUALS(bn.variable(0).label(0), "false") 470 TS_ASSERT_EQUALS(bn.variable(0).label(1), "true") 471 } catch (gum::Exception& e) { GUM_SHOWERROR(e); } 472 } 473 xtest_asia_with_user_modalities_string_incorrect()474 void xtest_asia_with_user_modalities_string_incorrect() { 475 gum::NodeProperty< gum::Sequence< std::string > > modals; 476 modals.insert(0, gum::Sequence< std::string >()); 477 modals[0].insert("False"); 478 modals[0].insert("true"); 479 modals[0].insert("big"); 480 481 modals.insert(2, gum::Sequence< std::string >()); 482 modals[2].insert("big"); 483 modals[2].insert("bigbig"); 484 modals[2].insert("true"); 485 modals[2].insert("bigbigbig"); 486 modals[2].insert("false"); 487 488 bool except = false; 489 490 try { 491 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 492 // GET_RESSOURCES_PATH("csv/asia3.csv"), modals); 493 learner.useAprioriSmoothing(); 494 } catch (gum::UnknownLabelInDatabase&) { except = true; } 495 496 TS_ASSERT(except) 497 } 498 xtest_asia_with_user_modalities_numbers()499 void xtest_asia_with_user_modalities_numbers() { 500 gum::NodeProperty< gum::Sequence< std::string > > modals; 501 modals.insert(0, gum::Sequence< std::string >()); 502 modals[0].insert("0"); 503 modals[0].insert("1"); 504 modals[0].insert("big"); 505 506 modals.insert(2, gum::Sequence< std::string >()); 507 modals[2].insert("big"); 508 modals[2].insert("bigbig"); 509 modals[2].insert("1"); 510 modals[2].insert("bigbigbig"); 511 modals[2].insert("0"); 512 513 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia.csv")); 514 // learner(GET_RESSOURCES_PATH("csv/asia.csv"), modals); 515 learner.useGreedyHillClimbing(); 516 learner.setMaxIndegree(10); 517 learner.useScoreLog2Likelihood(); 518 519 TS_ASSERT_THROWS(learner.useScoreBD(), gum::IncompatibleScoreApriori) 520 TS_GUM_ASSERT_THROWS_NOTHING(learner.useScoreBDeu()); 521 learner.useScoreLog2Likelihood(); 522 523 learner.useK2(std::vector< gum::NodeId >{1, 5, 2, 6, 0, 3, 4, 7}); 524 learner.addForbiddenArc(gum::Arc(4, 3)); 525 learner.addForbiddenArc(gum::Arc(5, 1)); 526 learner.addForbiddenArc(gum::Arc(5, 7)); 527 528 learner.addMandatoryArc("bronchitis", "lung_cancer"); 529 530 learner.useAprioriSmoothing(); 531 // learner.useAprioriDirichlet ( GET_RESSOURCES_PATH( "asia.csv" ) ); 532 533 gum::NodeProperty< gum::Size > slice_order{std::make_pair(gum::NodeId(0), (gum::Size)1), 534 std::make_pair(gum::NodeId(3), (gum::Size)0), 535 std::make_pair(gum::NodeId(1), (gum::Size)0)}; 536 learner.setSliceOrder(slice_order); 537 538 const std::vector< std::string >& names = learner.names(); 539 TS_ASSERT(!names.empty()) 540 541 try { 542 gum::BayesNet< double > bn = learner.learnBN(); 543 TS_ASSERT_EQUALS(bn.variable(0).domainSize(), (gum::Size)3) 544 TS_ASSERT_EQUALS(bn.variable(2).domainSize(), (gum::Size)5) 545 TS_ASSERT_EQUALS(bn.variable(0).label(0), "0") 546 TS_ASSERT_EQUALS(bn.variable(0).label(1), "1") 547 TS_ASSERT_EQUALS(bn.variable(0).label(2), "big") 548 } catch (gum::Exception& e) { GUM_SHOWERROR(e); } 549 } 550 xtest_asia_with_user_modalities_numbers_incorrect()551 void xtest_asia_with_user_modalities_numbers_incorrect() { 552 gum::NodeProperty< gum::Sequence< std::string > > modals; 553 modals.insert(0, gum::Sequence< std::string >()); 554 modals[0].insert("1"); 555 modals[0].insert("2"); 556 modals[0].insert("big"); 557 558 modals.insert(2, gum::Sequence< std::string >()); 559 modals[2].insert("big"); 560 modals[2].insert("bigbig"); 561 modals[2].insert("3"); 562 modals[2].insert("bigbigbig"); 563 modals[2].insert("0"); 564 565 bool except = false; 566 567 try { 568 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia.csv")); 569 // learner(GET_RESSOURCES_PATH("csv/asia.csv"), modals); 570 learner.useAprioriSmoothing(); 571 } catch (gum::UnknownLabelInDatabase&) { except = true; } 572 573 TS_ASSERT(except) 574 } 575 test_asia_param()576 void test_asia_param() { 577 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 578 579 gum::DAG dag; 580 581 for (unsigned int i = 0; i < 8; ++i) { 582 dag.addNodeWithId(i); 583 } 584 585 for (unsigned int i = 0; i < 7; ++i) { 586 dag.addArc(i, i + 1); 587 } 588 589 dag.addArc(0, 7); 590 dag.addArc(2, 4); 591 dag.addArc(5, 7); 592 dag.addArc(3, 6); 593 594 learner.useNoApriori(); 595 596 try { 597 gum::BayesNet< double > bn = learner.learnParameters(dag); 598 TS_ASSERT_EQUALS(bn.dim(), (gum::Size)25) 599 } catch (gum::Exception& e) { GUM_SHOWERROR(e); } 600 } 601 test_asia_param_from_bn()602 void test_asia_param_from_bn() { 603 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 604 605 learner.useK2(std::vector< gum::NodeId >{1, 5, 2, 6, 0, 3, 4, 7}); 606 gum::BayesNet< double > bn = learner.learnBN(); 607 608 try { 609 gum::BayesNet< double > bn2 = learner.learnParameters(bn.dag()); 610 TS_ASSERT_EQUALS(bn2.dag().arcs().size(), bn.dag().arcs().size()) 611 } catch (gum::Exception& e) { GUM_SHOWERROR(e); } 612 } 613 614 test_asia_param_float()615 void test_asia_param_float() { 616 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 617 618 gum::DAG dag; 619 620 for (unsigned int i = 0; i < 8; ++i) { 621 dag.addNodeWithId(i); 622 } 623 624 for (unsigned int i = 0; i < 7; ++i) { 625 dag.addArc(i, i + 1); 626 } 627 628 dag.addArc(0, 7); 629 dag.addArc(2, 4); 630 dag.addArc(5, 7); 631 dag.addArc(3, 6); 632 633 learner.useNoApriori(); 634 635 try { 636 gum::BayesNet< double > bn = learner.learnParameters(dag); 637 TS_ASSERT_EQUALS(bn.dim(), (gum::Size)25) 638 } catch (gum::Exception& e) { GUM_SHOWERROR(e); } 639 } 640 test_asia_param_from_bn_float()641 void test_asia_param_from_bn_float() { 642 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 643 644 learner.useK2(std::vector< gum::NodeId >{1, 5, 2, 6, 0, 3, 4, 7}); 645 gum::BayesNet< double > bn = learner.learnBN(); 646 647 try { 648 gum::BayesNet< double > bn2 = learner.learnParameters(bn.dag()); 649 TS_ASSERT_EQUALS(bn2.dag().arcs().size(), bn.dag().arcs().size()) 650 } catch (gum::Exception& e) { GUM_SHOWERROR(e); } 651 } 652 test_asia_param_bn()653 void test_asia_param_bn() { 654 #define createBoolVar(s) gum::LabelizedVariable(s, s, 0).addLabel("false").addLabel("true"); 655 // smoking,lung_cancer,bronchitis,visit_to_Asia,tuberculosis,tuberculos_or_cancer,dyspnoea,positive_XraY 656 auto s = createBoolVar("smoking"); 657 auto l = createBoolVar("lung_cancer"); 658 auto b = createBoolVar("bronchitis"); 659 auto v = createBoolVar("visit_to_Asia"); 660 auto t = createBoolVar("tuberculosis"); 661 auto o = createBoolVar("tuberculos_or_cancer"); 662 auto d = createBoolVar("dyspnoea"); 663 auto p = createBoolVar("positive_XraY"); 664 #undef createBoolVar 665 666 gum::BayesNet< double > bn; 667 gum::NodeId ns = bn.add(s); 668 gum::NodeId nl = bn.add(l); 669 gum::NodeId nb = bn.add(b); 670 gum::NodeId nv = bn.add(v); 671 gum::NodeId nt = bn.add(t); 672 gum::NodeId no = bn.add(o); 673 gum::NodeId nd = bn.add(d); 674 gum::NodeId np = bn.add(p); 675 676 bn.addArc(ns, nl); 677 bn.addArc(ns, nb); 678 bn.addArc(nl, no); 679 bn.addArc(nb, nd); 680 bn.addArc(nv, nt); 681 bn.addArc(nt, no); 682 bn.addArc(no, nd); 683 bn.addArc(no, np); 684 685 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"), bn); 686 687 learner.useScoreLog2Likelihood(); 688 learner.useAprioriSmoothing(); 689 690 try { 691 gum::BayesNet< double > bn2 = learner.learnParameters(bn.dag()); 692 TS_ASSERT_EQUALS(bn2.dim(), bn.dim()) 693 694 for (gum::NodeId node: bn.nodes()) { 695 gum::NodeId node2 = bn2.idFromName(bn.variable(node).name()); 696 TS_ASSERT_EQUALS(bn.variable(node).toString(), bn2.variable(node2).toString()) 697 } 698 699 } catch (gum::Exception& e) { GUM_SHOWERROR(e); } 700 } 701 test_asia_param_bn_with_not_matching_variable()702 void test_asia_param_bn_with_not_matching_variable() { 703 #define createBoolVar(s) gum::LabelizedVariable(s, s, 0).addLabel("false").addLabel("true"); 704 auto s = createBoolVar("smoking"); 705 auto l = createBoolVar("lung_cancer"); 706 auto b = createBoolVar("bronchitis"); 707 auto v = createBoolVar("visit_to_Asia"); 708 auto t = createBoolVar("tuberculosis"); 709 auto o = createBoolVar("tuberculos_or_cancer"); 710 auto d = createBoolVar("dyspnoea"); 711 712 // uncorrect name is : will it be correctly handled 713 auto p = createBoolVar("ZORBLOBO"); 714 #undef createBoolVar 715 716 gum::BayesNet< double > bn; 717 bn.add(s); 718 bn.add(l); 719 bn.add(b); 720 bn.add(v); 721 bn.add(t); 722 bn.add(o); 723 bn.add(d); 724 bn.add(p); 725 726 TS_ASSERT_THROWS( 727 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"), bn), 728 gum::MissingVariableInDatabase); 729 730 731 // learner.useScoreLog2Likelihood(); 732 // learner.useAprioriSmoothing(); 733 734 // TS_ASSERT_THROWS(gum::BayesNet< double > bn2 = 735 // learner.learnParameters(bn), 736 // gum::MissingVariableInDatabase); 737 } 738 test_asia_param_bn_with_subset_of_variables_in_base()739 void test_asia_param_bn_with_subset_of_variables_in_base() { 740 #define createBoolVar(s) gum::LabelizedVariable(s, s, 0).addLabel("false").addLabel("true"); 741 auto s = createBoolVar("smoking"); 742 auto t = createBoolVar("tuberculosis"); 743 auto o = createBoolVar("tuberculos_or_cancer"); 744 auto d = createBoolVar("dyspnoea"); 745 #undef createBoolVar 746 747 gum::BayesNet< double > bn; 748 gum::NodeId ns = bn.add(s); 749 gum::NodeId nt = bn.add(t); 750 gum::NodeId no = bn.add(o); 751 gum::NodeId nd = bn.add(d); 752 753 bn.addArc(ns, nt); 754 bn.addArc(nt, no); 755 bn.addArc(no, nd); 756 757 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"), bn); 758 759 760 learner.useScoreLog2Likelihood(); 761 learner.useAprioriSmoothing(); 762 763 gum::BayesNet< double > bn2 = learner.learnParameters(bn.dag()); 764 } 765 test_asia_param_bn_with_unknow_modality()766 void test_asia_param_bn_with_unknow_modality() { 767 #define createBoolVar(s) gum::LabelizedVariable(s, s, 0).addLabel("false").addLabel("true"); 768 auto s = createBoolVar("smoking"); 769 auto t = createBoolVar("tuberculosis"); 770 auto o = createBoolVar("tuberculos_or_cancer"); 771 auto d = createBoolVar("dyspnoea"); 772 #undef createBoolVar 773 774 gum::BayesNet< double > bn; 775 gum::NodeId ns = bn.add(s); 776 gum::NodeId nt = bn.add(t); 777 gum::NodeId no = bn.add(o); 778 gum::NodeId nd = bn.add(d); 779 780 bn.addArc(ns, nt); 781 bn.addArc(nt, no); 782 bn.addArc(no, nd); 783 784 // asia3-faulty contains a label "beurk" for variable "smoking" 785 // std::cout << "error test"; 786 787 TS_ASSERT_THROWS( 788 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3-faulty.csv"), 789 bn), 790 gum::UnknownLabelInDatabase); 791 } 792 test_listener()793 void test_listener() { 794 { 795 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia.csv")); 796 aSimpleBNLeanerListener listen(learner); 797 798 learner.setVerbosity(true); 799 learner.setMaxIndegree(10); 800 learner.useScoreK2(); 801 learner.useK2(std::vector< gum::NodeId >{1, 5, 2, 6, 0, 3, 4, 7}); 802 803 gum::BayesNet< double > bn = learner.learnBN(); 804 805 TS_ASSERT_EQUALS(listen.getNbr(), (gum::Size)2) 806 TS_ASSERT_EQUALS(listen.getMess(), "stopped on request") 807 TS_ASSERT_EQUALS(learner.messageApproximationScheme(), "stopped on request") 808 } 809 { 810 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia2.csv")); 811 aSimpleBNLeanerListener listen(learner); 812 813 learner.setVerbosity(true); 814 learner.setMaxIndegree(10); 815 learner.useScoreK2(); 816 learner.useK2(std::vector< gum::NodeId >{1, 5, 2, 6, 0, 3, 4, 7}); 817 818 gum::BayesNet< double > bn = learner.learnBN(); 819 820 TS_ASSERT_EQUALS(listen.getNbr(), (gum::Size)3) 821 TS_ASSERT_EQUALS(listen.getMess(), "stopped on request") 822 TS_ASSERT_EQUALS(learner.messageApproximationScheme(), "stopped on request") 823 } 824 { 825 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia.csv")); 826 aSimpleBNLeanerListener listen(learner); 827 828 learner.setVerbosity(true); 829 learner.setMaxIndegree(2); 830 learner.useLocalSearchWithTabuList(); 831 832 gum::BayesNet< double > bn = learner.learnBN(); 833 // std::cout << bn.dag () << std::endl; 834 835 TS_ASSERT_DELTA(listen.getNbr(), (gum::Size)15, 1); // 75 836 TS_ASSERT_EQUALS(listen.getMess(), "stopped on request") 837 TS_ASSERT_EQUALS(learner.messageApproximationScheme(), "stopped on request") 838 } 839 { 840 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia.csv")); 841 aSimpleBNLeanerListener listen(learner); 842 843 learner.setVerbosity(true); 844 learner.setMaxIndegree(2); 845 learner.useGreedyHillClimbing(); 846 847 gum::BayesNet< double > bn = learner.learnBN(); 848 849 TS_ASSERT_DELTA(listen.getNbr(), (gum::Size)3, 1); // 2? 850 TS_ASSERT_EQUALS(listen.getMess(), "stopped on request") 851 TS_ASSERT_EQUALS(learner.messageApproximationScheme(), "stopped on request") 852 } 853 } 854 test_DBNTonda()855 void test_DBNTonda() { 856 gum::BayesNet< double > dbn; 857 gum::NodeId bf_0 = dbn.add(gum::LabelizedVariable("bf_0", "bf_0", 4)); 858 /*gum::NodeId bf_t =*/dbn.add(gum::LabelizedVariable("bf_t", "bf_t", 4)); 859 gum::NodeId c_0 = dbn.add(gum::LabelizedVariable("c_0", "c_0", 5)); 860 gum::NodeId c_t = dbn.add(gum::LabelizedVariable("c_t", "c_t", 5)); 861 gum::NodeId h_0 = dbn.add(gum::LabelizedVariable("h_0", "h_0", 5)); 862 gum::NodeId h_t = dbn.add(gum::LabelizedVariable("h_t", "h_t", 5)); 863 gum::NodeId tf_0 = dbn.add(gum::LabelizedVariable("tf_0", "tf_0", 5)); 864 /*gum::NodeId tf_t =*/dbn.add(gum::LabelizedVariable("tf_t", "tf_t", 5)); 865 gum::NodeId wl_0 = dbn.add(gum::LabelizedVariable("wl_0", "wl_0", 4)); 866 gum::NodeId wl_t = dbn.add(gum::LabelizedVariable("wl_t", "wl_t", 4)); 867 868 for (auto n: {c_t, h_t, wl_t}) { 869 dbn.addArc(tf_0, n); 870 dbn.addArc(bf_0, n); 871 } 872 dbn.addArc(c_0, c_t); 873 dbn.addArc(h_0, h_t); 874 dbn.addArc(wl_0, wl_t); 875 876 gum::BayesNet< double > learn1; 877 { 878 // inductive learning leads to scrambled modalities 879 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/DBN_Tonda.csv")); 880 learner.useScoreLog2Likelihood(); 881 learner.useAprioriSmoothing(1.0); 882 learn1 = learner.learnParameters(dbn.dag()); 883 } 884 gum::BayesNet< double > learn2; 885 { 886 try { 887 /* 888 gum::NodeProperty< gum::Sequence< std::string > > modals; 889 auto ds = std::vector< unsigned int >({4, 4, 5, 5, 5, 5, 5, 5, 4, 4}); 890 auto labels = std::vector< std::string >({"0", "1", "2", "3", "4", "5"}); 891 892 for (auto i = 0U; i < ds.size(); i++) { 893 modals.insert(i, gum::Sequence< std::string >()); 894 895 for (auto k = 0U; k < ds[i]; k++) 896 modals[i].insert(labels[k]); 897 } 898 */ 899 900 // while explicit learning does the right thing 901 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/DBN_Tonda.csv"), 902 learn1); 903 learner.useScoreLog2Likelihood(); 904 learner.useAprioriSmoothing(1.0); 905 learn2 = learner.learnParameters(dbn.dag()); 906 } catch (gum::Exception& e) { GUM_SHOWERROR(e); } 907 } 908 gum::BayesNet< double > learn3; 909 { 910 // while explicit learning does the right thing 911 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/DBN_Tonda.csv"), dbn); 912 learner.useScoreLog2Likelihood(); 913 learner.useAprioriSmoothing(1.0); 914 learn3 = learner.learnParameters(dbn.dag()); 915 } 916 917 TS_ASSERT_EQUALS(learn1.variable(learn1.idFromName("wl_0")).toString(), "wl_0:Range([0,3])") 918 TS_ASSERT_EQUALS(learn2.variable(learn2.idFromName("wl_0")).toString(), "wl_0:Range([0,3])") 919 TS_ASSERT_EQUALS(learn2.variable(learn3.idFromName("wl_0")).toString(), "wl_0:Range([0,3])") 920 921 auto& p1 = learn1.cpt(learn1.idFromName("c_0")); 922 auto& p2 = learn2.cpt(learn2.idFromName("c_0")); 923 auto& p3 = learn3.cpt(learn3.idFromName("c_0")); 924 gum::Instantiation I1(p1), I2(p2), I3(p3); 925 926 for (I1.setFirst(), I2.setFirst(), I3.setFirst(); !I1.end(); I1.inc(), I2.inc(), I3.inc()) { 927 TS_ASSERT_EQUALS(I1.toString(), I2.toString()); // same modalities orders 928 TS_ASSERT_EQUALS(I1.toString(), I3.toString()); // same modalities orders 929 TS_ASSERT_EQUALS(p1[I1], p2[I2]); // same probabilities 930 TS_ASSERT_EQUALS(p1[I1], p3[I3]); // same probabilities 931 } 932 933 gum::BayesNet< double > learn4; 934 { 935 // inductive learning leads to scrambled modalities 936 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/DBN_Tonda.csv"), false); 937 learner.useScoreLog2Likelihood(); 938 learner.useAprioriSmoothing(1.0); 939 learn4 = learner.learnParameters(dbn.dag()); 940 } 941 TS_ASSERT_EQUALS(learn4.variable(learn1.idFromName("wl_0")).toString(), 942 "wl_0:Labelized(<0,1,2,3>)"); 943 } 944 945 test_asia_with_missing_values()946 void test_asia_with_missing_values() { 947 int nb = 0; 948 try { 949 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3-faulty.csv"), 950 true, 951 std::vector< std::string >{"BEURK"}); 952 learner.useK2(std::vector< gum::NodeId >{1, 5, 2, 6, 0, 3, 4, 7}); 953 learner.learnBN(); 954 } catch (gum::MissingValueInDatabase&) { nb = 1; } 955 956 TS_ASSERT_EQUALS(nb, 1) 957 } 958 test_BugDoumenc()959 void test_BugDoumenc() { 960 gum::BayesNet< double > templ; 961 std::vector< std::string > varBool{"S", 962 "DEP", 963 "TM", 964 "TE", 965 "TV", 966 "PSY", 967 "AL", 968 "PT", 969 "HYP", 970 "FRE", 971 "PC", 972 "C", 973 "MN", 974 "AM", 975 "PR", 976 "AR", 977 "DFM"}; // les vraibles booléennes du RB 978 979 std::vector< std::string > varTer{ 980 "NBC", 981 "MED", 982 "DEM", 983 "SP"}; // les variables pouvant prendre 3 valeurs possibles du RB 984 985 std::vector< std::string > varContinuous{"A", "ADL"}; // les variables continues du RB 986 987 988 std::vector< gum::NodeId > nodeList; // Liste des noeuds du RB 989 990 for (auto var: varBool) 991 nodeList.push_back(templ.add( 992 gum::LabelizedVariable(var, 993 var, 994 2))); // Ajout des variables booléennes à la liste des noeuds 995 996 for (auto var: varTer) 997 nodeList.push_back(templ.add( 998 gum::LabelizedVariable(var, 999 var, 1000 3))); // Ajout des variables ternaires à la liste des noeuds 1001 1002 gum::DiscretizedVariable< double > A("A", "A"); 1003 for (int i = 60; i <= 105; i += 5) { 1004 A.addTick(double(i)); 1005 } 1006 1007 gum::NodeId a_id = templ.add(A); 1008 nodeList.push_back(a_id); // Ajout de la variable Age allant de 60 1009 // à 100 ans à la liste des noeuds 1010 1011 // Ajout de la variable ADL allant de 0 à 6 à la liste des noeuds 1012 nodeList.push_back(templ.add(gum::RangeVariable("ADL", "ADL", 0, 6))); 1013 // Création du noeud central NRC (niveau de risque de chute) 1014 gum::LabelizedVariable NRC("NRC", "NRC", 0); 1015 1016 NRC.addLabel("faible"); 1017 NRC.addLabel("modere"); 1018 NRC.addLabel("eleve"); 1019 auto iNRC = templ.add(NRC); 1020 1021 // Création des arcs partant du noeud NRC vers les autres noeuds 1022 for (auto node: nodeList) { 1023 templ.addArc(iNRC, node); 1024 } 1025 1026 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/bugDoumenc.csv"), templ); 1027 learner.useScoreLog2Likelihood(); 1028 learner.useAprioriSmoothing(); 1029 auto bn = learner.learnParameters(templ.dag()); 1030 } 1031 test_BugDoumencWithInt()1032 void test_BugDoumencWithInt() { 1033 gum::BayesNet< double > templ; 1034 std::vector< std::string > varBool{"S", 1035 "DEP", 1036 "TM", 1037 "TE", 1038 "TV", 1039 "PSY", 1040 "AL", 1041 "PT", 1042 "HYP", 1043 "FRE", 1044 "PC", 1045 "C", 1046 "MN", 1047 "AM", 1048 "PR", 1049 "AR", 1050 "DFM"}; // les vraibles booléennes du RB 1051 1052 std::vector< std::string > varTer{ 1053 "NBC", 1054 "MED", 1055 "DEM", 1056 "SP"}; // les variables pouvant prendre 3 valeurs possibles du RB 1057 1058 std::vector< std::string > varContinuous{"A", "ADL"}; // les variables continues du RB 1059 1060 1061 std::vector< gum::NodeId > nodeList; // Liste des noeuds du RB 1062 1063 for (auto var: varBool) 1064 nodeList.push_back(templ.add( 1065 gum::LabelizedVariable(var, 1066 var, 1067 2))); // Ajout des variables booléennes à la liste des noeuds 1068 1069 for (auto var: varTer) 1070 nodeList.push_back(templ.add( 1071 gum::LabelizedVariable(var, 1072 var, 1073 3))); // Ajout des variables ternaires à la liste des noeuds 1074 1075 gum::DiscretizedVariable< int > A("A", "A"); 1076 for (int i = 60; i <= 105; i += 5) { 1077 A.addTick(i); 1078 } 1079 1080 nodeList.push_back(templ.add(A)); // Ajout de la variable Age allant de 60 1081 // à 100 ans à la liste des noeuds 1082 1083 // Ajout de la variable ADL allant de 0 à 6 à la liste des noeuds 1084 nodeList.push_back(templ.add(gum::RangeVariable("ADL", "ADL", 0, 6))); 1085 // Création du noeud central NRC (niveau de risque de chute) 1086 gum::LabelizedVariable NRC("NRC", "NRC", 0); 1087 1088 NRC.addLabel("faible"); 1089 NRC.addLabel("modere"); 1090 NRC.addLabel("eleve"); 1091 auto iNRC = templ.add(NRC); 1092 1093 // Création des arcs partant du noeud NRC vers les autres noeuds 1094 for (auto node: nodeList) { 1095 templ.addArc(iNRC, node); 1096 } 1097 1098 1099 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/bugDoumenc.csv"), templ); 1100 learner.useScoreLog2Likelihood(); 1101 learner.useAprioriSmoothing(); 1102 1103 auto bn = learner.learnParameters(templ.dag()); 1104 1105 const gum::DiscreteVariable& var_discr = bn.variable("A"); 1106 int good = 1; 1107 try { 1108 const gum::DiscretizedVariable< int >& xvar_discr 1109 = dynamic_cast< const gum::DiscretizedVariable< int >& >(var_discr); 1110 TS_ASSERT_EQUALS(xvar_discr.domainSize(), (gum::Size)9) 1111 TS_ASSERT_EQUALS(xvar_discr.label(0), "[60;65[") 1112 TS_ASSERT_EQUALS(xvar_discr.label(1), "[65;70[") 1113 TS_ASSERT_EQUALS(xvar_discr.label(8), "[100;105]") 1114 } catch (std::bad_cast&) { good = 0; } 1115 TS_ASSERT_EQUALS(good, 1) 1116 } 1117 test_setSliceOrderWithNames()1118 void test_setSliceOrderWithNames() { 1119 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 1120 learner.setSliceOrder( 1121 {{"smoking", "lung_cancer"}, {"bronchitis", "visit_to_Asia"}, {"tuberculosis"}}); 1122 1123 1124 gum::learning::BNLearner< double > learner2(GET_RESSOURCES_PATH("csv/asia3.csv")); 1125 TS_ASSERT_THROWS(learner2.setSliceOrder({{"smoking", "lung_cancer"}, 1126 {"bronchitis", "visit_to_Asia"}, 1127 {"smoking", "tuberculosis", "lung_cancer"}}), 1128 gum::DuplicateElement); 1129 1130 gum::learning::BNLearner< double > learner3(GET_RESSOURCES_PATH("csv/asia3.csv")); 1131 TS_ASSERT_THROWS( 1132 learner3.setSliceOrder( 1133 {{"smoking", "lung_cancer"}, {"bronchitis", "visit_to_Asia"}, {"CRUCRU"}}), 1134 gum::MissingVariableInDatabase); 1135 } 1136 test_dirichlet()1137 void test_dirichlet() { 1138 auto bn = gum::BayesNet< double >::fastPrototype("A->B<-C->D->E<-B"); 1139 1140 gum::learning::BNDatabaseGenerator< double > genere(bn); 1141 genere.setRandomVarOrder(); 1142 genere.drawSamples(2000); 1143 genere.toCSV(GET_RESSOURCES_PATH("outputs/bnlearner_dirichlet.csv")); 1144 1145 auto bn2 = gum::BayesNet< double >::fastPrototype("A->B->C->D->E"); 1146 gum::learning::BNDatabaseGenerator< double > genere2(bn2); 1147 genere2.drawSamples(100); 1148 genere2.toCSV(GET_RESSOURCES_PATH("outputs/bnlearner_database.csv")); 1149 1150 gum::learning::BNLearner< double > learner( 1151 GET_RESSOURCES_PATH("outputs/bnlearner_database.csv"), 1152 bn); 1153 learner.useAprioriDirichlet(GET_RESSOURCES_PATH("outputs/bnlearner_dirichlet.csv"), 10); 1154 learner.useScoreAIC(); 1155 1156 try { 1157 gum::BayesNet< double > bn3 = learner.learnBN(); 1158 } catch (gum::Exception& e) { GUM_SHOWERROR(e); } 1159 } 1160 1161 test_dirichlet2()1162 void test_dirichlet2() { 1163 // read the learning database 1164 gum::learning::DBInitializerFromCSV<> initializer( 1165 GET_RESSOURCES_PATH("csv/db_dirichlet_learning.csv")); 1166 const auto& var_names = initializer.variableNames(); 1167 const std::size_t nb_vars = var_names.size(); 1168 1169 gum::learning::DBTranslatorSet<> translator_set; 1170 gum::learning::DBTranslator4LabelizedVariable<> translator; 1171 for (std::size_t i = 0; i < nb_vars; ++i) { 1172 translator_set.insertTranslator(translator, i); 1173 } 1174 1175 gum::learning::DatabaseTable<> database(translator_set); 1176 database.setVariableNames(initializer.variableNames()); 1177 initializer.fillDatabase(database); 1178 1179 1180 // read the apriori database 1181 gum::learning::DBInitializerFromCSV<> dirichlet_initializer( 1182 GET_RESSOURCES_PATH("csv/db_dirichlet_apriori.csv")); 1183 const auto& dirichlet_var_names = initializer.variableNames(); 1184 const std::size_t dirichlet_nb_vars = dirichlet_var_names.size(); 1185 1186 gum::learning::DBTranslatorSet<> dirichlet_translator_set; 1187 for (std::size_t i = 0; i < dirichlet_nb_vars; ++i) { 1188 dirichlet_translator_set.insertTranslator(translator, i); 1189 } 1190 1191 gum::learning::DatabaseTable<> dirichlet_database(dirichlet_translator_set); 1192 dirichlet_database.setVariableNames(dirichlet_initializer.variableNames()); 1193 dirichlet_initializer.fillDatabase(dirichlet_database); 1194 dirichlet_database.reorder(); 1195 1196 1197 // create the score and the apriori 1198 gum::learning::DBRowGeneratorSet<> dirichlet_genset; 1199 gum::learning::DBRowGeneratorParser<> dirichlet_parser(dirichlet_database.handler(), 1200 dirichlet_genset); 1201 gum::learning::AprioriDirichletFromDatabase<> apriori(dirichlet_database, dirichlet_parser); 1202 1203 gum::learning::DBRowGeneratorSet<> genset; 1204 gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset); 1205 1206 std::vector< double > weights{0, 1.0, 5.0, 10.0, 1000.0, 7000.0, 100000.0}; 1207 1208 gum::learning::BNLearner< double > learner( 1209 GET_RESSOURCES_PATH("csv/db_dirichlet_learning.csv")); 1210 learner.useScoreBIC(); 1211 1212 for (const auto weight: weights) { 1213 apriori.setWeight(weight); 1214 gum::learning::ScoreBIC<> score(parser, apriori); 1215 1216 // finalize the learning algorithm 1217 gum::learning::StructuralConstraintSetStatic< gum::learning::StructuralConstraintDAG > 1218 struct_constraint; 1219 1220 gum::learning::ParamEstimatorML<> estimator(parser, apriori, score.internalApriori()); 1221 1222 gum::learning::GraphChangesGenerator4DiGraph< decltype(struct_constraint) > op_set( 1223 struct_constraint); 1224 1225 gum::learning::GraphChangesSelector4DiGraph< decltype(struct_constraint), decltype(op_set) > 1226 selector(score, struct_constraint, op_set); 1227 1228 gum::learning::GreedyHillClimbing search; 1229 1230 gum::BayesNet< double > bn = search.learnBN(selector, estimator); 1231 // std::cout << dag << std::endl; 1232 1233 1234 learner.useAprioriDirichlet(GET_RESSOURCES_PATH("csv/db_dirichlet_apriori.csv"), weight); 1235 1236 gum::BayesNet< double > xbn = learner.learnBN(); 1237 1238 TS_ASSERT_EQUALS(xbn.moralGraph(), bn.moralGraph()) 1239 } 1240 } 1241 test_EM()1242 void test_EM() { 1243 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/EM.csv"), 1244 true, 1245 std::vector< std::string >{"?"}); 1246 1247 TS_ASSERT(learner.hasMissingValues()) 1248 1249 gum::DAG dag; 1250 for (std::size_t i = std::size_t(0); i < learner.database().nbVariables(); ++i) { 1251 dag.addNodeWithId(gum::NodeId(i)); 1252 } 1253 dag.addArc(gum::NodeId(1), gum::NodeId(0)); 1254 dag.addArc(gum::NodeId(2), gum::NodeId(1)); 1255 dag.addArc(gum::NodeId(3), gum::NodeId(2)); 1256 1257 TS_ASSERT_THROWS(learner.learnParameters(dag), gum::MissingValueInDatabase) 1258 1259 learner.useEM(1e-3); 1260 learner.useAprioriSmoothing(); 1261 1262 TS_GUM_ASSERT_THROWS_NOTHING(learner.learnParameters(dag, false)); 1263 TS_GUM_ASSERT_THROWS_NOTHING(learner.nbrIterations()); 1264 } 1265 test_chi2()1266 void test_chi2() { 1267 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 1268 1269 auto reschi2 = learner.chi2("smoking", "lung_cancer"); 1270 TS_ASSERT_DELTA(reschi2.first, 36.2256, 1e-4) 1271 TS_ASSERT_DELTA(reschi2.second, 0, 1e-4) 1272 1273 reschi2 = learner.chi2("smoking", "visit_to_Asia"); 1274 TS_ASSERT_DELTA(reschi2.first, 1.1257, 1e-4) 1275 TS_ASSERT_DELTA(reschi2.second, 0.2886, 1e-4) 1276 1277 reschi2 = learner.chi2("lung_cancer", "tuberculosis"); 1278 TS_ASSERT_DELTA(reschi2.first, 0.6297, 1e-4) 1279 TS_ASSERT_DELTA(reschi2.second, 0.4274, 1e-4) 1280 1281 reschi2 = learner.chi2("lung_cancer", "tuberculosis", {"tuberculos_or_cancer"}); 1282 TS_ASSERT_DELTA(reschi2.first, 58.0, 1e-4) 1283 TS_ASSERT_DELTA(reschi2.second, 0.0, 1e-4) 1284 1285 // see IndepTestChi2TestSuite::test_statistics 1286 gum::learning::BNLearner< double > learner2(GET_RESSOURCES_PATH("csv/chi2.csv")); 1287 1288 auto stat = learner2.chi2("A", "C"); 1289 TS_ASSERT_DELTA(stat.first, 0.0007, 1e-3) 1290 TS_ASSERT_DELTA(stat.second, 0.978, 1e-3) 1291 1292 stat = learner2.chi2("A", "B"); 1293 TS_ASSERT_DELTA(stat.first, 21.4348, 1e-3) 1294 TS_ASSERT_DELTA(stat.second, 3.6e-6, TS_GUM_SMALL_ERROR) 1295 1296 stat = learner2.chi2("B", "A"); 1297 TS_ASSERT_DELTA(stat.first, 21.4348, 1e-3) 1298 TS_ASSERT_DELTA(stat.second, 3.6e-6, TS_GUM_SMALL_ERROR) 1299 1300 stat = learner2.chi2("B", "D"); 1301 TS_ASSERT_DELTA(stat.first, 0.903, 1e-3) 1302 TS_ASSERT_DELTA(stat.second, 0.341, 1e-3) 1303 1304 stat = learner2.chi2("A", "C", {"B"}); 1305 TS_ASSERT_DELTA(stat.first, 15.2205, 1e-3) 1306 TS_ASSERT_DELTA(stat.second, 0.0005, 1e-4) 1307 } 1308 test_G2()1309 void test_G2() { 1310 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 1311 auto resg2 = learner.G2("smoking", "lung_cancer"); 1312 TS_ASSERT_DELTA(resg2.first, 43.0321, 1e-4) 1313 TS_ASSERT_DELTA(resg2.second, 0, 1e-4) 1314 1315 resg2 = learner.G2("smoking", "visit_to_Asia"); 1316 TS_ASSERT_DELTA(resg2.first, 1.1418, 1e-4) 1317 TS_ASSERT_DELTA(resg2.second, 0.2852, 1e-4) 1318 1319 resg2 = learner.G2("lung_cancer", "tuberculosis"); 1320 TS_ASSERT_DELTA(resg2.first, 1.2201, 1e-4) 1321 TS_ASSERT_DELTA(resg2.second, 0.2693, 1e-4) 1322 1323 resg2 = learner.G2("lung_cancer", "tuberculosis", {"tuberculos_or_cancer"}); 1324 TS_ASSERT_DELTA(resg2.first, 59.1386, 1e-4) 1325 TS_ASSERT_DELTA(resg2.second, 0.0, 1e-4) 1326 1327 // see IndepTestChi2TestSuite::test_statistics 1328 gum::learning::BNLearner< double > learner2(GET_RESSOURCES_PATH("csv/chi2.csv")); 1329 1330 auto stat = learner2.G2("A", "C"); 1331 TS_ASSERT_DELTA(stat.first, 0.0007, 1e-3) 1332 TS_ASSERT_DELTA(stat.second, 0.978, 1e-4) 1333 1334 stat = learner2.G2("A", "B"); 1335 TS_ASSERT_DELTA(stat.first, 21.5846, 1e-3) 1336 TS_ASSERT_DELTA(stat.second, 3.6e-6, 1e-4) 1337 1338 stat = learner2.G2("B", "A"); 1339 TS_ASSERT_DELTA(stat.first, 21.5846, 1e-3) 1340 TS_ASSERT_DELTA(stat.second, 3.6e-6, 1e-4) 1341 1342 stat = learner2.G2("B", "D"); 1343 TS_ASSERT_DELTA(stat.first, 0.903, 1e-3) 1344 TS_ASSERT_DELTA(stat.second, 0.342, 1e-4) 1345 1346 stat = learner2.G2("A", "C", {"B"}); 1347 TS_ASSERT_DELTA(stat.first, 16.3470, 1e-3) 1348 TS_ASSERT_DELTA(stat.second, 0.0002, 1e-4) 1349 } 1350 test_cmpG2Chi2()1351 void test_cmpG2Chi2() { 1352 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/testXYbase.csv")); 1353 auto statChi2 = learner.chi2("X", "Y"); 1354 TS_ASSERT_DELTA(statChi2.first, 15.3389, 1e-3) 1355 TS_ASSERT_DELTA(statChi2.second, 0.01777843046460533, 1e-6) 1356 auto statG2 = learner.G2("X", "Y"); 1357 TS_ASSERT_DELTA(statG2.first, 16.6066, 1e-3) 1358 TS_ASSERT_DELTA(statG2.second, 0.0108433, 1e-6) 1359 } 1360 test_loglikelihood()1361 void test_loglikelihood() { 1362 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/chi2.csv")); 1363 TS_ASSERT_EQUALS(learner.nbRows(), (gum::Size)500) 1364 TS_ASSERT_EQUALS(learner.nbCols(), (gum::Size)4) 1365 1366 double siz = -1.0 * learner.database().size(); 1367 learner.useNoApriori(); 1368 1369 auto stat = learner.logLikelihood({"A"}) / siz; // LL=-N.H 1370 TS_ASSERT_DELTA(stat, 0.99943499, TS_GUM_SMALL_ERROR) 1371 stat = learner.logLikelihood({"B"}) / siz; // LL=-N.H 1372 TS_ASSERT_DELTA(stat, 0.9986032, TS_GUM_SMALL_ERROR) 1373 stat = learner.logLikelihood({std::string("A"), "B"}) / siz; // LL=-N.H 1374 TS_ASSERT_DELTA(stat, 1.9668973, TS_GUM_SMALL_ERROR) 1375 stat = learner.logLikelihood({std::string("A")}, {"B"}) / siz; // LL=-N.H 1376 TS_ASSERT_DELTA(stat, 1.9668973 - 0.9986032, TS_GUM_SMALL_ERROR) 1377 1378 stat = learner.logLikelihood({"C"}) / siz; // LL=-N.H 1379 TS_ASSERT_DELTA(stat, 0.99860302, TS_GUM_SMALL_ERROR) 1380 stat = learner.logLikelihood({"D"}) / siz; // LL=-N.H 1381 TS_ASSERT_DELTA(stat, 0.40217919, TS_GUM_SMALL_ERROR) 1382 stat = learner.logLikelihood({std::string("C"), "D"}) / siz; // LL=-N.H 1383 TS_ASSERT_DELTA(stat, 1.40077995, TS_GUM_SMALL_ERROR) 1384 stat = learner.logLikelihood({std::string("C")}, {"D"}) / siz; // LL=-N.H 1385 TS_ASSERT_DELTA(stat, 1.40077995 - 0.40217919, TS_GUM_SMALL_ERROR) 1386 } 1387 test_errorFromPyagrum()1388 void test_errorFromPyagrum() { 1389 try { 1390 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/sample_asia.csv")); 1391 learner.use3off2(); 1392 learner.useNMLCorrection(); 1393 auto ge3off2 = learner.learnMixedStructure(); 1394 } catch (gum::Exception& e) { GUM_SHOWERROR(e); } 1395 } 1396 test_PossibleEdges()1397 void test_PossibleEdges() { 1398 //[smoking , lung_cancer , bronchitis , visit_to_Asia , tuberculosis , 1399 // tuberculos_or_cancer , dyspnoea , positive_XraY] 1400 { 1401 // possible edges are not relevant 1402 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 1403 learner.addPossibleEdge("visit_to_Asia", "lung_cancer"); 1404 learner.addPossibleEdge("visit_to_Asia", "smoking"); 1405 1406 gum::BayesNet< double > bn = learner.learnBN(); 1407 TS_ASSERT_EQUALS(bn.sizeArcs(), (gum::Size)0) 1408 } 1409 1410 { 1411 // possible edges are relevant 1412 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 1413 learner.addPossibleEdge("smoking", "lung_cancer"); 1414 learner.addPossibleEdge("bronchitis", "smoking"); 1415 1416 gum::BayesNet< double > bn = learner.learnBN(); 1417 TS_ASSERT_EQUALS(bn.sizeArcs(), (gum::Size)2) 1418 TS_ASSERT(bn.parents("lung_cancer").contains(bn.idFromName("smoking"))) 1419 TS_ASSERT(bn.parents("bronchitis").contains(bn.idFromName("smoking"))) 1420 } 1421 1422 { 1423 // possible edges are relevant 1424 // mixed with a forbidden arcs 1425 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 1426 learner.addPossibleEdge("smoking", "lung_cancer"); 1427 learner.addPossibleEdge("bronchitis", "smoking"); 1428 learner.addForbiddenArc("smoking", "bronchitis"); 1429 1430 gum::BayesNet< double > bn = learner.learnBN(); 1431 TS_ASSERT_EQUALS(bn.sizeArcs(), (gum::Size)2) 1432 TS_ASSERT(bn.parents("lung_cancer").contains(bn.idFromName("smoking"))) 1433 TS_ASSERT(bn.parents("smoking").contains(bn.idFromName("bronchitis"))) 1434 } 1435 1436 { 1437 // possible edges are relevant 1438 // mixed with a mandatory arcs 1439 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 1440 learner.addPossibleEdge("smoking", "lung_cancer"); 1441 learner.addPossibleEdge("bronchitis", "smoking"); 1442 learner.addMandatoryArc("visit_to_Asia", "bronchitis"); 1443 1444 gum::BayesNet< double > bn = learner.learnBN(); 1445 TS_ASSERT_EQUALS(bn.sizeArcs(), (gum::Size)3) 1446 TS_ASSERT(bn.parents("lung_cancer").contains(bn.idFromName("smoking"))) 1447 TS_ASSERT(bn.parents("smoking").contains(bn.idFromName("bronchitis"))) 1448 TS_ASSERT(bn.parents("bronchitis").contains(bn.idFromName("visit_to_Asia"))) 1449 } 1450 } 1451 test_PossibleEdgesTabu()1452 void test_PossibleEdgesTabu() { 1453 //[smoking , lung_cancer , bronchitis , visit_to_Asia , tuberculosis , 1454 // tuberculos_or_cancer , dyspnoea , positive_XraY] 1455 { 1456 // possible edges are not relevant 1457 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 1458 learner.useLocalSearchWithTabuList(); 1459 learner.addPossibleEdge("visit_to_Asia", "lung_cancer"); 1460 learner.addPossibleEdge("visit_to_Asia", "smoking"); 1461 1462 gum::BayesNet< double > bn = learner.learnBN(); 1463 TS_ASSERT_EQUALS(bn.sizeArcs(), (gum::Size)0) 1464 } 1465 1466 { 1467 // possible edges are relevant 1468 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 1469 learner.useLocalSearchWithTabuList(); 1470 learner.addPossibleEdge("smoking", "lung_cancer"); 1471 learner.addPossibleEdge("bronchitis", "smoking"); 1472 1473 gum::BayesNet< double > bn = learner.learnBN(); 1474 TS_ASSERT_EQUALS(bn.sizeArcs(), (gum::Size)2) 1475 TS_ASSERT(bn.parents("lung_cancer").contains(bn.idFromName("smoking"))) 1476 TS_ASSERT(bn.parents("bronchitis").contains(bn.idFromName("smoking"))) 1477 } 1478 1479 { 1480 // possible edges are relevant 1481 // mixed with a forbidden arcs 1482 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 1483 learner.useLocalSearchWithTabuList(); 1484 learner.addPossibleEdge("smoking", "lung_cancer"); 1485 learner.addPossibleEdge("bronchitis", "smoking"); 1486 learner.addForbiddenArc("smoking", "bronchitis"); 1487 1488 gum::BayesNet< double > bn = learner.learnBN(); 1489 TS_ASSERT_EQUALS(bn.sizeArcs(), (gum::Size)2) 1490 TS_ASSERT(bn.parents("lung_cancer").contains(bn.idFromName("smoking"))) 1491 TS_ASSERT(bn.parents("smoking").contains(bn.idFromName("bronchitis"))) 1492 } 1493 1494 { 1495 // possible edges are relevant 1496 // mixed with a mandatory arcs 1497 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv")); 1498 learner.useLocalSearchWithTabuList(); 1499 learner.addPossibleEdge("smoking", "lung_cancer"); 1500 learner.addPossibleEdge("bronchitis", "smoking"); 1501 learner.addMandatoryArc("visit_to_Asia", "bronchitis"); 1502 1503 gum::BayesNet< double > bn = learner.learnBN(); 1504 TS_ASSERT_EQUALS(bn.sizeArcs(), (gum::Size)3) 1505 TS_ASSERT(bn.parents("lung_cancer").contains(bn.idFromName("smoking"))) 1506 TS_ASSERT(bn.parents("bronchitis").contains(bn.idFromName("smoking"))) 1507 TS_ASSERT(bn.parents("bronchitis").contains(bn.idFromName("visit_to_Asia"))) 1508 } 1509 } 1510 testPseudoCount()1511 void testPseudoCount() { 1512 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/minimal.csv")); 1513 TS_ASSERT_EQUALS(learner.domainSize(0), 2u) 1514 TS_ASSERT_EQUALS(learner.domainSize("X"), 2u) 1515 TS_ASSERT_EQUALS(learner.domainSize(1), 2u) 1516 TS_ASSERT_EQUALS(learner.domainSize("Y"), 2u) 1517 TS_ASSERT_EQUALS(learner.domainSize(2), 3u) 1518 TS_ASSERT_EQUALS(learner.domainSize("Z"), 3u) 1519 learner.useNoApriori(); 1520 1521 TS_ASSERT_EQUALS(learner.rawPseudoCount(std::vector< gum::NodeId >({0})), 1522 std::vector< double >({3, 4})); 1523 TS_ASSERT_EQUALS(learner.rawPseudoCount(std::vector< gum::NodeId >({0, 2})), 1524 std::vector< double >({2, 1, 1, 1, 0, 2})); 1525 TS_ASSERT_EQUALS(learner.rawPseudoCount(std::vector< gum::NodeId >({2, 0})), 1526 std::vector< double >({2, 1, 0, 1, 1, 2})); 1527 1528 TS_ASSERT_EQUALS(learner.rawPseudoCount(std::vector< std::string >({"X"})), 1529 std::vector< double >({3, 4})); 1530 TS_ASSERT_EQUALS(learner.rawPseudoCount(std::vector< std::string >({"X", "Z"})), 1531 std::vector< double >({2, 1, 1, 1, 0, 2})); 1532 TS_ASSERT_EQUALS(learner.rawPseudoCount(std::vector< std::string >({"Z", "X"})), 1533 std::vector< double >({2, 1, 0, 1, 1, 2})); 1534 } 1535 testNonRegressionZeroCount()1536 void testNonRegressionZeroCount() { 1537 ////////////////////////// 1538 // without specific score 1539 auto templ12 = gum::BayesNet< double >::fastPrototype("smoking->lung_cancer"); 1540 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia.csv"), templ12); 1541 auto bn = learner.learnParameters(templ12.dag()); 1542 1543 gum::learning::BNLearner< double > learner2(GET_RESSOURCES_PATH("csv/asia.csv"), templ12); 1544 auto bn2 = learner2.learnParameters(templ12.dag()); 1545 TS_ASSERT_EQUALS(bn.cpt("lung_cancer").toString(), bn2.cpt("lung_cancer").toString()) 1546 1547 ////////////////////////// 1548 // with score AIC 1549 auto templ34 = gum::BayesNet< double >::fastPrototype("smoking[3]->lung_cancer[3]"); 1550 gum::learning::BNLearner< double > learner3(GET_RESSOURCES_PATH("csv/asia.csv"), templ34); 1551 learner3.useScoreAIC(); 1552 learner3.useAprioriSmoothing(1e-6); 1553 1554 auto bn3 = learner3.learnParameters(templ34.dag()); 1555 { 1556 const gum::Potential< double >& p = bn.cpt("lung_cancer"); 1557 const gum::Potential< double >& q = bn3.cpt("lung_cancer"); 1558 1559 auto I = gum::Instantiation(p); 1560 auto J = gum::Instantiation(q); 1561 1562 TS_ASSERT_DELTA(p[I], q[J], 1e-6) 1563 ++I; 1564 ++J; 1565 TS_ASSERT_DELTA(p[I], q[J], 1e-6) 1566 ++J; 1567 TS_ASSERT_DELTA(0.0, q[J], 1e-6) 1568 ++I; 1569 ++J; 1570 TS_ASSERT_DELTA(p[I], q[J], 1e-6) 1571 ++I; 1572 ++J; 1573 TS_ASSERT_DELTA(p[I], q[J], 1e-6) 1574 ++J; 1575 TS_ASSERT_DELTA(0.0, q[J], 1e-6) 1576 ++J; 1577 TS_ASSERT_DELTA(1.0 / 3.0, q[J], 1e-6) 1578 ++J; 1579 TS_ASSERT_DELTA(1.0 / 3.0, q[J], 1e-6) 1580 ++J; 1581 TS_ASSERT_DELTA(1.0 / 3.0, q[J], 1e-6) 1582 } 1583 1584 auto templ35 = gum::BayesNet< double >::fastPrototype("smoking[3]->lung_cancer[3]"); 1585 gum::learning::BNLearner< double > learner4(GET_RESSOURCES_PATH("csv/asia.csv"), templ35); 1586 learner4.useScoreAIC(); 1587 1588 TS_ASSERT_THROWS(learner4.learnParameters(templ34.dag()), gum::DatabaseError) 1589 } 1590 test_misorientation_MIIC()1591 void test_misorientation_MIIC() { 1592 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/renewal.csv")); 1593 1594 learner.useMIIC(); 1595 learner.useNMLCorrection(); 1596 1597 auto bn = learner.learnBN(); 1598 auto expected_arcs = std::vector< std::pair< std::string, std::string > >( 1599 {{"coupon", "loyalty"}, 1600 {"coupon", "recent visit"}, 1601 {"loyalty", "renewal"}, 1602 {"loyalty", "recent visit"}, 1603 {"corporate customer", "loyalty"}, 1604 {"corporate customer", "yearly consumption"}, 1605 {"yearly consumption", "loyalty"}, 1606 {"yearly consumption", "coupon"}}); 1607 for (auto a: expected_arcs) { 1608 TS_ASSERT(bn.existsArc(a.first, a.second)) 1609 } 1610 } 1611 testState()1612 void testState() { 1613 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/renewal.csv")); 1614 { 1615 auto state = learner.state(); 1616 TS_ASSERT_EQUALS(state.size(), (gum::Size)8) 1617 TS_ASSERT_EQUALS(std::get< 0 >(state[0]), "Filename") 1618 TS_ASSERT_EQUALS(std::get< 1 >(state[0]), GET_RESSOURCES_PATH("csv/renewal.csv")) 1619 1620 TS_ASSERT_EQUALS(std::get< 0 >(state[1]), "Size") 1621 TS_ASSERT_EQUALS(std::get< 1 >(state[1]), "(50000,6)") 1622 1623 TS_ASSERT_EQUALS(std::get< 0 >(state[2]), "Variables") 1624 TS_ASSERT_EQUALS(std::get< 1 >(state[2]), 1625 "loyalty[2], renewal[2], yearly consumption[5], corporate customer[2], " 1626 "coupon[2], recent visit[2]") 1627 1628 TS_ASSERT_EQUALS(std::get< 0 >(state[3]), "Induced types") 1629 TS_ASSERT_EQUALS(std::get< 1 >(state[3]), "True") 1630 1631 TS_ASSERT_EQUALS(std::get< 0 >(state[4]), "Missing values") 1632 TS_ASSERT_EQUALS(std::get< 1 >(state[4]), "False") 1633 1634 TS_ASSERT_EQUALS(std::get< 0 >(state[5]), "Algorithm") 1635 TS_ASSERT_EQUALS(std::get< 1 >(state[5]), "Greedy Hill Climbing") 1636 1637 TS_ASSERT_EQUALS(std::get< 0 >(state[6]), "Score") 1638 TS_ASSERT_EQUALS(std::get< 1 >(state[6]), "BDeu") 1639 1640 TS_ASSERT_EQUALS(std::get< 0 >(state[7]), "Prior") 1641 TS_ASSERT_EQUALS(std::get< 1 >(state[7]), "-") 1642 } 1643 1644 learner.useMIIC(); 1645 learner.useNMLCorrection(); 1646 1647 { 1648 auto state = learner.state(); 1649 TS_ASSERT_EQUALS(state.size(), (gum::Size)8) 1650 TS_ASSERT_EQUALS(std::get< 0 >(state[0]), "Filename") 1651 TS_ASSERT_EQUALS(std::get< 1 >(state[0]), GET_RESSOURCES_PATH("csv/renewal.csv")) 1652 1653 TS_ASSERT_EQUALS(std::get< 0 >(state[1]), "Size") 1654 TS_ASSERT_EQUALS(std::get< 1 >(state[1]), "(50000,6)") 1655 1656 TS_ASSERT_EQUALS(std::get< 0 >(state[2]), "Variables") 1657 TS_ASSERT_EQUALS(std::get< 1 >(state[2]), 1658 "loyalty[2], renewal[2], yearly consumption[5], corporate customer[2], " 1659 "coupon[2], recent visit[2]") 1660 1661 1662 TS_ASSERT_EQUALS(std::get< 0 >(state[3]), "Induced types") 1663 TS_ASSERT_EQUALS(std::get< 1 >(state[3]), "True") 1664 1665 TS_ASSERT_EQUALS(std::get< 0 >(state[4]), "Missing values") 1666 TS_ASSERT_EQUALS(std::get< 1 >(state[4]), "False") 1667 1668 TS_ASSERT_EQUALS(std::get< 0 >(state[5]), "Algorithm") 1669 TS_ASSERT_EQUALS(std::get< 1 >(state[5]), "MIIC") 1670 1671 TS_ASSERT_EQUALS(std::get< 0 >(state[6]), "Correction") 1672 TS_ASSERT_EQUALS(std::get< 1 >(state[6]), "NML") 1673 1674 TS_ASSERT_EQUALS(std::get< 0 >(state[7]), "Prior") 1675 TS_ASSERT_EQUALS(std::get< 1 >(state[7]), "-") 1676 } 1677 1678 learner.addPossibleEdge("loyalty", "renewal"); 1679 learner.setSliceOrder({{"loyalty", "renewal"}, {"recent visit", "corporate customer"}}); 1680 { 1681 auto state = learner.state(); 1682 TS_ASSERT_EQUALS(state.size(), (gum::Size)10) 1683 TS_ASSERT_EQUALS(std::get< 0 >(state[0]), "Filename") 1684 TS_ASSERT_EQUALS(std::get< 1 >(state[0]), GET_RESSOURCES_PATH("csv/renewal.csv")) 1685 1686 TS_ASSERT_EQUALS(std::get< 0 >(state[1]), "Size") 1687 TS_ASSERT_EQUALS(std::get< 1 >(state[1]), "(50000,6)") 1688 1689 TS_ASSERT_EQUALS(std::get< 0 >(state[2]), "Variables") 1690 TS_ASSERT_EQUALS(std::get< 1 >(state[2]), 1691 "loyalty[2], renewal[2], yearly consumption[5], corporate customer[2], " 1692 "coupon[2], recent visit[2]") 1693 1694 1695 TS_ASSERT_EQUALS(std::get< 0 >(state[3]), "Induced types") 1696 TS_ASSERT_EQUALS(std::get< 1 >(state[3]), "True") 1697 1698 TS_ASSERT_EQUALS(std::get< 0 >(state[4]), "Missing values") 1699 TS_ASSERT_EQUALS(std::get< 1 >(state[4]), "False") 1700 1701 TS_ASSERT_EQUALS(std::get< 0 >(state[5]), "Algorithm") 1702 TS_ASSERT_EQUALS(std::get< 1 >(state[5]), "MIIC") 1703 1704 TS_ASSERT_EQUALS(std::get< 0 >(state[6]), "Correction") 1705 TS_ASSERT_EQUALS(std::get< 1 >(state[6]), "NML") 1706 1707 TS_ASSERT_EQUALS(std::get< 0 >(state[7]), "Prior") 1708 TS_ASSERT_EQUALS(std::get< 1 >(state[7]), "-") 1709 1710 TS_ASSERT_EQUALS(std::get< 0 >(state[8]), "Constraint Possible Edges") 1711 TS_ASSERT_EQUALS(std::get< 1 >(state[8]), "{loyalty--renewal}") 1712 1713 TS_ASSERT_EQUALS(std::get< 0 >(state[9]), "Constraint Slice Order") 1714 TS_ASSERT_EQUALS(std::get< 1 >(state[9]), 1715 "{corporate customer:1, renewal:0, loyalty:0, recent visit:1}") 1716 } 1717 1718 gum::DAG dag; 1719 dag.addNodes(learner.nbCols()); 1720 dag.addArc(0, 1); 1721 learner.setInitialDAG(dag); 1722 { 1723 auto state = learner.state(); 1724 TS_ASSERT_EQUALS(state.size(), (gum::Size)11) 1725 TS_ASSERT_EQUALS(std::get< 0 >(state[10]), "Initial DAG") 1726 TS_ASSERT_EQUALS(std::get< 1 >(state[10]), "True") 1727 } 1728 } 1729 testStateContinued()1730 void testStateContinued() { 1731 gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/renewal.csv")); 1732 learner.setDatabaseWeight(1000); 1733 learner.useK2(std::vector< gum::NodeId >{5, 4, 3, 2, 1, 0}); 1734 { 1735 auto state = learner.state(); 1736 TS_ASSERT_EQUALS(state.size(), (gum::Size)10) 1737 TS_ASSERT_EQUALS(std::get< 0 >(state[6]), "K2 order") 1738 TS_ASSERT_EQUALS( 1739 std::get< 1 >(state[6]), 1740 "recent visit, coupon, corporate customer, yearly consumption, renewal, loyalty") 1741 1742 TS_ASSERT_EQUALS(std::get< 0 >(state[9]), "Database weight") 1743 TS_ASSERT_EQUALS(std::get< 1 >(state[9]), "1000.000000") 1744 } 1745 learner.useScoreAIC(); 1746 learner.useAprioriBDeu(); 1747 { 1748 auto state = learner.state(); 1749 TS_ASSERT_EQUALS(state.size(), (gum::Size)11) 1750 TS_ASSERT_EQUALS(std::get< 0 >(state[8]), "Prior") 1751 TS_ASSERT_DIFFERS(std::get< 2 >(state[8]), "") // there is a comment about AIC versus BDeu 1752 } 1753 } 1754 }; // class BNLearnerTestSuite 1755 } /* namespace gum_tests */ 1756