1 /**
2  *
3  *   Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU)
4  *   info_at_agrum_dot_org
5  *
6  *  This library is free software: you can redistribute it and/or modify
7  *  it under the terms of the GNU Lesser General Public License as published by
8  *  the Free Software Foundation, either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  This library is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU Lesser General Public License for more details.
15  *
16  *  You should have received a copy of the GNU Lesser General Public License
17  *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 #include <gumtest/AgrumTestSuite.h>
23 #include <gumtest/testsuite_utils.h>
24 #include <iostream>
25 
26 #include <agrum/tools/database/DBTranslator4LabelizedVariable.h>
27 #include <agrum/tools/database/DBTranslatorSet.h>
28 #include <agrum/BN/learning/aprioris/aprioriNoApriori.h>
29 #include <agrum/BN/learning/aprioris/aprioriSmoothing.h>
30 #include <agrum/BN/learning/paramUtils/paramEstimatorML.h>
31 #include <agrum/BN/learning/paramUtils/DAG2BNLearner.h>
32 #include <agrum/tools/database/DBRowGenerator4CompleteRows.h>
33 #include <agrum/tools/database/DBRowGeneratorEM.h>
34 
35 namespace gum_tests {
36 
37   class DAG2BNLearnerTestSuite: public CxxTest::TestSuite {
38     private:
_normalize_(const std::vector<double> & vin)39     std::vector< double > _normalize_(const std::vector< double >& vin) {
40       double sum = 0;
41       for (const auto& val: vin)
42         sum += val;
43       std::vector< double > vout(vin);
44       for (auto& val: vout)
45         val /= sum;
46       return vout;
47     }
48 
_xnormalize_(const std::vector<double> & vin)49     std::vector< double > _xnormalize_(const std::vector< double >& vin) {
50       std::vector< double > vout(vin);
51       for (std::size_t i = 0; i < vin.size(); i += 3) {
52         double sum = 0;
53         for (std::size_t j = std::size_t(0); j < 3; ++j)
54           sum += vin[i + j];
55         for (std::size_t j = std::size_t(0); j < 3; ++j)
56           vout[i + j] /= sum;
57       }
58       return vout;
59     }
60 
_getProba_(const gum::BayesNet<double> & bn,const gum::NodeId id)61     std::vector< double > _getProba_(const gum::BayesNet< double >& bn, const gum::NodeId id) {
62       const gum::Potential< double >& pot = bn.cpt(id);
63       std::vector< double >           vect;
64       for (gum::Instantiation inst(pot); !inst.end(); ++inst) {
65         vect.push_back(pot.get(inst));
66       }
67       return vect;
68     }
69 
70     public:
test1()71     void test1() {
72       // create the translator set
73       gum::LabelizedVariable var("X1", "", 0);
74       var.addLabel("0");
75       var.addLabel("1");
76       var.addLabel("2");
77 
78       gum::learning::DBTranslatorSet<> trans_set;
79       {
80         const std::vector< std::string >                miss;
81         gum::learning::DBTranslator4LabelizedVariable<> translator(var, miss);
82         std::vector< std::string >                      names{"A", "B", "C", "D", "E", "F"};
83 
84         for (std::size_t i = std::size_t(0); i < names.size(); ++i) {
85           translator.setVariableName(names[i]);
86           trans_set.insertTranslator(translator, i);
87         }
88       }
89 
90       // create the database
91       gum::learning::DatabaseTable<> database(trans_set);
92       std::vector< std::string >     row0{"0", "1", "0", "2", "1", "1"};
93       std::vector< std::string >     row1{"1", "2", "0", "1", "2", "2"};
94       std::vector< std::string >     row2{"2", "1", "0", "1", "1", "0"};
95       std::vector< std::string >     row3{"1", "0", "0", "0", "0", "0"};
96       std::vector< std::string >     row4{"0", "0", "0", "1", "1", "1"};
97       for (int i = 0; i < 1000; ++i)
98         database.insertRow(row0);
99       for (int i = 0; i < 50; ++i)
100         database.insertRow(row1);
101       for (int i = 0; i < 75; ++i)
102         database.insertRow(row2);
103       for (int i = 0; i < 75; ++i)
104         database.insertRow(row3);
105       for (int i = 0; i < 200; ++i)
106         database.insertRow(row4);
107 
108       // create the parser
109       gum::learning::DBRowGeneratorSet<>    genset;
110       gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset);
111       gum::learning::AprioriSmoothing<>     extern_apriori(database);
112       gum::learning::AprioriNoApriori<>     intern_apriori(database);
113 
114       gum::learning::ParamEstimatorML<> param_estimator(parser, extern_apriori, intern_apriori);
115 
116       gum::learning::DAG2BNLearner<> learner;
117 
118       gum::DAG dag;
119       for (std::size_t i = std::size_t(0); i < database.nbVariables(); ++i) {
120         dag.addNodeWithId(gum::NodeId(i));
121       }
122       dag.addArc(0, 1);
123       dag.addArc(2, 0);
124 
125       auto bn1 = learner.createBN(param_estimator, dag);
126 
127       auto                  v2  = _getProba_(bn1, 2);
128       std::vector< double > xv2 = _normalize_({1401, 1, 1});
129       TS_ASSERT_EQUALS(v2, xv2)
130 
131       auto                  v02  = _getProba_(bn1, 0);
132       std::vector< double > xv02 = _xnormalize_({1201, 126, 76, 1, 1, 1, 1, 1, 1});
133       TS_ASSERT_EQUALS(v02, xv02)
134     }
135 
136 
testEM()137     void testEM() {
138       gum::LabelizedVariable var("x", "", 0);
139       var.addLabel("0");
140       var.addLabel("1");
141       gum::learning::DBTranslatorSet<> trans_set;
142       {
143         const std::vector< std::string >                miss{"N/A", "?"};
144         gum::learning::DBTranslator4LabelizedVariable<> translator(var, miss);
145         std::vector< std::string >                      names{"A", "B", "C", "D"};
146 
147         for (std::size_t i = std::size_t(0); i < names.size(); ++i) {
148           translator.setVariableName(names[i]);
149           trans_set.insertTranslator(translator, i);
150         }
151       }
152 
153       gum::learning::DatabaseTable<> database(trans_set);
154       std::vector< std::string >     row1{"0", "1", "1", "0"};
155       std::vector< std::string >     row2{"0", "?", "0", "1"};
156       std::vector< std::string >     row3{"1", "?", "?", "0"};
157       std::vector< std::string >     row4{"?", "?", "1", "0"};
158       std::vector< std::string >     row5{"?", "0", "?", "?"};
159       for (int i = 0; i < 100; ++i) {
160         database.insertRow(row1);
161         database.insertRow(row2);
162         database.insertRow(row3);
163         database.insertRow(row4);
164         database.insertRow(row5);
165       }
166 
167       const std::vector< gum::learning::DBTranslatedValueType > col_types{
168          gum::learning::DBTranslatedValueType::DISCRETE,
169          gum::learning::DBTranslatedValueType::DISCRETE,
170          gum::learning::DBTranslatedValueType::DISCRETE,
171          gum::learning::DBTranslatedValueType::DISCRETE};
172 
173       auto bn = gum::BayesNet< double >::fastPrototype("A;B;C;D");
174       bn.cpt("A").fillWith({0.3, 0.7});
175       bn.cpt("B").fillWith({0.3, 0.7});
176       bn.cpt("C").fillWith({0.3, 0.7});
177       bn.cpt("D").fillWith({0.3, 0.7});
178 
179       // bugfix for parallel exceution of VariableElimination
180       {
181         const gum::DAG& dag = bn.dag();
182         for (const auto node: dag) {
183           dag.parents(node);
184           dag.children(node);
185         }
186       }
187 
188       // create the parser
189       gum::learning::DBRowGenerator4CompleteRows<> generator_id(col_types);
190       gum::learning::DBRowGeneratorSet<>           genset_id;
191       genset_id.insertGenerator(generator_id);
192       gum::learning::DBRowGeneratorParser<> parser_id(database.handler(), genset_id);
193 
194       gum::learning::AprioriSmoothing<> extern_apriori(database);
195       gum::learning::AprioriNoApriori<> intern_apriori(database);
196       gum::learning::ParamEstimatorML<> param_estimator_id(parser_id,
197                                                            extern_apriori,
198                                                            intern_apriori);
199 
200       gum::learning::DBRowGeneratorEM<>  generator_EM(col_types, bn);
201       gum::learning::DBRowGenerator<>&   gen_EM = generator_EM;   // fix for g++-4.8
202       gum::learning::DBRowGeneratorSet<> genset_EM;
203       genset_EM.insertGenerator(gen_EM);
204       gum::learning::DBRowGeneratorParser<> parser_EM(database.handler(), genset_EM);
205       gum::learning::ParamEstimatorML<>     param_estimator_EM(parser_EM,
206                                                            extern_apriori,
207                                                            intern_apriori);
208 
209       gum::learning::DAG2BNLearner<> learner;
210 
211       gum::DAG dag;
212       for (std::size_t i = std::size_t(0); i < database.nbVariables(); ++i) {
213         dag.addNodeWithId(gum::NodeId(i));
214       }
215       dag.addArc(gum::NodeId(1), gum::NodeId(0));
216       dag.addArc(gum::NodeId(2), gum::NodeId(1));
217       dag.addArc(gum::NodeId(3), gum::NodeId(2));
218 
219       learner.setEpsilon(1e-3);
220       bool ok;
221       for (int i = 0; i < 10; i++) {
222         ok       = true;
223         auto bn1 = learner.createBN(param_estimator_id, param_estimator_EM, dag);
224         auto margB
225            = (bn1.cpt("D") * bn1.cpt("C") * bn1.cpt("B"))
226                 .margSumIn(gum::Set< const gum::DiscreteVariable* >({&bn1.variableFromName("B")}));
227         if ((bn1.cpt("D").max() < 0.8) && (bn1.cpt("D").max() > 0.6) && (margB.max() > 0.5)
228             && (margB.max() < 0.6))
229           break;
230         ok = false;
231       }
232       TS_ASSERT(ok)
233     }
234   };
235 
236 }   // namespace gum_tests
237