1 /** 2 * 3 * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU) 4 * info_at_agrum_dot_org 5 * 6 * This library is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU Lesser General Public License as published by 8 * the Free Software Foundation, either version 3 of the License, or 9 * (at your option) any later version. 10 * 11 * This library is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 * GNU Lesser General Public License for more details. 15 * 16 * You should have received a copy of the GNU Lesser General Public License 17 * along with this library. If not, see <http://www.gnu.org/licenses/>. 18 * 19 */ 20 21 22 #include <gumtest/AgrumTestSuite.h> 23 #include <gumtest/testsuite_utils.h> 24 #include <iostream> 25 26 #include <agrum/tools/database/DBTranslator4LabelizedVariable.h> 27 #include <agrum/tools/database/DBTranslatorSet.h> 28 #include <agrum/BN/learning/aprioris/aprioriNoApriori.h> 29 #include <agrum/BN/learning/aprioris/aprioriSmoothing.h> 30 #include <agrum/BN/learning/paramUtils/paramEstimatorML.h> 31 #include <agrum/BN/learning/paramUtils/DAG2BNLearner.h> 32 #include <agrum/tools/database/DBRowGenerator4CompleteRows.h> 33 #include <agrum/tools/database/DBRowGeneratorEM.h> 34 35 namespace gum_tests { 36 37 class DAG2BNLearnerTestSuite: public CxxTest::TestSuite { 38 private: _normalize_(const std::vector<double> & vin)39 std::vector< double > _normalize_(const std::vector< double >& vin) { 40 double sum = 0; 41 for (const auto& val: vin) 42 sum += val; 43 std::vector< double > vout(vin); 44 for (auto& val: vout) 45 val /= sum; 46 return vout; 47 } 48 _xnormalize_(const std::vector<double> & vin)49 std::vector< double > _xnormalize_(const std::vector< double >& vin) { 50 std::vector< double > vout(vin); 51 for (std::size_t i = 0; i < vin.size(); i += 3) { 52 double sum = 0; 53 for (std::size_t j = std::size_t(0); j < 3; ++j) 54 sum += vin[i + j]; 55 for (std::size_t j = std::size_t(0); j < 3; ++j) 56 vout[i + j] /= sum; 57 } 58 return vout; 59 } 60 _getProba_(const gum::BayesNet<double> & bn,const gum::NodeId id)61 std::vector< double > _getProba_(const gum::BayesNet< double >& bn, const gum::NodeId id) { 62 const gum::Potential< double >& pot = bn.cpt(id); 63 std::vector< double > vect; 64 for (gum::Instantiation inst(pot); !inst.end(); ++inst) { 65 vect.push_back(pot.get(inst)); 66 } 67 return vect; 68 } 69 70 public: test1()71 void test1() { 72 // create the translator set 73 gum::LabelizedVariable var("X1", "", 0); 74 var.addLabel("0"); 75 var.addLabel("1"); 76 var.addLabel("2"); 77 78 gum::learning::DBTranslatorSet<> trans_set; 79 { 80 const std::vector< std::string > miss; 81 gum::learning::DBTranslator4LabelizedVariable<> translator(var, miss); 82 std::vector< std::string > names{"A", "B", "C", "D", "E", "F"}; 83 84 for (std::size_t i = std::size_t(0); i < names.size(); ++i) { 85 translator.setVariableName(names[i]); 86 trans_set.insertTranslator(translator, i); 87 } 88 } 89 90 // create the database 91 gum::learning::DatabaseTable<> database(trans_set); 92 std::vector< std::string > row0{"0", "1", "0", "2", "1", "1"}; 93 std::vector< std::string > row1{"1", "2", "0", "1", "2", "2"}; 94 std::vector< std::string > row2{"2", "1", "0", "1", "1", "0"}; 95 std::vector< std::string > row3{"1", "0", "0", "0", "0", "0"}; 96 std::vector< std::string > row4{"0", "0", "0", "1", "1", "1"}; 97 for (int i = 0; i < 1000; ++i) 98 database.insertRow(row0); 99 for (int i = 0; i < 50; ++i) 100 database.insertRow(row1); 101 for (int i = 0; i < 75; ++i) 102 database.insertRow(row2); 103 for (int i = 0; i < 75; ++i) 104 database.insertRow(row3); 105 for (int i = 0; i < 200; ++i) 106 database.insertRow(row4); 107 108 // create the parser 109 gum::learning::DBRowGeneratorSet<> genset; 110 gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset); 111 gum::learning::AprioriSmoothing<> extern_apriori(database); 112 gum::learning::AprioriNoApriori<> intern_apriori(database); 113 114 gum::learning::ParamEstimatorML<> param_estimator(parser, extern_apriori, intern_apriori); 115 116 gum::learning::DAG2BNLearner<> learner; 117 118 gum::DAG dag; 119 for (std::size_t i = std::size_t(0); i < database.nbVariables(); ++i) { 120 dag.addNodeWithId(gum::NodeId(i)); 121 } 122 dag.addArc(0, 1); 123 dag.addArc(2, 0); 124 125 auto bn1 = learner.createBN(param_estimator, dag); 126 127 auto v2 = _getProba_(bn1, 2); 128 std::vector< double > xv2 = _normalize_({1401, 1, 1}); 129 TS_ASSERT_EQUALS(v2, xv2) 130 131 auto v02 = _getProba_(bn1, 0); 132 std::vector< double > xv02 = _xnormalize_({1201, 126, 76, 1, 1, 1, 1, 1, 1}); 133 TS_ASSERT_EQUALS(v02, xv02) 134 } 135 136 testEM()137 void testEM() { 138 gum::LabelizedVariable var("x", "", 0); 139 var.addLabel("0"); 140 var.addLabel("1"); 141 gum::learning::DBTranslatorSet<> trans_set; 142 { 143 const std::vector< std::string > miss{"N/A", "?"}; 144 gum::learning::DBTranslator4LabelizedVariable<> translator(var, miss); 145 std::vector< std::string > names{"A", "B", "C", "D"}; 146 147 for (std::size_t i = std::size_t(0); i < names.size(); ++i) { 148 translator.setVariableName(names[i]); 149 trans_set.insertTranslator(translator, i); 150 } 151 } 152 153 gum::learning::DatabaseTable<> database(trans_set); 154 std::vector< std::string > row1{"0", "1", "1", "0"}; 155 std::vector< std::string > row2{"0", "?", "0", "1"}; 156 std::vector< std::string > row3{"1", "?", "?", "0"}; 157 std::vector< std::string > row4{"?", "?", "1", "0"}; 158 std::vector< std::string > row5{"?", "0", "?", "?"}; 159 for (int i = 0; i < 100; ++i) { 160 database.insertRow(row1); 161 database.insertRow(row2); 162 database.insertRow(row3); 163 database.insertRow(row4); 164 database.insertRow(row5); 165 } 166 167 const std::vector< gum::learning::DBTranslatedValueType > col_types{ 168 gum::learning::DBTranslatedValueType::DISCRETE, 169 gum::learning::DBTranslatedValueType::DISCRETE, 170 gum::learning::DBTranslatedValueType::DISCRETE, 171 gum::learning::DBTranslatedValueType::DISCRETE}; 172 173 auto bn = gum::BayesNet< double >::fastPrototype("A;B;C;D"); 174 bn.cpt("A").fillWith({0.3, 0.7}); 175 bn.cpt("B").fillWith({0.3, 0.7}); 176 bn.cpt("C").fillWith({0.3, 0.7}); 177 bn.cpt("D").fillWith({0.3, 0.7}); 178 179 // bugfix for parallel exceution of VariableElimination 180 { 181 const gum::DAG& dag = bn.dag(); 182 for (const auto node: dag) { 183 dag.parents(node); 184 dag.children(node); 185 } 186 } 187 188 // create the parser 189 gum::learning::DBRowGenerator4CompleteRows<> generator_id(col_types); 190 gum::learning::DBRowGeneratorSet<> genset_id; 191 genset_id.insertGenerator(generator_id); 192 gum::learning::DBRowGeneratorParser<> parser_id(database.handler(), genset_id); 193 194 gum::learning::AprioriSmoothing<> extern_apriori(database); 195 gum::learning::AprioriNoApriori<> intern_apriori(database); 196 gum::learning::ParamEstimatorML<> param_estimator_id(parser_id, 197 extern_apriori, 198 intern_apriori); 199 200 gum::learning::DBRowGeneratorEM<> generator_EM(col_types, bn); 201 gum::learning::DBRowGenerator<>& gen_EM = generator_EM; // fix for g++-4.8 202 gum::learning::DBRowGeneratorSet<> genset_EM; 203 genset_EM.insertGenerator(gen_EM); 204 gum::learning::DBRowGeneratorParser<> parser_EM(database.handler(), genset_EM); 205 gum::learning::ParamEstimatorML<> param_estimator_EM(parser_EM, 206 extern_apriori, 207 intern_apriori); 208 209 gum::learning::DAG2BNLearner<> learner; 210 211 gum::DAG dag; 212 for (std::size_t i = std::size_t(0); i < database.nbVariables(); ++i) { 213 dag.addNodeWithId(gum::NodeId(i)); 214 } 215 dag.addArc(gum::NodeId(1), gum::NodeId(0)); 216 dag.addArc(gum::NodeId(2), gum::NodeId(1)); 217 dag.addArc(gum::NodeId(3), gum::NodeId(2)); 218 219 learner.setEpsilon(1e-3); 220 bool ok; 221 for (int i = 0; i < 10; i++) { 222 ok = true; 223 auto bn1 = learner.createBN(param_estimator_id, param_estimator_EM, dag); 224 auto margB 225 = (bn1.cpt("D") * bn1.cpt("C") * bn1.cpt("B")) 226 .margSumIn(gum::Set< const gum::DiscreteVariable* >({&bn1.variableFromName("B")})); 227 if ((bn1.cpt("D").max() < 0.8) && (bn1.cpt("D").max() > 0.6) && (margB.max() > 0.5) 228 && (margB.max() < 0.6)) 229 break; 230 ok = false; 231 } 232 TS_ASSERT(ok) 233 } 234 }; 235 236 } // namespace gum_tests 237