1 /***************************************************************************
2  *   Copyright (c) 2007 by Lionel Torti                                    *
3  *   info_at_agrum_dot_org                                                  *
4  *                                                                         *
5  *   This program is free software; you can redistribute it and/or modify  *
6  *   it under the terms of the GNU General Public License as published by  *
7  *   the Free Software Foundation; either version 2 of the License, or     *
8  *   (at your option) any later version.                                   *
9  *                                                                         *
10  *   This program is distributed in the hope that it wil be useful,        *
11  *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
12  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
13  *   GNU General Public License for more details.                          *
14  *                                                                         *
15  *   You should have received a copy of the GNU General Public License     *
16  *   along with this program; if not, write to the                         *
17  *   Free Software Foundation, Inc.,                                       *
18  *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
19  ***************************************************************************/
20 
21 #include <gumtest/AgrumTestSuite.h>
22 #include <gumtest/testsuite_utils.h>
23 #include <iostream>
24 #include <vector>
25 #include <ressources/include/simpleDebugGenerator.h>
26 #include <ressources/include/evenDebugGenerator.h>
27 
28 #include <agrum/BN/inference/lazyPropagation.h>
29 #include <agrum/tools/database/DBTranslator4LabelizedVariable.h>
30 #include <agrum/tools/database/DBTranslator4ContinuousVariable.h>
31 #include <agrum/tools/database/DBRowGeneratorParser.h>
32 #include <agrum/tools/database/DBRowGeneratorEM.h>
33 
34 namespace gum_tests {
35 
36   class DBRowGeneratorParserTestSuite : public CxxTest::TestSuite {
37   private:
38     gum::Potential<double>
_infer_(const gum::BayesNet<double> & bn,const std::vector<std::size_t> & targets,const gum::learning::DBRow<gum::learning::DBTranslatedValue> & row)39      _infer_( const gum::BayesNet<double>& bn,
40              const std::vector<std::size_t>& targets,
41              const gum::learning::DBRow< gum::learning::DBTranslatedValue >& row) {
42       gum::LazyPropagation<double> ve(&bn);
43 
44       gum::NodeSet target_set;
45       for (auto target : targets) target_set.insert(gum::NodeId(target));
46       ve.addJointTarget(target_set);
47 
48       const auto xrow = row.row();
49       const auto row_size = xrow.size();
50       for (std::size_t col = std::size_t(0); col < row_size; ++col) {
51         if ( xrow[col].discr_val != std::numeric_limits<std::size_t>::max() ) {
52           ve.addEvidence(gum::NodeId(col), xrow[col].discr_val);
53         }
54       }
55 
56       gum::Potential<double> prob = ve.jointPosterior(target_set);
57       return prob;
58     }
59 
60 
61     public:
test_simple()62     void test_simple() {
63       gum::learning::DBTranslator4LabelizedVariable<>  translator_lab;
64       gum::learning::DBTranslator4ContinuousVariable<> translator_cont;
65       gum::learning::DBTranslatorSet<> set;
66       set.insertTranslator ( translator_lab,  0 );
67       set.insertTranslator ( translator_lab,  1 );
68       set.insertTranslator ( translator_cont, 2 );
69       set.insertTranslator ( translator_lab,  3 );
70       set[0].setVariableName ( "v0" );
71       set[1].setVariableName ( "v1" );
72       set[2].setVariableName ( "v2" );
73       set[3].setVariableName ( "v3" );
74 
75       gum::learning::DatabaseTable<> database ( set );
76       std::vector<std::string> row { "A0", "B0", "3.003", "C0" };
77       database.insertRow( row );
78 
79       row[0] = "A1";
80       row[1] = "B1";
81       row[2] = "3.113";
82       row[3] = "C1";
83       database.insertRow( row );
84 
85       row[0] = "A2";
86       row[1] = "B2";
87       row[2] = "3.223";
88       row[3] = "C2";
89       database.insertRow( row );
90 
91       row[0] = "A3";
92       row[1] = "B3";
93       row[2] = "3.333";
94       row[3] = "C3";
95       database.insertRow( row );
96 
97       const std::vector<gum::learning::DBTranslatedValueType>
98         col_types { gum::learning::DBTranslatedValueType::DISCRETE,
99                     gum::learning::DBTranslatedValueType::DISCRETE,
100                     gum::learning::DBTranslatedValueType::CONTINUOUS,
101                     gum::learning::DBTranslatedValueType::DISCRETE };
102 
103       gum::learning::SimpleDebugGenerator<>  generator1 ( col_types, 6 );
104       gum::learning::EvenDebugGenerator<> generator2 ( col_types, 4 );
105       gum::learning::DBRowGeneratorSet<> genset;
106       genset.insertGenerator ( generator1 );
107       genset.insertGenerator ( generator2 );
108 
109       gum::learning::DBRowGeneratorParser<>
110         parser ( database.handler (), genset );
111 
112       std::size_t nb_rows = std::size_t(0);
113       while ( parser.hasRows () ) {
114         const auto& row = parser.row().row();
115         TS_ASSERT ( row[0].discr_val == nb_rows / 12 )
116         TS_ASSERT ( row[1].discr_val == nb_rows / 12 )
117         TS_ASSERT ( row[2].cont_val  == 3.003f + (nb_rows / 12) * 11 / 100.0f )
118         TS_ASSERT ( row[3].discr_val == nb_rows / 12 )
119         ++nb_rows;
120       }
121       TS_ASSERT ( nb_rows == std::size_t( 48 ) )
122 
123       nb_rows = std::size_t(0);
124       while ( parser.hasRows () ) {
125         parser.row().row();
126          ++nb_rows;
127       }
128       TS_ASSERT ( nb_rows == std::size_t( 0 ) )
129 
130       parser.reset ();
131       nb_rows = std::size_t(0);
132       while ( parser.hasRows () ) {
133         parser.row().row();
134          ++nb_rows;
135       }
136       TS_ASSERT ( nb_rows == std::size_t( 48 ) )
137 
138       TS_ASSERT ( ( parser.handler().range() ==
139                     std::make_pair<std::size_t,std::size_t> (0,4) ) );
140 
141       const auto& xgenset = parser.generatorSet ();
142       TS_ASSERT ( xgenset.nbGenerators() == std::size_t(2) )
143 
144 
145       const std::vector<std::size_t> cols_of_interest { std::size_t(0),
146                                                         std::size_t(2),
147                                                         std::size_t(3) };
148       parser.setColumnsOfInterest ( cols_of_interest );
149       parser.reset ();
150       const auto& cols0 = xgenset[0].columnsOfInterest ();
151       const auto& cols1 = xgenset[1].columnsOfInterest ();
152       TS_ASSERT ( cols_of_interest == cols0 )
153       TS_ASSERT ( cols_of_interest == cols1 )
154 
155       gum::learning::DBRowGeneratorParser<> parser2 ( parser );
156       parser2.reset ();
157       nb_rows = std::size_t(0);
158       while ( parser2.hasRows () ) {
159         const auto& row = parser2.row().row();
160         TS_ASSERT ( row[0].discr_val == nb_rows / 12 )
161         TS_ASSERT ( row[1].discr_val == nb_rows / 12 )
162         TS_ASSERT ( row[2].cont_val  == 3.003f + (nb_rows / 12) * 11 / 100.0f )
163         TS_ASSERT ( row[3].discr_val == nb_rows / 12 )
164          ++nb_rows;
165       }
166       TS_ASSERT ( nb_rows == std::size_t( 48 ) )
167       const auto& ygenset = parser2.generatorSet ();
168       const auto& ycols0  = ygenset[0].columnsOfInterest ();
169       const auto& ycols1  = ygenset[1].columnsOfInterest ();
170       TS_ASSERT ( cols_of_interest == ycols0 )
171       TS_ASSERT ( cols_of_interest == ycols1 )
172 
173       gum::learning::DBRowGeneratorParser<>
174         parser3 ( parser, std::allocator<gum::learning::DBTranslatedValue> () );
175       parser3.reset ();
176       nb_rows = std::size_t(0);
177       while ( parser3.hasRows () ) {
178         const auto& row = parser3.row().row();
179         TS_ASSERT ( row[0].discr_val == nb_rows / 12 )
180         TS_ASSERT ( row[1].discr_val == nb_rows / 12 )
181         TS_ASSERT ( row[2].cont_val  == 3.003f + (nb_rows / 12) * 11 / 100.0f )
182         TS_ASSERT ( row[3].discr_val == nb_rows / 12 )
183          ++nb_rows;
184       }
185       TS_ASSERT ( nb_rows == std::size_t( 48 ) )
186       const auto& zgenset = parser3.generatorSet ();
187       const auto& zcols0  = zgenset[0].columnsOfInterest ();
188       const auto& zcols1  = zgenset[1].columnsOfInterest ();
189       TS_ASSERT ( cols_of_interest == zcols0 )
190       TS_ASSERT ( cols_of_interest == zcols1 )
191 
192       gum::learning::DBRowGeneratorParser<>
193         parser4 ( std::move ( parser3 ),
194                   std::allocator<gum::learning::DBTranslatedValue> () );
195       parser4.reset ();
196       nb_rows = std::size_t(0);
197       while ( parser4.hasRows () ) {
198         const auto& row = parser4.row().row();
199         TS_ASSERT ( row[0].discr_val == nb_rows / 12 )
200         TS_ASSERT ( row[1].discr_val == nb_rows / 12 )
201         TS_ASSERT ( row[2].cont_val  == 3.003f + (nb_rows / 12) * 11 / 100.0f )
202         TS_ASSERT ( row[3].discr_val == nb_rows / 12 )
203          ++nb_rows;
204       }
205       TS_ASSERT ( nb_rows == std::size_t( 48 ) )
206       const auto& pgenset = parser4.generatorSet ();
207       const auto& pcols0  = pgenset[0].columnsOfInterest ();
208       const auto& pcols1  = pgenset[1].columnsOfInterest ();
209       TS_ASSERT ( cols_of_interest == pcols0 )
210       TS_ASSERT ( cols_of_interest == pcols1 )
211 
212       gum::learning::DBRowGeneratorParser<> parser5 ( std::move ( parser4 ) );
213       parser5.reset ();
214       nb_rows = std::size_t(0);
215       while ( parser5.hasRows () ) {
216         const auto& row = parser5.row().row();
217         TS_ASSERT ( row[0].discr_val == nb_rows / 12 )
218         TS_ASSERT ( row[1].discr_val == nb_rows / 12 )
219         TS_ASSERT ( row[2].cont_val  == 3.003f + (nb_rows / 12) * 11 / 100.0f )
220         TS_ASSERT ( row[3].discr_val == nb_rows / 12 )
221          ++nb_rows;
222       }
223       TS_ASSERT ( nb_rows == std::size_t( 48 ) )
224       const auto& qgenset = parser5.generatorSet ();
225       const auto& qcols0  = qgenset[0].columnsOfInterest ();
226       const auto& qcols1  = qgenset[1].columnsOfInterest ();
227       TS_ASSERT ( cols_of_interest == qcols0 )
228       TS_ASSERT ( cols_of_interest == qcols1 )
229 
230       gum::learning::DBRowGeneratorParser<>* parser6 = parser.clone ();
231       parser6->reset ();
232       nb_rows = std::size_t(0);
233       while ( parser6->hasRows () ) {
234         const auto& row = parser6->row().row();
235         TS_ASSERT ( row[0].discr_val == nb_rows / 12 )
236         TS_ASSERT ( row[1].discr_val == nb_rows / 12 )
237         TS_ASSERT ( row[2].cont_val  == 3.003f + (nb_rows / 12) * 11 / 100.0f )
238         TS_ASSERT ( row[3].discr_val == nb_rows / 12 )
239          ++nb_rows;
240       }
241       TS_ASSERT ( nb_rows == std::size_t( 48 ) )
242       const auto& rgenset = parser6->generatorSet ();
243       const auto& rcols0  = rgenset[0].columnsOfInterest ();
244       const auto& rcols1  = rgenset[1].columnsOfInterest ();
245       TS_ASSERT ( cols_of_interest == rcols0 )
246       TS_ASSERT ( cols_of_interest == rcols1 )
247 
248       delete parser6;
249 
250       gum::learning::DBRowGeneratorParser<>* parser7 =
251         parser.clone (std::allocator<gum::learning::DBTranslatedValue>());
252       parser7->reset ();
253       nb_rows = std::size_t(0);
254       while ( parser7->hasRows () ) {
255         const auto& row = parser7->row().row();
256         TS_ASSERT ( row[0].discr_val == nb_rows / 12 )
257         TS_ASSERT ( row[1].discr_val == nb_rows / 12 )
258         TS_ASSERT ( row[2].cont_val  == 3.003f + (nb_rows / 12) * 11 / 100.0f )
259         TS_ASSERT ( row[3].discr_val == nb_rows / 12 )
260          ++nb_rows;
261       }
262       TS_ASSERT ( nb_rows == std::size_t( 48 ) )
263       const auto& sgenset = parser7->generatorSet ();
264       const auto& scols0  = sgenset[0].columnsOfInterest ();
265       const auto& scols1  = sgenset[1].columnsOfInterest ();
266       TS_ASSERT ( cols_of_interest == scols0 )
267       TS_ASSERT ( cols_of_interest == scols1 )
268 
269       delete parser7;
270 
271       gum::learning::DBRowGeneratorParser<>
272         parser8( database.handler (), gum::learning::DBRowGeneratorSet<> () );
273       nb_rows = std::size_t(0);
274       while ( parser8.hasRows () ) {
275         const auto& row = parser8.row().row();
276         TS_ASSERT ( row[0].discr_val == nb_rows )
277         TS_ASSERT ( row[1].discr_val == nb_rows )
278         TS_ASSERT ( row[2].cont_val  == 3.003f + (nb_rows * 11) / 100.0f )
279         TS_ASSERT ( row[3].discr_val == nb_rows )
280         ++nb_rows;
281       }
282       TS_ASSERT ( nb_rows == std::size_t( 4 ) )
283       const auto& tgenset = parser8.generatorSet ();
284       TS_ASSERT_EQUALS( tgenset.size() , std::size_t(0) )
285 
286       gum::learning::DBRowGeneratorParser<> parser9 ( parser8 );
287       parser9.reset ();
288       nb_rows = std::size_t(0);
289       while ( parser9.hasRows () ) {
290         const auto& row = parser9.row().row();
291         TS_ASSERT ( row[0].discr_val == nb_rows )
292         TS_ASSERT ( row[1].discr_val == nb_rows )
293         TS_ASSERT ( row[2].cont_val  == 3.003f + (nb_rows * 11) / 100.0f )
294         TS_ASSERT ( row[3].discr_val == nb_rows )
295         ++nb_rows;
296       }
297       TS_ASSERT ( nb_rows == std::size_t( 4 ) )
298       TS_ASSERT_EQUALS( parser9.generatorSet().size() , std::size_t(0) )
299 
300       parser8 = parser2;
301       parser8.reset ();
302       nb_rows = std::size_t(0);
303       while ( parser8.hasRows () ) {
304         const auto& row = parser8.row().row();
305         TS_ASSERT ( row[0].discr_val == nb_rows / 12 )
306         TS_ASSERT ( row[1].discr_val == nb_rows / 12 )
307         TS_ASSERT ( row[2].cont_val  == 3.003f + (nb_rows / 12) * 11 / 100.0f )
308         TS_ASSERT ( row[3].discr_val == nb_rows / 12 )
309          ++nb_rows;
310       }
311       TS_ASSERT ( nb_rows == std::size_t( 48 ) )
312       const auto& agenset = parser8.generatorSet ();
313       const auto& acols0  = agenset[0].columnsOfInterest ();
314       const auto& acols1  = agenset[1].columnsOfInterest ();
315       TS_ASSERT ( cols_of_interest == acols0 )
316       TS_ASSERT ( cols_of_interest == acols1 )
317 
318       parser2 = std::move ( parser9 );
319 
320       parser2.reset ();
321       nb_rows = std::size_t(0);
322       while ( parser2.hasRows () ) {
323         const auto& row = parser2.row().row();
324         TS_ASSERT ( row[0].discr_val == nb_rows )
325         TS_ASSERT ( row[1].discr_val == nb_rows )
326         TS_ASSERT ( row[2].cont_val  == 3.003f + (nb_rows * 11) / 100.0f )
327         TS_ASSERT ( row[3].discr_val == nb_rows )
328         ++nb_rows;
329       }
330       TS_ASSERT ( nb_rows == std::size_t( 4 ) )
331       TS_ASSERT_EQUALS( parser2.generatorSet().size() , std::size_t(0) )
332 
333     }
334 
335 
testEM()336     void testEM () {
337       const std::vector< gum::learning::DBTranslatedValueType > col_types{
338         gum::learning::DBTranslatedValueType::DISCRETE,
339         gum::learning::DBTranslatedValueType::DISCRETE,
340         gum::learning::DBTranslatedValueType::DISCRETE,
341         gum::learning::DBTranslatedValueType::DISCRETE};
342 
343       auto bn0 = gum::BayesNet< double >::fastPrototype("A;B;C;D");
344       bn0.cpt("A").fillWith({0.3, 0.7});
345       bn0.cpt("B").fillWith({0.3, 0.7});
346       bn0.cpt("C").fillWith({0.3, 0.7});
347       bn0.cpt("D").fillWith({0.3, 0.7});
348 
349       gum::LabelizedVariable var("x", "", 0);
350       var.addLabel("0");
351       var.addLabel("1");
352       const std::vector<std::string> miss {"N/A","?"};
353       gum::learning::DBTranslator4LabelizedVariable<> translator(var,miss);
354       gum::learning::DBTranslatorSet<> set;
355       for ( std::size_t i = std::size_t(0); i < std::size_t(4); ++i)
356         set.insertTranslator ( translator, i );
357 
358       set[0].setVariableName ( "v0" );
359       set[1].setVariableName ( "v1" );
360       set[2].setVariableName ( "v2" );
361       set[3].setVariableName ( "v3" );
362 
363       gum::learning::DatabaseTable<> database ( set );
364       std::vector<std::string> row1 { "0", "1", "1", "0" };
365       std::vector<std::string> row2 { "0", "?", "1", "0" };
366       std::vector<std::string> row3 { "0", "?", "?", "0" };
367       std::vector<std::string> row4 { "?", "?", "1", "0" };
368       std::vector<std::string> row5 { "?", "?", "?", "?" };
369       database.insertRow( row1 );
370       database.insertRow( row2 );
371       database.insertRow( row3 );
372       database.insertRow( row4 );
373       database.insertRow( row5 );
374 
375       auto handler = database.handler();
376 
377       gum::learning::DBRowGeneratorIdentity<> generator1(col_types);
378       gum::learning::DBRowGeneratorEM<>       generator2(col_types,bn0);
379       gum::learning::DBRowGenerator<>&        gen2 = generator2; // fix for g++-4.8
380       gum::learning::DBRowGeneratorIdentity<> generator3(col_types);
381       gum::learning::DBRowGeneratorEM<>       generator4(col_types,bn0);
382       gum::learning::DBRowGenerator<>&        gen4 = generator4; // fix for g++-4.8
383 
384       gum::learning::DBRowGeneratorSet<> genset;
385       genset.insertGenerator(generator1);
386       genset.insertGenerator(gen2);
387       genset.insertGenerator(generator3);
388       genset.insertGenerator(gen4);
389 
390       gum::learning::DBRowGeneratorParser<>
391         parser ( database.handler (), genset );
392 
393       auto bn = gum::BayesNet< double >::fastPrototype("A->B->C<-D");
394       bn.cpt("A").fillWith({0.3, 0.7});
395       bn.cpt("B").fillWith({0.4, 0.6, 0.7, 0.3});
396       bn.cpt("C").fillWith({0.2, 0.8, 0.3, 0.7, 0.4, 0.6, 0.5, 0.5});
397       bn.cpt("D").fillWith({0.8, 0.2});
398 
399       parser.setBayesNet(bn);
400 
401       const std::vector< std::size_t > cols_of_interest{std::size_t(0),
402                                                         std::size_t(1)};
403 
404       parser.setColumnsOfInterest(cols_of_interest);
405       TS_ASSERT(parser.hasRows())
406       {
407         const auto& row = parser.row();
408         const auto& xrow = row.row();
409 
410         TS_ASSERT_EQUALS(row.weight(), 1.0)
411         TS_ASSERT_EQUALS(xrow[0].discr_val, std::size_t(0))
412         TS_ASSERT_EQUALS(xrow[1].discr_val, std::size_t(1))
413       }
414 
415       for (int i = 0; i < 2; ++i ) {
416         ++handler;
417         TS_ASSERT(parser.hasRows())
418 
419         gum::Potential<double> proba =  _infer_(bn, {std::size_t(1)},
420                                                handler.row());
421         gum::Instantiation inst(proba);
422 
423         const auto& fill_row1  = parser.row();
424         const auto& xfill_row1 = fill_row1.row();
425         TS_ASSERT_EQUALS(xfill_row1[0].discr_val,std::size_t(0))
426         TS_ASSERT_EQUALS(xfill_row1[1].discr_val, std::size_t(0))
427         TS_ASSERT_DELTA(fill_row1.weight(), proba.get(inst),0.001)
428 
429         ++inst;
430         const auto& fill_row2  = parser.row();
431         const auto& xfill_row2 = fill_row2.row();
432         TS_ASSERT_EQUALS(xfill_row2[0].discr_val, std::size_t(0))
433         TS_ASSERT_EQUALS(xfill_row2[1].discr_val, std::size_t(1))
434         TS_ASSERT_DELTA(fill_row2.weight(), proba.get(inst), 0.001)
435       }
436 
437       for (int i = 0; i < 2; ++i ) {
438         ++handler;
439         TS_ASSERT(parser.hasRows())
440 
441         gum::Potential<double> proba =  _infer_(bn, {std::size_t(0),std::size_t(1)},
442                                                handler.row());
443 
444         std::vector<double> xproba (4);
445         std::vector<bool> observed(4, false);
446         std::size_t idx;
447        for (gum::Instantiation inst(proba); !inst.end(); ++inst) {
448          if ( proba.variablesSequence()[0]->name() == "A")
449            idx = inst.val(0) + std::size_t(2) * inst.val(1);
450          else
451            idx = inst.val(1) + std::size_t(2) * inst.val(0);
452           xproba[idx] = proba.get(inst);
453         }
454 
455         const auto& fill_row1  = parser.row();
456         const auto& xfill_row1 = fill_row1.row();
457         idx = xfill_row1[0].discr_val + std::size_t(2) * xfill_row1[1].discr_val;
458         observed[idx] = true;
459         TS_ASSERT_DELTA(fill_row1.weight(), xproba[idx], 0.001)
460 
461         const auto& fill_row2  = parser.row();
462         const auto& xfill_row2 = fill_row2.row();
463         idx = xfill_row2[0].discr_val + std::size_t(2) * xfill_row2[1].discr_val;
464         observed[idx] = true;
465         TS_ASSERT_DELTA(fill_row2.weight(), xproba[idx], 0.001)
466 
467         const auto& fill_row3  = parser.row();
468         const auto& xfill_row3 = fill_row3.row();
469         idx = xfill_row3[0].discr_val + std::size_t(2) * xfill_row3[1].discr_val;
470         observed[idx] = true;
471         TS_ASSERT_DELTA(fill_row3.weight(), xproba[idx],0.001)
472 
473         const auto& fill_row4  = parser.row();
474         const auto& xfill_row4 = fill_row4.row();
475         idx = xfill_row4[0].discr_val + std::size_t(2) * xfill_row4[1].discr_val;
476         observed[idx] = true;
477         TS_ASSERT_DELTA(fill_row4.weight(), xproba[idx],0.001)
478 
479         int nb_observed = 0;
480         for ( auto obs : observed)
481           if (obs) ++nb_observed;
482         TS_ASSERT_EQUALS(nb_observed, 4)
483       }
484 
485     }
486 
487   };
488 
489 
490 } /* namespace gum_tests */
491