1 /**
2  *
3  *   Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU)
4  *   info_at_agrum_dot_org
5  *
6  *  This library is free software: you can redistribute it and/or modify
7  *  it under the terms of the GNU Lesser General Public License as published by
8  *  the Free Software Foundation, either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  This library is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU Lesser General Public License for more details.
15  *
16  *  You should have received a copy of the GNU Lesser General Public License
17  *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 // floating point env
23 #include <cfenv>
24 #include <vector>
25 #include <string>
26 #include <iostream>
27 
28 #include <gumtest/AgrumTestSuite.h>
29 #include <gumtest/testsuite_utils.h>
30 
31 #include <agrum/BN/learning/BNLearner.h>
32 
33 #include <agrum/tools/core/approximations/approximationSchemeListener.h>
34 
35 #include <agrum/BN/database/BNDatabaseGenerator.h>
36 
37 namespace gum_tests {
38 
39   class aSimpleBNLeanerListener: public gum::ApproximationSchemeListener {
40     private:
41     gum::Size   _nbr_;
42     std::string _mess_;
43 
44     public:
aSimpleBNLeanerListener(gum::IApproximationSchemeConfiguration & sch)45     aSimpleBNLeanerListener(gum::IApproximationSchemeConfiguration& sch) :
46         gum::ApproximationSchemeListener(sch), _nbr_(0), _mess_(""){};
47 
whenProgress(const void * buffer,const gum::Size a,const double b,const double c)48     void whenProgress(const void* buffer, const gum::Size a, const double b, const double c) {
49       _nbr_++;
50     }
51 
whenStop(const void * buffer,const std::string s)52     void whenStop(const void* buffer, const std::string s) { _mess_ = s; }
53 
getNbr()54     gum::Size getNbr() { return _nbr_; }
55 
getMess()56     std::string getMess() { return _mess_; }
57   };
58 
59   class BNLearnerTestSuite: public CxxTest::TestSuite {
60     public:
test_asia()61     void test_asia() {
62       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
63 
64       learner.useLocalSearchWithTabuList(100, 1);
65       learner.setMaxIndegree(10);
66       learner.useScoreLog2Likelihood();
67 
68       TS_GUM_ASSERT_THROWS_NOTHING(learner.useScoreBD())
69       TS_ASSERT_DIFFERS("", learner.checkScoreAprioriCompatibility())
70       TS_GUM_ASSERT_THROWS_NOTHING(learner.useScoreBDeu())
71       TS_ASSERT_EQUALS("", learner.checkScoreAprioriCompatibility())
72       learner.useScoreLog2Likelihood();
73 
74       learner.useK2(std::vector< gum::NodeId >{1, 5, 2, 6, 0, 3, 4, 7});
75       // learner.addForbiddenArc ( gum::Arc (4,3) );
76       // learner.addForbiddenArc ( gum::Arc (5,1) );
77       // learner.addForbiddenArc ( gum::Arc (5,7) );
78 
79       // learner.addMandatoryArc ( gum::Arc ( learner.nodeId ( "bronchitis" ),
80       //                                      learner.nodeId ( "lung_cancer" )
81       //                                      ) );
82 
83       learner.addMandatoryArc("bronchitis", "lung_cancer");
84 
85       learner.useAprioriSmoothing();
86       // learner.useAprioriDirichlet (  GET_RESSOURCES_PATH( "asia.csv" ) );
87 
88       gum::NodeProperty< gum::Size > slice_order{std::make_pair(gum::NodeId(0), (gum::Size)1),
89                                                  std::make_pair(gum::NodeId(3), (gum::Size)0),
90                                                  std::make_pair(gum::NodeId(1), (gum::Size)0)};
91       learner.setSliceOrder(slice_order);
92 
93       const std::vector< std::string >& names = learner.names();
94       TS_ASSERT(!names.empty())
95 
96       try {
97         gum::BayesNet< double > bn = learner.learnBN();
98         TS_ASSERT_EQUALS(bn.dag().arcs().size(), (gum::Size)9)
99       } catch (gum::Exception& e) { GUM_SHOWERROR(e); }
100 
101       learner.setDatabaseWeight(10.0);
102       const auto&  db     = learner.database();
103       const double weight = 10.0 / double(db.nbRows());
104       for (const auto& row: db) {
105         TS_ASSERT_EQUALS(row.weight(), weight)
106       }
107       TS_ASSERT_DELTA(learner.databaseWeight(), 10.0, 1e-4)
108 
109       const std::size_t nbr = db.nbRows();
110       for (std::size_t i = std::size_t(0); i < nbr; ++i) {
111         if (i % 2) learner.setRecordWeight(i, 2.0);
112       }
113 
114       std::size_t index = std::size_t(0);
115       for (const auto& row: db) {
116         if (index % 2) {
117           TS_ASSERT_EQUALS(row.weight(), 2.0)
118           TS_ASSERT_EQUALS(learner.recordWeight(index), 2.0)
119         } else {
120           TS_ASSERT_EQUALS(row.weight(), 10.0)
121           TS_ASSERT_EQUALS(learner.recordWeight(index), 10.0)
122         }
123         ++index;
124       }
125     }
126 
test_induceTypes()127     void test_induceTypes() {
128       {
129         gum::learning::BNLearner< double > learner1(GET_RESSOURCES_PATH("csv/asia.csv"));
130         learner1.useScoreBDeu();
131         learner1.useNoApriori();
132         gum::BayesNet< double > bn1 = learner1.learnBN();
133 
134         gum::learning::BNLearner< double > learner2(GET_RESSOURCES_PATH("csv/asia.csv"),
135                                                     true,
136                                                     {"?"});
137         for (const auto trans2: learner2.database().translatorSet().translators()) {
138           const auto& var2 = trans2->variable();
139           TS_ASSERT((var2->varType() == gum::VarType::Range)
140                     || (var2->varType() == gum::VarType::Integer));
141         }
142 
143         learner2.useScoreBDeu();
144         learner2.useNoApriori();
145         gum::BayesNet< double > bn2 = learner2.learnBN();
146 
147         TS_ASSERT_EQUALS(bn1.dag(), bn2.dag())
148       }
149 
150       {
151         gum::learning::BNLearner< double > learner1(GET_RESSOURCES_PATH("csv/alarm.csv"));
152         learner1.useScoreBDeu();
153         learner1.useNoApriori();
154         gum::BayesNet< double > bn1 = learner1.learnBN();
155 
156         gum::learning::BNLearner< double > learner2(GET_RESSOURCES_PATH("csv/alarm.csv"),
157                                                     true,
158                                                     {"?"});
159 
160         for (const auto trans2: learner2.database().translatorSet().translators()) {
161           const auto& var2 = trans2->variable();
162           TS_ASSERT((var2->varType() == gum::VarType::Range)
163                     || (var2->varType() == gum::VarType::Integer));
164         }
165 
166         learner2.useScoreBDeu();
167         learner2.useNoApriori();
168         gum::BayesNet< double > bn2 = learner2.learnBN();
169 
170         TS_ASSERT_EQUALS(bn1.dag(), bn2.dag())
171       }
172 
173       {
174         auto bn = gum::BayesNet< double >::fastPrototype("A->B<-C->D->E<-B");
175         gum::learning::BNDatabaseGenerator< double > genere(bn);
176         genere.setRandomVarOrder();
177         genere.drawSamples(2000);
178         genere.toCSV(GET_RESSOURCES_PATH("outputs/bnlearner_dirichlet.csv"));
179 
180         auto bn2 = gum::BayesNet< double >::fastPrototype("A->B->C->D->E");
181         gum::learning::BNDatabaseGenerator< double > genere2(bn2);
182         genere2.drawSamples(100);
183         genere2.toCSV(GET_RESSOURCES_PATH("outputs/bnlearner_database.csv"));
184 
185         gum::learning::BNLearner< double > learner1(
186            GET_RESSOURCES_PATH("outputs/bnlearner_database.csv"));
187         learner1.useAprioriDirichlet(GET_RESSOURCES_PATH("outputs/bnlearner_dirichlet.csv"), 10);
188         learner1.useScoreAIC();
189         gum::BayesNet< double > xbn1 = learner1.learnBN();
190 
191         gum::learning::BNLearner< double > learner2(
192            GET_RESSOURCES_PATH("outputs/bnlearner_database.csv"),
193            true,
194            {"?"});
195 
196         for (const auto trans2: learner2.database().translatorSet().translators()) {
197           const auto& var2 = trans2->variable();
198           TS_ASSERT((var2->varType() == gum::VarType::Range)
199                     || (var2->varType() == gum::VarType::Integer));
200         }
201 
202         learner2.useAprioriDirichlet(GET_RESSOURCES_PATH("outputs/bnlearner_dirichlet.csv"), 10);
203         learner2.useScoreAIC();
204         gum::BayesNet< double > xbn2 = learner2.learnBN();
205 
206         TS_ASSERT_EQUALS(xbn1.dag(), xbn2.dag())
207       }
208     }
209 
210 
test_guill()211     void test_guill() {
212       try {
213         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3_withguill.csv"));
214         TS_FAIL("asia3_withguill.csv contains syntax error (with \").");
215       } catch (gum::SyntaxError& e) {};
216     }
217 
test_ranges()218     void test_ranges() {
219       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
220 
221       learner.useGreedyHillClimbing();
222       learner.useScoreBIC();
223       learner.useAprioriSmoothing();
224 
225       const std::size_t k        = 5;
226       const auto&       database = learner.database();
227       const std::size_t dbsize   = database.nbRows();
228       std::size_t       foldSize = dbsize / k;
229 
230       gum::learning::DBRowGeneratorSet<>    genset;
231       gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset);
232       gum::learning::AprioriSmoothing<>     apriori(database);
233       apriori.setWeight(1);
234 
235       gum::learning::StructuralConstraintSetStatic< gum::learning::StructuralConstraintDAG >
236          struct_constraint;
237 
238       gum::learning::GraphChangesGenerator4DiGraph< decltype(struct_constraint) > op_set(
239          struct_constraint);
240 
241       gum::learning::GreedyHillClimbing search;
242 
243       gum::learning::ScoreBIC<>         score(parser, apriori);
244       gum::learning::ParamEstimatorML<> estimator(parser, apriori, score.internalApriori());
245       for (std::size_t fold = 0; fold < k; fold++) {
246         // create the ranges of rows over which we perform the learning
247         const std::size_t unfold_deb = fold * foldSize;
248         const std::size_t unfold_end = unfold_deb + foldSize;
249 
250         std::vector< std::pair< std::size_t, std::size_t > > ranges;
251         if (fold == std::size_t(0)) {
252           ranges.push_back(std::pair< std::size_t, std::size_t >(unfold_end, dbsize));
253         } else {
254           ranges.push_back(std::pair< std::size_t, std::size_t >(std::size_t(0), unfold_deb));
255 
256           if (fold != k - 1) {
257             ranges.push_back(std::pair< std::size_t, std::size_t >(unfold_end, dbsize));
258           }
259         }
260 
261         learner.useDatabaseRanges(ranges);
262         TS_ASSERT_EQUALS(learner.databaseRanges(), ranges)
263 
264         learner.clearDatabaseRanges();
265         TS_ASSERT_DIFFERS(learner.databaseRanges(), ranges)
266 
267         learner.useCrossValidationFold(fold, k);
268         TS_ASSERT_EQUALS(learner.databaseRanges(), ranges)
269 
270         gum::BayesNet< double > bn1 = learner.learnBN();
271 
272 
273         score.setRanges(ranges);
274         estimator.setRanges(ranges);
275         gum::learning::GraphChangesSelector4DiGraph< decltype(struct_constraint), decltype(op_set) >
276                                 selector(score, struct_constraint, op_set);
277         gum::BayesNet< double > bn2 = search.learnBN< double >(selector, estimator);
278 
279         TS_ASSERT_EQUALS(bn1.dag(), bn2.dag())
280 
281         gum::Instantiation I1, I2;
282 
283         for (auto& name: database.variableNames()) {
284           I1.add(bn1.variableFromName(name));
285           I2.add(bn2.variableFromName(name));
286         }
287 
288         double            LL1 = 0.0, LL2 = 0.0;
289         const std::size_t nbCol = database.nbVariables();
290         parser.setRange(unfold_deb, unfold_end);
291         while (parser.hasRows()) {
292           const gum::learning::DBRow< gum::learning::DBTranslatedValue >& row = parser.row();
293           for (std::size_t i = 0; i < nbCol; ++i) {
294             I1.chgVal(i, row[i].discr_val);
295             I2.chgVal(i, row[i].discr_val);
296           }
297 
298           LL1 += bn1.log2JointProbability(I1) * row.weight();
299           LL2 += bn2.log2JointProbability(I2) * row.weight();
300         }
301 
302         TS_ASSERT_EQUALS(LL1, LL2)
303       }
304     }
305 
306 
test_asia_3off2()307     void test_asia_3off2() {
308       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia.csv"));
309 
310       aSimpleBNLeanerListener listen(learner);
311 
312       learner.useGreedyHillClimbing();
313 
314       learner.use3off2();
315       learner.useNMLCorrection();
316       learner.addForbiddenArc(gum::Arc(4, 1));
317       // learner.addForbiddenArc ( gum::Arc (5,1) );
318       // learner.addForbiddenArc ( gum::Arc (5,7) );
319 
320       learner.addMandatoryArc(gum::Arc(7, 5));
321       gum::DAG i_dag;
322       for (gum::NodeId i = 0; i < 8; ++i) {
323         i_dag.addNodeWithId(i);
324       }
325       learner.setInitialDAG(i_dag);
326       // learner.addMandatoryArc( "bronchitis", "lung_cancer" );
327 
328       const std::vector< std::string >& names = learner.names();
329       TS_ASSERT(!names.empty())
330 
331       try {
332         gum::BayesNet< double > bn = learner.learnBN();
333         TS_ASSERT_EQUALS(bn.dag().arcs().size(), (gum::Size)9)
334         // TS_ASSERT_EQUALS(listen.getNbr(), (gum::Size)86)
335         TS_ASSERT(!bn.dag().existsArc(4, 1))
336         TS_ASSERT(bn.dag().existsArc(7, 5))
337 
338         gum::MixedGraph mg = learner.learnMixedStructure();
339         TS_ASSERT_EQUALS(mg.arcs().size(), (gum::Size)8)
340         TS_ASSERT_EQUALS(mg.edges().size(), (gum::Size)1)
341         TS_ASSERT(!mg.existsArc(4, 1))
342         TS_ASSERT(mg.existsArc(7, 5))
343         std::vector< gum::Arc > latents = learner.latentVariables();
344         TS_ASSERT_EQUALS(latents.size(), (gum::Size)2)
345       } catch (gum::Exception& e) { GUM_SHOWERROR(e); }
346     }
347 
348     // WARNING: this test is commented on purpose: you need a running database
349     // with a table filled with the content of the asia.csv file. You will also
350     // need a proper odbc configuration (under linux and macos you'll need
351     // unixodbc and specific database odbc drivers).
352     // void test_asia_db() {
353     //   try {
354     //     auto db = gum::learning::DatabaseFromSQL(
355     //         "PostgreSQL",
356     //         "lto",
357     //         "Password2Change",
358     //         "select smoking , lung_cancer , bronchitis , visit_to_asia , "
359     //         "tuberculosis , tuberculos_or_cancer , dyspnoea , positive_xray "
360     //         "from asia;" );
361     //     gum::learning::BNLearner<double> learner( db );
362 
363     //     learner.useLocalSearchWithTabuList( 100, 1 );
364     //     learner.setMaxIndegree( 10 );
365     //     learner.useScoreLog2Likelihood();
366 
367     //     TS_ASSERT_THROWS( learner.useScoreBD(), gum::IncompatibleScoreApriori
368     //     );
369     //     TS_GUM_ASSERT_THROWS_NOTHING( learner.useScoreBDeu() );
370     //     learner.useScoreLog2Likelihood();
371 
372     //     learner.useK2( std::vector<gum::NodeId>{1, 5, 2, 6, 0, 3, 4, 7} );
373     //     learner.addMandatoryArc( "bronchitis", "lung_cancer" );
374     //     learner.useAprioriSmoothing();
375 
376     //     gum::NodeProperty<unsigned int> slice_order{
377     //         std::make_pair( gum::NodeId( 0 ), 1 ),
378     //         std::make_pair( gum::NodeId( 3 ), 0 ),
379     //         std::make_pair( gum::NodeId( 1 ), 0 )};
380 
381     //     const std::vector<std::string>& names = learner.names();
382     //     TS_ASSERT( !names.empty() )
383 
384     //     try {
385     //       gum::BayesNet<double> bn = learner.learnBN();
386     //       TS_ASSERT_EQUALS( bn.dag().arcs().size() , 9 )
387     //     } catch ( gum::Exception& e ) {
388     //       GUM_SHOWERROR( e );
389     //     }
390     //   } catch ( gum::Exception& e ) {
391     //     GUM_TRACE( e.errorType() );
392     //     GUM_TRACE( e.errorContent() );
393     //     GUM_TRACE( e.errorCallStack() );
394     //     TS_FAIL("plop");
395     //   }
396     // }
397 
test_asia_with_domain_sizes()398     void test_asia_with_domain_sizes() {
399       gum::learning::BNLearner< double > learn(GET_RESSOURCES_PATH("csv/asia3.csv"));
400       const auto&                        database = learn.database();
401 
402       gum::BayesNet< double > bn;
403       for (auto& name: database.variableNames()) {
404         gum::LabelizedVariable var(name, name, {"false", "true", "big"});
405         bn.add(var);
406       }
407 
408       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"), bn);
409       learner.useScoreBIC();
410       learner.useAprioriSmoothing();
411 
412       gum::BayesNet< double > bn2 = learner.learnBN();
413       for (auto& name: database.variableNames()) {
414         TS_ASSERT_EQUALS(bn2.variableFromName(name).domainSize(), (gum::Size)3)
415       }
416     }
417 
xtest_asia_with_user_modalities_string_min()418     void xtest_asia_with_user_modalities_string_min() {
419       gum::NodeProperty< gum::Sequence< std::string > > modals;
420       modals.insert(0, gum::Sequence< std::string >());
421       modals[0].insert("false");
422       modals[0].insert("true");
423       modals[0].insert("big");
424 
425       modals.insert(2, gum::Sequence< std::string >());
426       modals[2].insert("big");
427       modals[2].insert("bigbig");
428       modals[2].insert("true");
429       modals[2].insert("bigbigbig");
430       modals[2].insert("false");
431 
432       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
433       // GET_RESSOURCES_PATH("csv/asia3.csv"), modals, true);
434 
435       learner.useGreedyHillClimbing();
436       learner.setMaxIndegree(10);
437       learner.useScoreLog2Likelihood();
438 
439       TS_ASSERT_THROWS(learner.useScoreBD(), gum::IncompatibleScoreApriori)
440       TS_GUM_ASSERT_THROWS_NOTHING(learner.useScoreBDeu());
441       learner.useScoreLog2Likelihood();
442 
443       learner.useK2(std::vector< gum::NodeId >{1, 5, 2, 6, 0, 3, 4, 7});
444       // learner.addForbiddenArc ( gum::Arc (4,3) );
445       // learner.addForbiddenArc ( gum::Arc (5,1) );
446       // learner.addForbiddenArc ( gum::Arc (5,7) );
447 
448       // learner.addMandatoryArc ( gum::Arc ( learner.nodeId ( "bronchitis" ),
449       //                                      learner.nodeId ( "lung_cancer" )
450       //                                      ) );
451 
452       learner.addMandatoryArc("bronchitis", "lung_cancer");
453 
454       learner.useAprioriSmoothing();
455       // learner.useAprioriDirichlet (  GET_RESSOURCES_PATH( "asia.csv" ) );
456 
457       gum::NodeProperty< gum::Size > slice_order{std::make_pair(gum::NodeId(0), (gum::Size)1),
458                                                  std::make_pair(gum::NodeId(3), (gum::Size)0),
459                                                  std::make_pair(gum::NodeId(1), (gum::Size)0)};
460       learner.setSliceOrder(slice_order);
461 
462       const std::vector< std::string >& names = learner.names();
463       TS_ASSERT(!names.empty())
464 
465       try {
466         gum::BayesNet< double > bn = learner.learnBN();
467         TS_ASSERT_EQUALS(bn.variable(0).domainSize(), (gum::Size)2)
468         TS_ASSERT_EQUALS(bn.variable(2).domainSize(), (gum::Size)2)
469         TS_ASSERT_EQUALS(bn.variable(0).label(0), "false")
470         TS_ASSERT_EQUALS(bn.variable(0).label(1), "true")
471       } catch (gum::Exception& e) { GUM_SHOWERROR(e); }
472     }
473 
xtest_asia_with_user_modalities_string_incorrect()474     void xtest_asia_with_user_modalities_string_incorrect() {
475       gum::NodeProperty< gum::Sequence< std::string > > modals;
476       modals.insert(0, gum::Sequence< std::string >());
477       modals[0].insert("False");
478       modals[0].insert("true");
479       modals[0].insert("big");
480 
481       modals.insert(2, gum::Sequence< std::string >());
482       modals[2].insert("big");
483       modals[2].insert("bigbig");
484       modals[2].insert("true");
485       modals[2].insert("bigbigbig");
486       modals[2].insert("false");
487 
488       bool except = false;
489 
490       try {
491         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
492         // GET_RESSOURCES_PATH("csv/asia3.csv"), modals);
493         learner.useAprioriSmoothing();
494       } catch (gum::UnknownLabelInDatabase&) { except = true; }
495 
496       TS_ASSERT(except)
497     }
498 
xtest_asia_with_user_modalities_numbers()499     void xtest_asia_with_user_modalities_numbers() {
500       gum::NodeProperty< gum::Sequence< std::string > > modals;
501       modals.insert(0, gum::Sequence< std::string >());
502       modals[0].insert("0");
503       modals[0].insert("1");
504       modals[0].insert("big");
505 
506       modals.insert(2, gum::Sequence< std::string >());
507       modals[2].insert("big");
508       modals[2].insert("bigbig");
509       modals[2].insert("1");
510       modals[2].insert("bigbigbig");
511       modals[2].insert("0");
512 
513       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia.csv"));
514       // learner(GET_RESSOURCES_PATH("csv/asia.csv"), modals);
515       learner.useGreedyHillClimbing();
516       learner.setMaxIndegree(10);
517       learner.useScoreLog2Likelihood();
518 
519       TS_ASSERT_THROWS(learner.useScoreBD(), gum::IncompatibleScoreApriori)
520       TS_GUM_ASSERT_THROWS_NOTHING(learner.useScoreBDeu());
521       learner.useScoreLog2Likelihood();
522 
523       learner.useK2(std::vector< gum::NodeId >{1, 5, 2, 6, 0, 3, 4, 7});
524       learner.addForbiddenArc(gum::Arc(4, 3));
525       learner.addForbiddenArc(gum::Arc(5, 1));
526       learner.addForbiddenArc(gum::Arc(5, 7));
527 
528       learner.addMandatoryArc("bronchitis", "lung_cancer");
529 
530       learner.useAprioriSmoothing();
531       // learner.useAprioriDirichlet (  GET_RESSOURCES_PATH( "asia.csv" ) );
532 
533       gum::NodeProperty< gum::Size > slice_order{std::make_pair(gum::NodeId(0), (gum::Size)1),
534                                                  std::make_pair(gum::NodeId(3), (gum::Size)0),
535                                                  std::make_pair(gum::NodeId(1), (gum::Size)0)};
536       learner.setSliceOrder(slice_order);
537 
538       const std::vector< std::string >& names = learner.names();
539       TS_ASSERT(!names.empty())
540 
541       try {
542         gum::BayesNet< double > bn = learner.learnBN();
543         TS_ASSERT_EQUALS(bn.variable(0).domainSize(), (gum::Size)3)
544         TS_ASSERT_EQUALS(bn.variable(2).domainSize(), (gum::Size)5)
545         TS_ASSERT_EQUALS(bn.variable(0).label(0), "0")
546         TS_ASSERT_EQUALS(bn.variable(0).label(1), "1")
547         TS_ASSERT_EQUALS(bn.variable(0).label(2), "big")
548       } catch (gum::Exception& e) { GUM_SHOWERROR(e); }
549     }
550 
xtest_asia_with_user_modalities_numbers_incorrect()551     void xtest_asia_with_user_modalities_numbers_incorrect() {
552       gum::NodeProperty< gum::Sequence< std::string > > modals;
553       modals.insert(0, gum::Sequence< std::string >());
554       modals[0].insert("1");
555       modals[0].insert("2");
556       modals[0].insert("big");
557 
558       modals.insert(2, gum::Sequence< std::string >());
559       modals[2].insert("big");
560       modals[2].insert("bigbig");
561       modals[2].insert("3");
562       modals[2].insert("bigbigbig");
563       modals[2].insert("0");
564 
565       bool except = false;
566 
567       try {
568         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia.csv"));
569         // learner(GET_RESSOURCES_PATH("csv/asia.csv"), modals);
570         learner.useAprioriSmoothing();
571       } catch (gum::UnknownLabelInDatabase&) { except = true; }
572 
573       TS_ASSERT(except)
574     }
575 
test_asia_param()576     void test_asia_param() {
577       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
578 
579       gum::DAG dag;
580 
581       for (unsigned int i = 0; i < 8; ++i) {
582         dag.addNodeWithId(i);
583       }
584 
585       for (unsigned int i = 0; i < 7; ++i) {
586         dag.addArc(i, i + 1);
587       }
588 
589       dag.addArc(0, 7);
590       dag.addArc(2, 4);
591       dag.addArc(5, 7);
592       dag.addArc(3, 6);
593 
594       learner.useNoApriori();
595 
596       try {
597         gum::BayesNet< double > bn = learner.learnParameters(dag);
598         TS_ASSERT_EQUALS(bn.dim(), (gum::Size)25)
599       } catch (gum::Exception& e) { GUM_SHOWERROR(e); }
600     }
601 
test_asia_param_from_bn()602     void test_asia_param_from_bn() {
603       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
604 
605       learner.useK2(std::vector< gum::NodeId >{1, 5, 2, 6, 0, 3, 4, 7});
606       gum::BayesNet< double > bn = learner.learnBN();
607 
608       try {
609         gum::BayesNet< double > bn2 = learner.learnParameters(bn.dag());
610         TS_ASSERT_EQUALS(bn2.dag().arcs().size(), bn.dag().arcs().size())
611       } catch (gum::Exception& e) { GUM_SHOWERROR(e); }
612     }
613 
614 
test_asia_param_float()615     void test_asia_param_float() {
616       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
617 
618       gum::DAG dag;
619 
620       for (unsigned int i = 0; i < 8; ++i) {
621         dag.addNodeWithId(i);
622       }
623 
624       for (unsigned int i = 0; i < 7; ++i) {
625         dag.addArc(i, i + 1);
626       }
627 
628       dag.addArc(0, 7);
629       dag.addArc(2, 4);
630       dag.addArc(5, 7);
631       dag.addArc(3, 6);
632 
633       learner.useNoApriori();
634 
635       try {
636         gum::BayesNet< double > bn = learner.learnParameters(dag);
637         TS_ASSERT_EQUALS(bn.dim(), (gum::Size)25)
638       } catch (gum::Exception& e) { GUM_SHOWERROR(e); }
639     }
640 
test_asia_param_from_bn_float()641     void test_asia_param_from_bn_float() {
642       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
643 
644       learner.useK2(std::vector< gum::NodeId >{1, 5, 2, 6, 0, 3, 4, 7});
645       gum::BayesNet< double > bn = learner.learnBN();
646 
647       try {
648         gum::BayesNet< double > bn2 = learner.learnParameters(bn.dag());
649         TS_ASSERT_EQUALS(bn2.dag().arcs().size(), bn.dag().arcs().size())
650       } catch (gum::Exception& e) { GUM_SHOWERROR(e); }
651     }
652 
test_asia_param_bn()653     void test_asia_param_bn() {
654 #define createBoolVar(s) gum::LabelizedVariable(s, s, 0).addLabel("false").addLabel("true");
655       // smoking,lung_cancer,bronchitis,visit_to_Asia,tuberculosis,tuberculos_or_cancer,dyspnoea,positive_XraY
656       auto s = createBoolVar("smoking");
657       auto l = createBoolVar("lung_cancer");
658       auto b = createBoolVar("bronchitis");
659       auto v = createBoolVar("visit_to_Asia");
660       auto t = createBoolVar("tuberculosis");
661       auto o = createBoolVar("tuberculos_or_cancer");
662       auto d = createBoolVar("dyspnoea");
663       auto p = createBoolVar("positive_XraY");
664 #undef createBoolVar
665 
666       gum::BayesNet< double > bn;
667       gum::NodeId             ns = bn.add(s);
668       gum::NodeId             nl = bn.add(l);
669       gum::NodeId             nb = bn.add(b);
670       gum::NodeId             nv = bn.add(v);
671       gum::NodeId             nt = bn.add(t);
672       gum::NodeId             no = bn.add(o);
673       gum::NodeId             nd = bn.add(d);
674       gum::NodeId             np = bn.add(p);
675 
676       bn.addArc(ns, nl);
677       bn.addArc(ns, nb);
678       bn.addArc(nl, no);
679       bn.addArc(nb, nd);
680       bn.addArc(nv, nt);
681       bn.addArc(nt, no);
682       bn.addArc(no, nd);
683       bn.addArc(no, np);
684 
685       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"), bn);
686 
687       learner.useScoreLog2Likelihood();
688       learner.useAprioriSmoothing();
689 
690       try {
691         gum::BayesNet< double > bn2 = learner.learnParameters(bn.dag());
692         TS_ASSERT_EQUALS(bn2.dim(), bn.dim())
693 
694         for (gum::NodeId node: bn.nodes()) {
695           gum::NodeId node2 = bn2.idFromName(bn.variable(node).name());
696           TS_ASSERT_EQUALS(bn.variable(node).toString(), bn2.variable(node2).toString())
697         }
698 
699       } catch (gum::Exception& e) { GUM_SHOWERROR(e); }
700     }
701 
test_asia_param_bn_with_not_matching_variable()702     void test_asia_param_bn_with_not_matching_variable() {
703 #define createBoolVar(s) gum::LabelizedVariable(s, s, 0).addLabel("false").addLabel("true");
704       auto s = createBoolVar("smoking");
705       auto l = createBoolVar("lung_cancer");
706       auto b = createBoolVar("bronchitis");
707       auto v = createBoolVar("visit_to_Asia");
708       auto t = createBoolVar("tuberculosis");
709       auto o = createBoolVar("tuberculos_or_cancer");
710       auto d = createBoolVar("dyspnoea");
711 
712       // uncorrect name is : will it be correctly handled
713       auto p = createBoolVar("ZORBLOBO");
714 #undef createBoolVar
715 
716       gum::BayesNet< double > bn;
717       bn.add(s);
718       bn.add(l);
719       bn.add(b);
720       bn.add(v);
721       bn.add(t);
722       bn.add(o);
723       bn.add(d);
724       bn.add(p);
725 
726       TS_ASSERT_THROWS(
727          gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"), bn),
728          gum::MissingVariableInDatabase);
729 
730 
731       // learner.useScoreLog2Likelihood();
732       // learner.useAprioriSmoothing();
733 
734       // TS_ASSERT_THROWS(gum::BayesNet< double > bn2 =
735       // learner.learnParameters(bn),
736       //                 gum::MissingVariableInDatabase);
737     }
738 
test_asia_param_bn_with_subset_of_variables_in_base()739     void test_asia_param_bn_with_subset_of_variables_in_base() {
740 #define createBoolVar(s) gum::LabelizedVariable(s, s, 0).addLabel("false").addLabel("true");
741       auto s = createBoolVar("smoking");
742       auto t = createBoolVar("tuberculosis");
743       auto o = createBoolVar("tuberculos_or_cancer");
744       auto d = createBoolVar("dyspnoea");
745 #undef createBoolVar
746 
747       gum::BayesNet< double > bn;
748       gum::NodeId             ns = bn.add(s);
749       gum::NodeId             nt = bn.add(t);
750       gum::NodeId             no = bn.add(o);
751       gum::NodeId             nd = bn.add(d);
752 
753       bn.addArc(ns, nt);
754       bn.addArc(nt, no);
755       bn.addArc(no, nd);
756 
757       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"), bn);
758 
759 
760       learner.useScoreLog2Likelihood();
761       learner.useAprioriSmoothing();
762 
763       gum::BayesNet< double > bn2 = learner.learnParameters(bn.dag());
764     }
765 
test_asia_param_bn_with_unknow_modality()766     void test_asia_param_bn_with_unknow_modality() {
767 #define createBoolVar(s) gum::LabelizedVariable(s, s, 0).addLabel("false").addLabel("true");
768       auto s = createBoolVar("smoking");
769       auto t = createBoolVar("tuberculosis");
770       auto o = createBoolVar("tuberculos_or_cancer");
771       auto d = createBoolVar("dyspnoea");
772 #undef createBoolVar
773 
774       gum::BayesNet< double > bn;
775       gum::NodeId             ns = bn.add(s);
776       gum::NodeId             nt = bn.add(t);
777       gum::NodeId             no = bn.add(o);
778       gum::NodeId             nd = bn.add(d);
779 
780       bn.addArc(ns, nt);
781       bn.addArc(nt, no);
782       bn.addArc(no, nd);
783 
784       // asia3-faulty contains a label "beurk" for variable "smoking"
785       // std::cout << "error test";
786 
787       TS_ASSERT_THROWS(
788          gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3-faulty.csv"),
789                                                     bn),
790          gum::UnknownLabelInDatabase);
791     }
792 
test_listener()793     void test_listener() {
794       {
795         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia.csv"));
796         aSimpleBNLeanerListener            listen(learner);
797 
798         learner.setVerbosity(true);
799         learner.setMaxIndegree(10);
800         learner.useScoreK2();
801         learner.useK2(std::vector< gum::NodeId >{1, 5, 2, 6, 0, 3, 4, 7});
802 
803         gum::BayesNet< double > bn = learner.learnBN();
804 
805         TS_ASSERT_EQUALS(listen.getNbr(), (gum::Size)2)
806         TS_ASSERT_EQUALS(listen.getMess(), "stopped on request")
807         TS_ASSERT_EQUALS(learner.messageApproximationScheme(), "stopped on request")
808       }
809       {
810         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia2.csv"));
811         aSimpleBNLeanerListener            listen(learner);
812 
813         learner.setVerbosity(true);
814         learner.setMaxIndegree(10);
815         learner.useScoreK2();
816         learner.useK2(std::vector< gum::NodeId >{1, 5, 2, 6, 0, 3, 4, 7});
817 
818         gum::BayesNet< double > bn = learner.learnBN();
819 
820         TS_ASSERT_EQUALS(listen.getNbr(), (gum::Size)3)
821         TS_ASSERT_EQUALS(listen.getMess(), "stopped on request")
822         TS_ASSERT_EQUALS(learner.messageApproximationScheme(), "stopped on request")
823       }
824       {
825         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia.csv"));
826         aSimpleBNLeanerListener            listen(learner);
827 
828         learner.setVerbosity(true);
829         learner.setMaxIndegree(2);
830         learner.useLocalSearchWithTabuList();
831 
832         gum::BayesNet< double > bn = learner.learnBN();
833         // std::cout << bn.dag () << std::endl;
834 
835         TS_ASSERT_DELTA(listen.getNbr(), (gum::Size)15, 1);   // 75
836         TS_ASSERT_EQUALS(listen.getMess(), "stopped on request")
837         TS_ASSERT_EQUALS(learner.messageApproximationScheme(), "stopped on request")
838       }
839       {
840         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia.csv"));
841         aSimpleBNLeanerListener            listen(learner);
842 
843         learner.setVerbosity(true);
844         learner.setMaxIndegree(2);
845         learner.useGreedyHillClimbing();
846 
847         gum::BayesNet< double > bn = learner.learnBN();
848 
849         TS_ASSERT_DELTA(listen.getNbr(), (gum::Size)3, 1);   // 2?
850         TS_ASSERT_EQUALS(listen.getMess(), "stopped on request")
851         TS_ASSERT_EQUALS(learner.messageApproximationScheme(), "stopped on request")
852       }
853     }
854 
test_DBNTonda()855     void test_DBNTonda() {
856       gum::BayesNet< double > dbn;
857       gum::NodeId             bf_0 = dbn.add(gum::LabelizedVariable("bf_0", "bf_0", 4));
858       /*gum::NodeId bf_t =*/dbn.add(gum::LabelizedVariable("bf_t", "bf_t", 4));
859       gum::NodeId c_0  = dbn.add(gum::LabelizedVariable("c_0", "c_0", 5));
860       gum::NodeId c_t  = dbn.add(gum::LabelizedVariable("c_t", "c_t", 5));
861       gum::NodeId h_0  = dbn.add(gum::LabelizedVariable("h_0", "h_0", 5));
862       gum::NodeId h_t  = dbn.add(gum::LabelizedVariable("h_t", "h_t", 5));
863       gum::NodeId tf_0 = dbn.add(gum::LabelizedVariable("tf_0", "tf_0", 5));
864       /*gum::NodeId tf_t =*/dbn.add(gum::LabelizedVariable("tf_t", "tf_t", 5));
865       gum::NodeId wl_0 = dbn.add(gum::LabelizedVariable("wl_0", "wl_0", 4));
866       gum::NodeId wl_t = dbn.add(gum::LabelizedVariable("wl_t", "wl_t", 4));
867 
868       for (auto n: {c_t, h_t, wl_t}) {
869         dbn.addArc(tf_0, n);
870         dbn.addArc(bf_0, n);
871       }
872       dbn.addArc(c_0, c_t);
873       dbn.addArc(h_0, h_t);
874       dbn.addArc(wl_0, wl_t);
875 
876       gum::BayesNet< double > learn1;
877       {
878         // inductive learning leads to scrambled modalities
879         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/DBN_Tonda.csv"));
880         learner.useScoreLog2Likelihood();
881         learner.useAprioriSmoothing(1.0);
882         learn1 = learner.learnParameters(dbn.dag());
883       }
884       gum::BayesNet< double > learn2;
885       {
886         try {
887           /*
888           gum::NodeProperty< gum::Sequence< std::string > > modals;
889           auto ds = std::vector< unsigned int >({4, 4, 5, 5, 5, 5, 5, 5, 4, 4});
890           auto labels = std::vector< std::string >({"0", "1", "2", "3", "4", "5"});
891 
892           for (auto i = 0U; i < ds.size(); i++) {
893             modals.insert(i, gum::Sequence< std::string >());
894 
895             for (auto k = 0U; k < ds[i]; k++)
896               modals[i].insert(labels[k]);
897           }
898           */
899 
900           // while explicit learning does the right thing
901           gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/DBN_Tonda.csv"),
902                                                      learn1);
903           learner.useScoreLog2Likelihood();
904           learner.useAprioriSmoothing(1.0);
905           learn2 = learner.learnParameters(dbn.dag());
906         } catch (gum::Exception& e) { GUM_SHOWERROR(e); }
907       }
908       gum::BayesNet< double > learn3;
909       {
910         // while explicit learning does the right thing
911         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/DBN_Tonda.csv"), dbn);
912         learner.useScoreLog2Likelihood();
913         learner.useAprioriSmoothing(1.0);
914         learn3 = learner.learnParameters(dbn.dag());
915       }
916 
917       TS_ASSERT_EQUALS(learn1.variable(learn1.idFromName("wl_0")).toString(), "wl_0:Range([0,3])")
918       TS_ASSERT_EQUALS(learn2.variable(learn2.idFromName("wl_0")).toString(), "wl_0:Range([0,3])")
919       TS_ASSERT_EQUALS(learn2.variable(learn3.idFromName("wl_0")).toString(), "wl_0:Range([0,3])")
920 
921       auto&              p1 = learn1.cpt(learn1.idFromName("c_0"));
922       auto&              p2 = learn2.cpt(learn2.idFromName("c_0"));
923       auto&              p3 = learn3.cpt(learn3.idFromName("c_0"));
924       gum::Instantiation I1(p1), I2(p2), I3(p3);
925 
926       for (I1.setFirst(), I2.setFirst(), I3.setFirst(); !I1.end(); I1.inc(), I2.inc(), I3.inc()) {
927         TS_ASSERT_EQUALS(I1.toString(), I2.toString());   // same modalities orders
928         TS_ASSERT_EQUALS(I1.toString(), I3.toString());   // same modalities orders
929         TS_ASSERT_EQUALS(p1[I1], p2[I2]);                 // same probabilities
930         TS_ASSERT_EQUALS(p1[I1], p3[I3]);                 // same probabilities
931       }
932 
933       gum::BayesNet< double > learn4;
934       {
935         // inductive learning leads to scrambled modalities
936         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/DBN_Tonda.csv"), false);
937         learner.useScoreLog2Likelihood();
938         learner.useAprioriSmoothing(1.0);
939         learn4 = learner.learnParameters(dbn.dag());
940       }
941       TS_ASSERT_EQUALS(learn4.variable(learn1.idFromName("wl_0")).toString(),
942                        "wl_0:Labelized(<0,1,2,3>)");
943     }
944 
945 
test_asia_with_missing_values()946     void test_asia_with_missing_values() {
947       int nb = 0;
948       try {
949         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3-faulty.csv"),
950                                                    true,
951                                                    std::vector< std::string >{"BEURK"});
952         learner.useK2(std::vector< gum::NodeId >{1, 5, 2, 6, 0, 3, 4, 7});
953         learner.learnBN();
954       } catch (gum::MissingValueInDatabase&) { nb = 1; }
955 
956       TS_ASSERT_EQUALS(nb, 1)
957     }
958 
test_BugDoumenc()959     void test_BugDoumenc() {
960       gum::BayesNet< double >    templ;
961       std::vector< std::string > varBool{"S",
962                                          "DEP",
963                                          "TM",
964                                          "TE",
965                                          "TV",
966                                          "PSY",
967                                          "AL",
968                                          "PT",
969                                          "HYP",
970                                          "FRE",
971                                          "PC",
972                                          "C",
973                                          "MN",
974                                          "AM",
975                                          "PR",
976                                          "AR",
977                                          "DFM"};   // les vraibles booléennes du RB
978 
979       std::vector< std::string > varTer{
980          "NBC",
981          "MED",
982          "DEM",
983          "SP"};   // les variables pouvant prendre 3 valeurs possibles du RB
984 
985       std::vector< std::string > varContinuous{"A", "ADL"};   // les variables continues du RB
986 
987 
988       std::vector< gum::NodeId > nodeList;   // Liste des noeuds du RB
989 
990       for (auto var: varBool)
991         nodeList.push_back(templ.add(
992            gum::LabelizedVariable(var,
993                                   var,
994                                   2)));   // Ajout des variables booléennes à la liste des noeuds
995 
996       for (auto var: varTer)
997         nodeList.push_back(templ.add(
998            gum::LabelizedVariable(var,
999                                   var,
1000                                   3)));   // Ajout des variables ternaires à la liste des noeuds
1001 
1002       gum::DiscretizedVariable< double > A("A", "A");
1003       for (int i = 60; i <= 105; i += 5) {
1004         A.addTick(double(i));
1005       }
1006 
1007       gum::NodeId a_id = templ.add(A);
1008       nodeList.push_back(a_id);   // Ajout de la variable Age allant de 60
1009       // à 100 ans à la liste des noeuds
1010 
1011       // Ajout de la variable ADL allant de 0 à 6 à la liste des noeuds
1012       nodeList.push_back(templ.add(gum::RangeVariable("ADL", "ADL", 0, 6)));
1013       // Création du noeud central NRC (niveau de risque de chute)
1014       gum::LabelizedVariable NRC("NRC", "NRC", 0);
1015 
1016       NRC.addLabel("faible");
1017       NRC.addLabel("modere");
1018       NRC.addLabel("eleve");
1019       auto iNRC = templ.add(NRC);
1020 
1021       // Création des arcs partant du noeud NRC vers les autres noeuds
1022       for (auto node: nodeList) {
1023         templ.addArc(iNRC, node);
1024       }
1025 
1026       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/bugDoumenc.csv"), templ);
1027       learner.useScoreLog2Likelihood();
1028       learner.useAprioriSmoothing();
1029       auto bn = learner.learnParameters(templ.dag());
1030     }
1031 
test_BugDoumencWithInt()1032     void test_BugDoumencWithInt() {
1033       gum::BayesNet< double >    templ;
1034       std::vector< std::string > varBool{"S",
1035                                          "DEP",
1036                                          "TM",
1037                                          "TE",
1038                                          "TV",
1039                                          "PSY",
1040                                          "AL",
1041                                          "PT",
1042                                          "HYP",
1043                                          "FRE",
1044                                          "PC",
1045                                          "C",
1046                                          "MN",
1047                                          "AM",
1048                                          "PR",
1049                                          "AR",
1050                                          "DFM"};   // les vraibles booléennes du RB
1051 
1052       std::vector< std::string > varTer{
1053          "NBC",
1054          "MED",
1055          "DEM",
1056          "SP"};   // les variables pouvant prendre 3 valeurs possibles du RB
1057 
1058       std::vector< std::string > varContinuous{"A", "ADL"};   // les variables continues du RB
1059 
1060 
1061       std::vector< gum::NodeId > nodeList;   // Liste des noeuds du RB
1062 
1063       for (auto var: varBool)
1064         nodeList.push_back(templ.add(
1065            gum::LabelizedVariable(var,
1066                                   var,
1067                                   2)));   // Ajout des variables booléennes à la liste des noeuds
1068 
1069       for (auto var: varTer)
1070         nodeList.push_back(templ.add(
1071            gum::LabelizedVariable(var,
1072                                   var,
1073                                   3)));   // Ajout des variables ternaires à la liste des noeuds
1074 
1075       gum::DiscretizedVariable< int > A("A", "A");
1076       for (int i = 60; i <= 105; i += 5) {
1077         A.addTick(i);
1078       }
1079 
1080       nodeList.push_back(templ.add(A));   // Ajout de la variable Age allant de 60
1081       // à 100 ans à la liste des noeuds
1082 
1083       // Ajout de la variable ADL allant de 0 à 6 à la liste des noeuds
1084       nodeList.push_back(templ.add(gum::RangeVariable("ADL", "ADL", 0, 6)));
1085       // Création du noeud central NRC (niveau de risque de chute)
1086       gum::LabelizedVariable NRC("NRC", "NRC", 0);
1087 
1088       NRC.addLabel("faible");
1089       NRC.addLabel("modere");
1090       NRC.addLabel("eleve");
1091       auto iNRC = templ.add(NRC);
1092 
1093       // Création des arcs partant du noeud NRC vers les autres noeuds
1094       for (auto node: nodeList) {
1095         templ.addArc(iNRC, node);
1096       }
1097 
1098 
1099       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/bugDoumenc.csv"), templ);
1100       learner.useScoreLog2Likelihood();
1101       learner.useAprioriSmoothing();
1102 
1103       auto bn = learner.learnParameters(templ.dag());
1104 
1105       const gum::DiscreteVariable& var_discr = bn.variable("A");
1106       int                          good      = 1;
1107       try {
1108         const gum::DiscretizedVariable< int >& xvar_discr
1109            = dynamic_cast< const gum::DiscretizedVariable< int >& >(var_discr);
1110         TS_ASSERT_EQUALS(xvar_discr.domainSize(), (gum::Size)9)
1111         TS_ASSERT_EQUALS(xvar_discr.label(0), "[60;65[")
1112         TS_ASSERT_EQUALS(xvar_discr.label(1), "[65;70[")
1113         TS_ASSERT_EQUALS(xvar_discr.label(8), "[100;105]")
1114       } catch (std::bad_cast&) { good = 0; }
1115       TS_ASSERT_EQUALS(good, 1)
1116     }
1117 
test_setSliceOrderWithNames()1118     void test_setSliceOrderWithNames() {
1119       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
1120       learner.setSliceOrder(
1121          {{"smoking", "lung_cancer"}, {"bronchitis", "visit_to_Asia"}, {"tuberculosis"}});
1122 
1123 
1124       gum::learning::BNLearner< double > learner2(GET_RESSOURCES_PATH("csv/asia3.csv"));
1125       TS_ASSERT_THROWS(learner2.setSliceOrder({{"smoking", "lung_cancer"},
1126                                                {"bronchitis", "visit_to_Asia"},
1127                                                {"smoking", "tuberculosis", "lung_cancer"}}),
1128                        gum::DuplicateElement);
1129 
1130       gum::learning::BNLearner< double > learner3(GET_RESSOURCES_PATH("csv/asia3.csv"));
1131       TS_ASSERT_THROWS(
1132          learner3.setSliceOrder(
1133             {{"smoking", "lung_cancer"}, {"bronchitis", "visit_to_Asia"}, {"CRUCRU"}}),
1134          gum::MissingVariableInDatabase);
1135     }
1136 
test_dirichlet()1137     void test_dirichlet() {
1138       auto bn = gum::BayesNet< double >::fastPrototype("A->B<-C->D->E<-B");
1139 
1140       gum::learning::BNDatabaseGenerator< double > genere(bn);
1141       genere.setRandomVarOrder();
1142       genere.drawSamples(2000);
1143       genere.toCSV(GET_RESSOURCES_PATH("outputs/bnlearner_dirichlet.csv"));
1144 
1145       auto bn2 = gum::BayesNet< double >::fastPrototype("A->B->C->D->E");
1146       gum::learning::BNDatabaseGenerator< double > genere2(bn2);
1147       genere2.drawSamples(100);
1148       genere2.toCSV(GET_RESSOURCES_PATH("outputs/bnlearner_database.csv"));
1149 
1150       gum::learning::BNLearner< double > learner(
1151          GET_RESSOURCES_PATH("outputs/bnlearner_database.csv"),
1152          bn);
1153       learner.useAprioriDirichlet(GET_RESSOURCES_PATH("outputs/bnlearner_dirichlet.csv"), 10);
1154       learner.useScoreAIC();
1155 
1156       try {
1157         gum::BayesNet< double > bn3 = learner.learnBN();
1158       } catch (gum::Exception& e) { GUM_SHOWERROR(e); }
1159     }
1160 
1161 
test_dirichlet2()1162     void test_dirichlet2() {
1163       // read the learning database
1164       gum::learning::DBInitializerFromCSV<> initializer(
1165          GET_RESSOURCES_PATH("csv/db_dirichlet_learning.csv"));
1166       const auto&       var_names = initializer.variableNames();
1167       const std::size_t nb_vars   = var_names.size();
1168 
1169       gum::learning::DBTranslatorSet<>                translator_set;
1170       gum::learning::DBTranslator4LabelizedVariable<> translator;
1171       for (std::size_t i = 0; i < nb_vars; ++i) {
1172         translator_set.insertTranslator(translator, i);
1173       }
1174 
1175       gum::learning::DatabaseTable<> database(translator_set);
1176       database.setVariableNames(initializer.variableNames());
1177       initializer.fillDatabase(database);
1178 
1179 
1180       // read the apriori database
1181       gum::learning::DBInitializerFromCSV<> dirichlet_initializer(
1182          GET_RESSOURCES_PATH("csv/db_dirichlet_apriori.csv"));
1183       const auto&       dirichlet_var_names = initializer.variableNames();
1184       const std::size_t dirichlet_nb_vars   = dirichlet_var_names.size();
1185 
1186       gum::learning::DBTranslatorSet<> dirichlet_translator_set;
1187       for (std::size_t i = 0; i < dirichlet_nb_vars; ++i) {
1188         dirichlet_translator_set.insertTranslator(translator, i);
1189       }
1190 
1191       gum::learning::DatabaseTable<> dirichlet_database(dirichlet_translator_set);
1192       dirichlet_database.setVariableNames(dirichlet_initializer.variableNames());
1193       dirichlet_initializer.fillDatabase(dirichlet_database);
1194       dirichlet_database.reorder();
1195 
1196 
1197       // create the score and the apriori
1198       gum::learning::DBRowGeneratorSet<>            dirichlet_genset;
1199       gum::learning::DBRowGeneratorParser<>         dirichlet_parser(dirichlet_database.handler(),
1200                                                              dirichlet_genset);
1201       gum::learning::AprioriDirichletFromDatabase<> apriori(dirichlet_database, dirichlet_parser);
1202 
1203       gum::learning::DBRowGeneratorSet<>    genset;
1204       gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset);
1205 
1206       std::vector< double > weights{0, 1.0, 5.0, 10.0, 1000.0, 7000.0, 100000.0};
1207 
1208       gum::learning::BNLearner< double > learner(
1209          GET_RESSOURCES_PATH("csv/db_dirichlet_learning.csv"));
1210       learner.useScoreBIC();
1211 
1212       for (const auto weight: weights) {
1213         apriori.setWeight(weight);
1214         gum::learning::ScoreBIC<> score(parser, apriori);
1215 
1216         // finalize the learning algorithm
1217         gum::learning::StructuralConstraintSetStatic< gum::learning::StructuralConstraintDAG >
1218            struct_constraint;
1219 
1220         gum::learning::ParamEstimatorML<> estimator(parser, apriori, score.internalApriori());
1221 
1222         gum::learning::GraphChangesGenerator4DiGraph< decltype(struct_constraint) > op_set(
1223            struct_constraint);
1224 
1225         gum::learning::GraphChangesSelector4DiGraph< decltype(struct_constraint), decltype(op_set) >
1226            selector(score, struct_constraint, op_set);
1227 
1228         gum::learning::GreedyHillClimbing search;
1229 
1230         gum::BayesNet< double > bn = search.learnBN(selector, estimator);
1231         // std::cout << dag << std::endl;
1232 
1233 
1234         learner.useAprioriDirichlet(GET_RESSOURCES_PATH("csv/db_dirichlet_apriori.csv"), weight);
1235 
1236         gum::BayesNet< double > xbn = learner.learnBN();
1237 
1238         TS_ASSERT_EQUALS(xbn.moralGraph(), bn.moralGraph())
1239       }
1240     }
1241 
test_EM()1242     void test_EM() {
1243       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/EM.csv"),
1244                                                  true,
1245                                                  std::vector< std::string >{"?"});
1246 
1247       TS_ASSERT(learner.hasMissingValues())
1248 
1249       gum::DAG dag;
1250       for (std::size_t i = std::size_t(0); i < learner.database().nbVariables(); ++i) {
1251         dag.addNodeWithId(gum::NodeId(i));
1252       }
1253       dag.addArc(gum::NodeId(1), gum::NodeId(0));
1254       dag.addArc(gum::NodeId(2), gum::NodeId(1));
1255       dag.addArc(gum::NodeId(3), gum::NodeId(2));
1256 
1257       TS_ASSERT_THROWS(learner.learnParameters(dag), gum::MissingValueInDatabase)
1258 
1259       learner.useEM(1e-3);
1260       learner.useAprioriSmoothing();
1261 
1262       TS_GUM_ASSERT_THROWS_NOTHING(learner.learnParameters(dag, false));
1263       TS_GUM_ASSERT_THROWS_NOTHING(learner.nbrIterations());
1264     }
1265 
test_chi2()1266     void test_chi2() {
1267       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
1268 
1269       auto reschi2 = learner.chi2("smoking", "lung_cancer");
1270       TS_ASSERT_DELTA(reschi2.first, 36.2256, 1e-4)
1271       TS_ASSERT_DELTA(reschi2.second, 0, 1e-4)
1272 
1273       reschi2 = learner.chi2("smoking", "visit_to_Asia");
1274       TS_ASSERT_DELTA(reschi2.first, 1.1257, 1e-4)
1275       TS_ASSERT_DELTA(reschi2.second, 0.2886, 1e-4)
1276 
1277       reschi2 = learner.chi2("lung_cancer", "tuberculosis");
1278       TS_ASSERT_DELTA(reschi2.first, 0.6297, 1e-4)
1279       TS_ASSERT_DELTA(reschi2.second, 0.4274, 1e-4)
1280 
1281       reschi2 = learner.chi2("lung_cancer", "tuberculosis", {"tuberculos_or_cancer"});
1282       TS_ASSERT_DELTA(reschi2.first, 58.0, 1e-4)
1283       TS_ASSERT_DELTA(reschi2.second, 0.0, 1e-4)
1284 
1285       // see IndepTestChi2TestSuite::test_statistics
1286       gum::learning::BNLearner< double > learner2(GET_RESSOURCES_PATH("csv/chi2.csv"));
1287 
1288       auto stat = learner2.chi2("A", "C");
1289       TS_ASSERT_DELTA(stat.first, 0.0007, 1e-3)
1290       TS_ASSERT_DELTA(stat.second, 0.978, 1e-3)
1291 
1292       stat = learner2.chi2("A", "B");
1293       TS_ASSERT_DELTA(stat.first, 21.4348, 1e-3)
1294       TS_ASSERT_DELTA(stat.second, 3.6e-6, TS_GUM_SMALL_ERROR)
1295 
1296       stat = learner2.chi2("B", "A");
1297       TS_ASSERT_DELTA(stat.first, 21.4348, 1e-3)
1298       TS_ASSERT_DELTA(stat.second, 3.6e-6, TS_GUM_SMALL_ERROR)
1299 
1300       stat = learner2.chi2("B", "D");
1301       TS_ASSERT_DELTA(stat.first, 0.903, 1e-3)
1302       TS_ASSERT_DELTA(stat.second, 0.341, 1e-3)
1303 
1304       stat = learner2.chi2("A", "C", {"B"});
1305       TS_ASSERT_DELTA(stat.first, 15.2205, 1e-3)
1306       TS_ASSERT_DELTA(stat.second, 0.0005, 1e-4)
1307     }
1308 
test_G2()1309     void test_G2() {
1310       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
1311       auto                               resg2 = learner.G2("smoking", "lung_cancer");
1312       TS_ASSERT_DELTA(resg2.first, 43.0321, 1e-4)
1313       TS_ASSERT_DELTA(resg2.second, 0, 1e-4)
1314 
1315       resg2 = learner.G2("smoking", "visit_to_Asia");
1316       TS_ASSERT_DELTA(resg2.first, 1.1418, 1e-4)
1317       TS_ASSERT_DELTA(resg2.second, 0.2852, 1e-4)
1318 
1319       resg2 = learner.G2("lung_cancer", "tuberculosis");
1320       TS_ASSERT_DELTA(resg2.first, 1.2201, 1e-4)
1321       TS_ASSERT_DELTA(resg2.second, 0.2693, 1e-4)
1322 
1323       resg2 = learner.G2("lung_cancer", "tuberculosis", {"tuberculos_or_cancer"});
1324       TS_ASSERT_DELTA(resg2.first, 59.1386, 1e-4)
1325       TS_ASSERT_DELTA(resg2.second, 0.0, 1e-4)
1326 
1327       // see IndepTestChi2TestSuite::test_statistics
1328       gum::learning::BNLearner< double > learner2(GET_RESSOURCES_PATH("csv/chi2.csv"));
1329 
1330       auto stat = learner2.G2("A", "C");
1331       TS_ASSERT_DELTA(stat.first, 0.0007, 1e-3)
1332       TS_ASSERT_DELTA(stat.second, 0.978, 1e-4)
1333 
1334       stat = learner2.G2("A", "B");
1335       TS_ASSERT_DELTA(stat.first, 21.5846, 1e-3)
1336       TS_ASSERT_DELTA(stat.second, 3.6e-6, 1e-4)
1337 
1338       stat = learner2.G2("B", "A");
1339       TS_ASSERT_DELTA(stat.first, 21.5846, 1e-3)
1340       TS_ASSERT_DELTA(stat.second, 3.6e-6, 1e-4)
1341 
1342       stat = learner2.G2("B", "D");
1343       TS_ASSERT_DELTA(stat.first, 0.903, 1e-3)
1344       TS_ASSERT_DELTA(stat.second, 0.342, 1e-4)
1345 
1346       stat = learner2.G2("A", "C", {"B"});
1347       TS_ASSERT_DELTA(stat.first, 16.3470, 1e-3)
1348       TS_ASSERT_DELTA(stat.second, 0.0002, 1e-4)
1349     }
1350 
test_cmpG2Chi2()1351     void test_cmpG2Chi2() {
1352       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/testXYbase.csv"));
1353       auto                               statChi2 = learner.chi2("X", "Y");
1354       TS_ASSERT_DELTA(statChi2.first, 15.3389, 1e-3)
1355       TS_ASSERT_DELTA(statChi2.second, 0.01777843046460533, 1e-6)
1356       auto statG2 = learner.G2("X", "Y");
1357       TS_ASSERT_DELTA(statG2.first, 16.6066, 1e-3)
1358       TS_ASSERT_DELTA(statG2.second, 0.0108433, 1e-6)
1359     }
1360 
test_loglikelihood()1361     void test_loglikelihood() {
1362       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/chi2.csv"));
1363       TS_ASSERT_EQUALS(learner.nbRows(), (gum::Size)500)
1364       TS_ASSERT_EQUALS(learner.nbCols(), (gum::Size)4)
1365 
1366       double siz = -1.0 * learner.database().size();
1367       learner.useNoApriori();
1368 
1369       auto stat = learner.logLikelihood({"A"}) / siz;   // LL=-N.H
1370       TS_ASSERT_DELTA(stat, 0.99943499, TS_GUM_SMALL_ERROR)
1371       stat = learner.logLikelihood({"B"}) / siz;   // LL=-N.H
1372       TS_ASSERT_DELTA(stat, 0.9986032, TS_GUM_SMALL_ERROR)
1373       stat = learner.logLikelihood({std::string("A"), "B"}) / siz;   // LL=-N.H
1374       TS_ASSERT_DELTA(stat, 1.9668973, TS_GUM_SMALL_ERROR)
1375       stat = learner.logLikelihood({std::string("A")}, {"B"}) / siz;   // LL=-N.H
1376       TS_ASSERT_DELTA(stat, 1.9668973 - 0.9986032, TS_GUM_SMALL_ERROR)
1377 
1378       stat = learner.logLikelihood({"C"}) / siz;   // LL=-N.H
1379       TS_ASSERT_DELTA(stat, 0.99860302, TS_GUM_SMALL_ERROR)
1380       stat = learner.logLikelihood({"D"}) / siz;   // LL=-N.H
1381       TS_ASSERT_DELTA(stat, 0.40217919, TS_GUM_SMALL_ERROR)
1382       stat = learner.logLikelihood({std::string("C"), "D"}) / siz;   // LL=-N.H
1383       TS_ASSERT_DELTA(stat, 1.40077995, TS_GUM_SMALL_ERROR)
1384       stat = learner.logLikelihood({std::string("C")}, {"D"}) / siz;   // LL=-N.H
1385       TS_ASSERT_DELTA(stat, 1.40077995 - 0.40217919, TS_GUM_SMALL_ERROR)
1386     }
1387 
test_errorFromPyagrum()1388     void test_errorFromPyagrum() {
1389       try {
1390         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/sample_asia.csv"));
1391         learner.use3off2();
1392         learner.useNMLCorrection();
1393         auto ge3off2 = learner.learnMixedStructure();
1394       } catch (gum::Exception& e) { GUM_SHOWERROR(e); }
1395     }
1396 
test_PossibleEdges()1397     void test_PossibleEdges() {
1398       //[smoking , lung_cancer , bronchitis , visit_to_Asia , tuberculosis ,
1399       // tuberculos_or_cancer , dyspnoea , positive_XraY]
1400       {
1401         // possible edges are not relevant
1402         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
1403         learner.addPossibleEdge("visit_to_Asia", "lung_cancer");
1404         learner.addPossibleEdge("visit_to_Asia", "smoking");
1405 
1406         gum::BayesNet< double > bn = learner.learnBN();
1407         TS_ASSERT_EQUALS(bn.sizeArcs(), (gum::Size)0)
1408       }
1409 
1410       {
1411         // possible edges are relevant
1412         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
1413         learner.addPossibleEdge("smoking", "lung_cancer");
1414         learner.addPossibleEdge("bronchitis", "smoking");
1415 
1416         gum::BayesNet< double > bn = learner.learnBN();
1417         TS_ASSERT_EQUALS(bn.sizeArcs(), (gum::Size)2)
1418         TS_ASSERT(bn.parents("lung_cancer").contains(bn.idFromName("smoking")))
1419         TS_ASSERT(bn.parents("bronchitis").contains(bn.idFromName("smoking")))
1420       }
1421 
1422       {
1423         // possible edges are relevant
1424         // mixed with a forbidden arcs
1425         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
1426         learner.addPossibleEdge("smoking", "lung_cancer");
1427         learner.addPossibleEdge("bronchitis", "smoking");
1428         learner.addForbiddenArc("smoking", "bronchitis");
1429 
1430         gum::BayesNet< double > bn = learner.learnBN();
1431         TS_ASSERT_EQUALS(bn.sizeArcs(), (gum::Size)2)
1432         TS_ASSERT(bn.parents("lung_cancer").contains(bn.idFromName("smoking")))
1433         TS_ASSERT(bn.parents("smoking").contains(bn.idFromName("bronchitis")))
1434       }
1435 
1436       {
1437         // possible edges are relevant
1438         // mixed with a mandatory arcs
1439         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
1440         learner.addPossibleEdge("smoking", "lung_cancer");
1441         learner.addPossibleEdge("bronchitis", "smoking");
1442         learner.addMandatoryArc("visit_to_Asia", "bronchitis");
1443 
1444         gum::BayesNet< double > bn = learner.learnBN();
1445         TS_ASSERT_EQUALS(bn.sizeArcs(), (gum::Size)3)
1446         TS_ASSERT(bn.parents("lung_cancer").contains(bn.idFromName("smoking")))
1447         TS_ASSERT(bn.parents("smoking").contains(bn.idFromName("bronchitis")))
1448         TS_ASSERT(bn.parents("bronchitis").contains(bn.idFromName("visit_to_Asia")))
1449       }
1450     }
1451 
test_PossibleEdgesTabu()1452     void test_PossibleEdgesTabu() {
1453       //[smoking , lung_cancer , bronchitis , visit_to_Asia , tuberculosis ,
1454       // tuberculos_or_cancer , dyspnoea , positive_XraY]
1455       {
1456         // possible edges are not relevant
1457         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
1458         learner.useLocalSearchWithTabuList();
1459         learner.addPossibleEdge("visit_to_Asia", "lung_cancer");
1460         learner.addPossibleEdge("visit_to_Asia", "smoking");
1461 
1462         gum::BayesNet< double > bn = learner.learnBN();
1463         TS_ASSERT_EQUALS(bn.sizeArcs(), (gum::Size)0)
1464       }
1465 
1466       {
1467         // possible edges are relevant
1468         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
1469         learner.useLocalSearchWithTabuList();
1470         learner.addPossibleEdge("smoking", "lung_cancer");
1471         learner.addPossibleEdge("bronchitis", "smoking");
1472 
1473         gum::BayesNet< double > bn = learner.learnBN();
1474         TS_ASSERT_EQUALS(bn.sizeArcs(), (gum::Size)2)
1475         TS_ASSERT(bn.parents("lung_cancer").contains(bn.idFromName("smoking")))
1476         TS_ASSERT(bn.parents("bronchitis").contains(bn.idFromName("smoking")))
1477       }
1478 
1479       {
1480         // possible edges are relevant
1481         // mixed with a forbidden arcs
1482         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
1483         learner.useLocalSearchWithTabuList();
1484         learner.addPossibleEdge("smoking", "lung_cancer");
1485         learner.addPossibleEdge("bronchitis", "smoking");
1486         learner.addForbiddenArc("smoking", "bronchitis");
1487 
1488         gum::BayesNet< double > bn = learner.learnBN();
1489         TS_ASSERT_EQUALS(bn.sizeArcs(), (gum::Size)2)
1490         TS_ASSERT(bn.parents("lung_cancer").contains(bn.idFromName("smoking")))
1491         TS_ASSERT(bn.parents("smoking").contains(bn.idFromName("bronchitis")))
1492       }
1493 
1494       {
1495         // possible edges are relevant
1496         // mixed with a mandatory arcs
1497         gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia3.csv"));
1498         learner.useLocalSearchWithTabuList();
1499         learner.addPossibleEdge("smoking", "lung_cancer");
1500         learner.addPossibleEdge("bronchitis", "smoking");
1501         learner.addMandatoryArc("visit_to_Asia", "bronchitis");
1502 
1503         gum::BayesNet< double > bn = learner.learnBN();
1504         TS_ASSERT_EQUALS(bn.sizeArcs(), (gum::Size)3)
1505         TS_ASSERT(bn.parents("lung_cancer").contains(bn.idFromName("smoking")))
1506         TS_ASSERT(bn.parents("bronchitis").contains(bn.idFromName("smoking")))
1507         TS_ASSERT(bn.parents("bronchitis").contains(bn.idFromName("visit_to_Asia")))
1508       }
1509     }
1510 
testPseudoCount()1511     void testPseudoCount() {
1512       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/minimal.csv"));
1513       TS_ASSERT_EQUALS(learner.domainSize(0), 2u)
1514       TS_ASSERT_EQUALS(learner.domainSize("X"), 2u)
1515       TS_ASSERT_EQUALS(learner.domainSize(1), 2u)
1516       TS_ASSERT_EQUALS(learner.domainSize("Y"), 2u)
1517       TS_ASSERT_EQUALS(learner.domainSize(2), 3u)
1518       TS_ASSERT_EQUALS(learner.domainSize("Z"), 3u)
1519       learner.useNoApriori();
1520 
1521       TS_ASSERT_EQUALS(learner.rawPseudoCount(std::vector< gum::NodeId >({0})),
1522                        std::vector< double >({3, 4}));
1523       TS_ASSERT_EQUALS(learner.rawPseudoCount(std::vector< gum::NodeId >({0, 2})),
1524                        std::vector< double >({2, 1, 1, 1, 0, 2}));
1525       TS_ASSERT_EQUALS(learner.rawPseudoCount(std::vector< gum::NodeId >({2, 0})),
1526                        std::vector< double >({2, 1, 0, 1, 1, 2}));
1527 
1528       TS_ASSERT_EQUALS(learner.rawPseudoCount(std::vector< std::string >({"X"})),
1529                        std::vector< double >({3, 4}));
1530       TS_ASSERT_EQUALS(learner.rawPseudoCount(std::vector< std::string >({"X", "Z"})),
1531                        std::vector< double >({2, 1, 1, 1, 0, 2}));
1532       TS_ASSERT_EQUALS(learner.rawPseudoCount(std::vector< std::string >({"Z", "X"})),
1533                        std::vector< double >({2, 1, 0, 1, 1, 2}));
1534     }
1535 
testNonRegressionZeroCount()1536     void testNonRegressionZeroCount() {
1537       //////////////////////////
1538       // without specific score
1539       auto templ12 = gum::BayesNet< double >::fastPrototype("smoking->lung_cancer");
1540       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/asia.csv"), templ12);
1541       auto                               bn = learner.learnParameters(templ12.dag());
1542 
1543       gum::learning::BNLearner< double > learner2(GET_RESSOURCES_PATH("csv/asia.csv"), templ12);
1544       auto                               bn2 = learner2.learnParameters(templ12.dag());
1545       TS_ASSERT_EQUALS(bn.cpt("lung_cancer").toString(), bn2.cpt("lung_cancer").toString())
1546 
1547       //////////////////////////
1548       // with score AIC
1549       auto templ34 = gum::BayesNet< double >::fastPrototype("smoking[3]->lung_cancer[3]");
1550       gum::learning::BNLearner< double > learner3(GET_RESSOURCES_PATH("csv/asia.csv"), templ34);
1551       learner3.useScoreAIC();
1552       learner3.useAprioriSmoothing(1e-6);
1553 
1554       auto bn3 = learner3.learnParameters(templ34.dag());
1555       {
1556         const gum::Potential< double >& p = bn.cpt("lung_cancer");
1557         const gum::Potential< double >& q = bn3.cpt("lung_cancer");
1558 
1559         auto I = gum::Instantiation(p);
1560         auto J = gum::Instantiation(q);
1561 
1562         TS_ASSERT_DELTA(p[I], q[J], 1e-6)
1563         ++I;
1564         ++J;
1565         TS_ASSERT_DELTA(p[I], q[J], 1e-6)
1566         ++J;
1567         TS_ASSERT_DELTA(0.0, q[J], 1e-6)
1568         ++I;
1569         ++J;
1570         TS_ASSERT_DELTA(p[I], q[J], 1e-6)
1571         ++I;
1572         ++J;
1573         TS_ASSERT_DELTA(p[I], q[J], 1e-6)
1574         ++J;
1575         TS_ASSERT_DELTA(0.0, q[J], 1e-6)
1576         ++J;
1577         TS_ASSERT_DELTA(1.0 / 3.0, q[J], 1e-6)
1578         ++J;
1579         TS_ASSERT_DELTA(1.0 / 3.0, q[J], 1e-6)
1580         ++J;
1581         TS_ASSERT_DELTA(1.0 / 3.0, q[J], 1e-6)
1582       }
1583 
1584       auto templ35 = gum::BayesNet< double >::fastPrototype("smoking[3]->lung_cancer[3]");
1585       gum::learning::BNLearner< double > learner4(GET_RESSOURCES_PATH("csv/asia.csv"), templ35);
1586       learner4.useScoreAIC();
1587 
1588       TS_ASSERT_THROWS(learner4.learnParameters(templ34.dag()), gum::DatabaseError)
1589     }
1590 
test_misorientation_MIIC()1591     void test_misorientation_MIIC() {
1592       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/renewal.csv"));
1593 
1594       learner.useMIIC();
1595       learner.useNMLCorrection();
1596 
1597       auto bn            = learner.learnBN();
1598       auto expected_arcs = std::vector< std::pair< std::string, std::string > >(
1599          {{"coupon", "loyalty"},
1600           {"coupon", "recent visit"},
1601           {"loyalty", "renewal"},
1602           {"loyalty", "recent visit"},
1603           {"corporate customer", "loyalty"},
1604           {"corporate customer", "yearly consumption"},
1605           {"yearly consumption", "loyalty"},
1606           {"yearly consumption", "coupon"}});
1607       for (auto a: expected_arcs) {
1608         TS_ASSERT(bn.existsArc(a.first, a.second))
1609       }
1610     }
1611 
testState()1612     void testState() {
1613       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/renewal.csv"));
1614       {
1615         auto state = learner.state();
1616         TS_ASSERT_EQUALS(state.size(), (gum::Size)8)
1617         TS_ASSERT_EQUALS(std::get< 0 >(state[0]), "Filename")
1618         TS_ASSERT_EQUALS(std::get< 1 >(state[0]), GET_RESSOURCES_PATH("csv/renewal.csv"))
1619 
1620         TS_ASSERT_EQUALS(std::get< 0 >(state[1]), "Size")
1621         TS_ASSERT_EQUALS(std::get< 1 >(state[1]), "(50000,6)")
1622 
1623         TS_ASSERT_EQUALS(std::get< 0 >(state[2]), "Variables")
1624         TS_ASSERT_EQUALS(std::get< 1 >(state[2]),
1625                          "loyalty[2], renewal[2], yearly consumption[5], corporate customer[2], "
1626                          "coupon[2], recent visit[2]")
1627 
1628         TS_ASSERT_EQUALS(std::get< 0 >(state[3]), "Induced types")
1629         TS_ASSERT_EQUALS(std::get< 1 >(state[3]), "True")
1630 
1631         TS_ASSERT_EQUALS(std::get< 0 >(state[4]), "Missing values")
1632         TS_ASSERT_EQUALS(std::get< 1 >(state[4]), "False")
1633 
1634         TS_ASSERT_EQUALS(std::get< 0 >(state[5]), "Algorithm")
1635         TS_ASSERT_EQUALS(std::get< 1 >(state[5]), "Greedy Hill Climbing")
1636 
1637         TS_ASSERT_EQUALS(std::get< 0 >(state[6]), "Score")
1638         TS_ASSERT_EQUALS(std::get< 1 >(state[6]), "BDeu")
1639 
1640         TS_ASSERT_EQUALS(std::get< 0 >(state[7]), "Prior")
1641         TS_ASSERT_EQUALS(std::get< 1 >(state[7]), "-")
1642       }
1643 
1644       learner.useMIIC();
1645       learner.useNMLCorrection();
1646 
1647       {
1648         auto state = learner.state();
1649         TS_ASSERT_EQUALS(state.size(), (gum::Size)8)
1650         TS_ASSERT_EQUALS(std::get< 0 >(state[0]), "Filename")
1651         TS_ASSERT_EQUALS(std::get< 1 >(state[0]), GET_RESSOURCES_PATH("csv/renewal.csv"))
1652 
1653         TS_ASSERT_EQUALS(std::get< 0 >(state[1]), "Size")
1654         TS_ASSERT_EQUALS(std::get< 1 >(state[1]), "(50000,6)")
1655 
1656         TS_ASSERT_EQUALS(std::get< 0 >(state[2]), "Variables")
1657         TS_ASSERT_EQUALS(std::get< 1 >(state[2]),
1658                          "loyalty[2], renewal[2], yearly consumption[5], corporate customer[2], "
1659                          "coupon[2], recent visit[2]")
1660 
1661 
1662         TS_ASSERT_EQUALS(std::get< 0 >(state[3]), "Induced types")
1663         TS_ASSERT_EQUALS(std::get< 1 >(state[3]), "True")
1664 
1665         TS_ASSERT_EQUALS(std::get< 0 >(state[4]), "Missing values")
1666         TS_ASSERT_EQUALS(std::get< 1 >(state[4]), "False")
1667 
1668         TS_ASSERT_EQUALS(std::get< 0 >(state[5]), "Algorithm")
1669         TS_ASSERT_EQUALS(std::get< 1 >(state[5]), "MIIC")
1670 
1671         TS_ASSERT_EQUALS(std::get< 0 >(state[6]), "Correction")
1672         TS_ASSERT_EQUALS(std::get< 1 >(state[6]), "NML")
1673 
1674         TS_ASSERT_EQUALS(std::get< 0 >(state[7]), "Prior")
1675         TS_ASSERT_EQUALS(std::get< 1 >(state[7]), "-")
1676       }
1677 
1678       learner.addPossibleEdge("loyalty", "renewal");
1679       learner.setSliceOrder({{"loyalty", "renewal"}, {"recent visit", "corporate customer"}});
1680       {
1681         auto state = learner.state();
1682         TS_ASSERT_EQUALS(state.size(), (gum::Size)10)
1683         TS_ASSERT_EQUALS(std::get< 0 >(state[0]), "Filename")
1684         TS_ASSERT_EQUALS(std::get< 1 >(state[0]), GET_RESSOURCES_PATH("csv/renewal.csv"))
1685 
1686         TS_ASSERT_EQUALS(std::get< 0 >(state[1]), "Size")
1687         TS_ASSERT_EQUALS(std::get< 1 >(state[1]), "(50000,6)")
1688 
1689         TS_ASSERT_EQUALS(std::get< 0 >(state[2]), "Variables")
1690         TS_ASSERT_EQUALS(std::get< 1 >(state[2]),
1691                          "loyalty[2], renewal[2], yearly consumption[5], corporate customer[2], "
1692                          "coupon[2], recent visit[2]")
1693 
1694 
1695         TS_ASSERT_EQUALS(std::get< 0 >(state[3]), "Induced types")
1696         TS_ASSERT_EQUALS(std::get< 1 >(state[3]), "True")
1697 
1698         TS_ASSERT_EQUALS(std::get< 0 >(state[4]), "Missing values")
1699         TS_ASSERT_EQUALS(std::get< 1 >(state[4]), "False")
1700 
1701         TS_ASSERT_EQUALS(std::get< 0 >(state[5]), "Algorithm")
1702         TS_ASSERT_EQUALS(std::get< 1 >(state[5]), "MIIC")
1703 
1704         TS_ASSERT_EQUALS(std::get< 0 >(state[6]), "Correction")
1705         TS_ASSERT_EQUALS(std::get< 1 >(state[6]), "NML")
1706 
1707         TS_ASSERT_EQUALS(std::get< 0 >(state[7]), "Prior")
1708         TS_ASSERT_EQUALS(std::get< 1 >(state[7]), "-")
1709 
1710         TS_ASSERT_EQUALS(std::get< 0 >(state[8]), "Constraint Possible Edges")
1711         TS_ASSERT_EQUALS(std::get< 1 >(state[8]), "{loyalty--renewal}")
1712 
1713         TS_ASSERT_EQUALS(std::get< 0 >(state[9]), "Constraint Slice Order")
1714         TS_ASSERT_EQUALS(std::get< 1 >(state[9]),
1715                          "{corporate customer:1, renewal:0, loyalty:0, recent visit:1}")
1716       }
1717 
1718       gum::DAG dag;
1719       dag.addNodes(learner.nbCols());
1720       dag.addArc(0, 1);
1721       learner.setInitialDAG(dag);
1722       {
1723         auto state = learner.state();
1724         TS_ASSERT_EQUALS(state.size(), (gum::Size)11)
1725         TS_ASSERT_EQUALS(std::get< 0 >(state[10]), "Initial DAG")
1726         TS_ASSERT_EQUALS(std::get< 1 >(state[10]), "True")
1727       }
1728     }
1729 
testStateContinued()1730     void testStateContinued() {
1731       gum::learning::BNLearner< double > learner(GET_RESSOURCES_PATH("csv/renewal.csv"));
1732       learner.setDatabaseWeight(1000);
1733       learner.useK2(std::vector< gum::NodeId >{5, 4, 3, 2, 1, 0});
1734       {
1735         auto state = learner.state();
1736         TS_ASSERT_EQUALS(state.size(), (gum::Size)10)
1737         TS_ASSERT_EQUALS(std::get< 0 >(state[6]), "K2 order")
1738         TS_ASSERT_EQUALS(
1739            std::get< 1 >(state[6]),
1740            "recent visit, coupon, corporate customer, yearly consumption, renewal, loyalty")
1741 
1742         TS_ASSERT_EQUALS(std::get< 0 >(state[9]), "Database weight")
1743         TS_ASSERT_EQUALS(std::get< 1 >(state[9]), "1000.000000")
1744       }
1745       learner.useScoreAIC();
1746       learner.useAprioriBDeu();
1747       {
1748         auto state = learner.state();
1749         TS_ASSERT_EQUALS(state.size(), (gum::Size)11)
1750         TS_ASSERT_EQUALS(std::get< 0 >(state[8]), "Prior")
1751         TS_ASSERT_DIFFERS(std::get< 2 >(state[8]), "")   // there is a comment about AIC versus BDeu
1752       }
1753     }
1754   };   // class BNLearnerTestSuite
1755 } /* namespace gum_tests */
1756