1 /** 2 * 3 * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU) 4 * info_at_agrum_dot_org 5 * 6 * This library is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU Lesser General Public License as published by 8 * the Free Software Foundation, either version 3 of the License, or 9 * (at your option) any later version. 10 * 11 * This library is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 * GNU Lesser General Public License for more details. 15 * 16 * You should have received a copy of the GNU Lesser General Public License 17 * along with this library. If not, see <http://www.gnu.org/licenses/>. 18 * 19 */ 20 21 22 #include <gumtest/AgrumTestSuite.h> 23 #include <gumtest/testsuite_utils.h> 24 #include <iostream> 25 #include <sstream> 26 27 #include <agrum/tools/database/DBTranslator4LabelizedVariable.h> 28 #include <agrum/tools/database/DBTranslatorSet.h> 29 #include <agrum/BN/learning/aprioris/aprioriSmoothing.h> 30 #include <agrum/BN/learning/scores_and_tests/scoreK2.h> 31 #include <agrum/BN/learning/structureUtils/graphChangesSelector4DiGraph.h> 32 33 #include <agrum/BN/learning/constraints/structuralConstraintDAG.h> 34 #include <agrum/BN/learning/constraints/structuralConstraintDiGraph.h> 35 #include <agrum/BN/learning/structureUtils/graphChangesGenerator4DiGraph.h> 36 37 namespace gum_tests { 38 39 class GraphChangesSelector4DiGraphTestSuite: public CxxTest::TestSuite { 40 private: _order_nodes_(const std::vector<std::vector<double>> & all_scores,const std::vector<gum::NodeId> & best_nodes,std::vector<std::pair<gum::NodeId,double>> & sorted_nodes)41 void _order_nodes_(const std::vector< std::vector< double > >& all_scores, 42 const std::vector< gum::NodeId >& best_nodes, 43 std::vector< std::pair< gum::NodeId, double > >& sorted_nodes) { 44 const std::size_t size = best_nodes.size(); 45 for (std::size_t i = std::size_t(0); i < size; ++i) { 46 sorted_nodes[i].first = gum::NodeId(i); 47 sorted_nodes[i].second = all_scores[i][best_nodes[i]]; 48 } 49 50 std::sort( 51 sorted_nodes.begin(), 52 sorted_nodes.end(), 53 [](const std::pair< gum::NodeId, double >& a, 54 const std::pair< gum::NodeId, double >& b) -> bool { return a.second > b.second; }); 55 } 56 _compute_scores_(gum::learning::ScoreK2<> & score,const gum::DAG & graph,std::vector<std::vector<double>> & all_scores,std::vector<gum::NodeId> & best_nodes,gum::NodeId & best_node)57 void _compute_scores_(gum::learning::ScoreK2<>& score, 58 const gum::DAG& graph, 59 std::vector< std::vector< double > >& all_scores, 60 std::vector< gum::NodeId >& best_nodes, 61 gum::NodeId& best_node) { 62 const std::size_t size = best_nodes.size(); 63 64 for (std::size_t i = 0; i < size; ++i) { 65 for (std::size_t j = 0; j < size; ++j) { 66 const auto& parents = graph.parents(i); 67 if (i != j) { 68 std::vector< gum::NodeId > pars; 69 for (const auto par: parents) 70 pars.push_back(par); 71 all_scores[i][j] = -score.score(i, pars); 72 73 if (!parents.exists(j)) { 74 pars.push_back(gum::NodeId(j)); 75 all_scores[i][j] += score.score(i, pars); 76 } else { 77 for (auto& par: pars) { 78 if (par == gum::NodeId(j)) { 79 par = *(pars.rbegin()); 80 pars.pop_back(); 81 break; 82 } 83 } 84 all_scores[i][j] += score.score(i, pars); 85 } 86 } else { 87 all_scores[i][j] = std::numeric_limits< double >::lowest(); 88 } 89 } 90 } 91 92 double best_xscore = std::numeric_limits< double >::lowest(); 93 best_node = 0; 94 for (std::size_t i = 0; i < size; ++i) { 95 double best = all_scores[i][0]; 96 best_nodes[i] = 0; 97 for (std::size_t j = 1; j < size; ++j) { 98 if (all_scores[i][j] > best) { 99 best = all_scores[i][j]; 100 best_nodes[i] = j; 101 if (best_xscore < best) { 102 best_xscore = best; 103 best_node = i; 104 } 105 } 106 } 107 } 108 } 109 110 public: test_K2()111 void test_K2() { 112 // create the translator set 113 gum::LabelizedVariable var("X1", "", 0); 114 var.addLabel("0"); 115 var.addLabel("1"); 116 var.addLabel("2"); 117 118 gum::learning::DBTranslatorSet<> trans_set; 119 { 120 const std::vector< std::string > miss; 121 gum::learning::DBTranslator4LabelizedVariable<> translator(var, miss); 122 std::vector< std::string > names{"A", "B", "C", "D", "E", "F"}; 123 124 for (std::size_t i = std::size_t(0); i < names.size(); ++i) { 125 translator.setVariableName(names[i]); 126 trans_set.insertTranslator(translator, i); 127 } 128 } 129 130 // create the database 131 gum::learning::DatabaseTable<> database(trans_set); 132 std::vector< std::string > row0{"0", "1", "0", "2", "1", "1"}; 133 std::vector< std::string > row1{"1", "2", "0", "1", "2", "2"}; 134 std::vector< std::string > row2{"2", "1", "0", "1", "1", "0"}; 135 std::vector< std::string > row3{"1", "0", "0", "0", "0", "0"}; 136 std::vector< std::string > row4{"0", "0", "0", "1", "1", "1"}; 137 for (int i = 0; i < 1000; ++i) 138 database.insertRow(row0); 139 for (int i = 0; i < 50; ++i) 140 database.insertRow(row1); 141 for (int i = 0; i < 75; ++i) 142 database.insertRow(row2); 143 for (int i = 0; i < 75; ++i) 144 database.insertRow(row3); 145 for (int i = 0; i < 200; ++i) 146 database.insertRow(row4); 147 148 // create the parser 149 gum::learning::DBRowGeneratorSet<> genset; 150 gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset); 151 152 gum::learning::AprioriSmoothing<> apriori(database); 153 gum::learning::ScoreK2<> score(parser, apriori); 154 155 gum::learning::StructuralConstraintSetStatic< gum::learning::StructuralConstraintDiGraph > 156 struct_constraint; 157 158 gum::learning::GraphChangesGenerator4DiGraph< decltype(struct_constraint) > op_set( 159 struct_constraint); 160 161 gum::learning::GraphChangesSelector4DiGraph< decltype(struct_constraint), decltype(op_set) > 162 selector(score, struct_constraint, op_set); 163 164 gum::DAG graph; 165 selector.setGraph(graph); 166 167 TS_ASSERT(!selector.empty()) 168 for (const auto node: graph) { 169 TS_ASSERT(!selector.empty(node)) 170 } 171 172 selector.setGraph(graph); 173 174 TS_ASSERT(!selector.empty()) 175 for (const auto node: graph) { 176 TS_ASSERT(!selector.empty(node)) 177 } 178 179 gum::learning::GraphChange change(gum::learning::GraphChangeType::ARC_DELETION, 0, 1); 180 TS_ASSERT(!selector.isChangeValid(change)) 181 182 for (const auto node: graph) { 183 const auto& change = selector.bestChange(node); 184 TS_ASSERT_EQUALS(change.type(), gum::learning::GraphChangeType::ARC_ADDITION) 185 } 186 TS_ASSERT_EQUALS(selector.bestChange().type(), gum::learning::GraphChangeType::ARC_ADDITION) 187 188 std::vector< std::vector< double > > all_scores(6, std::vector< double >(6)); 189 std::vector< gum::NodeId > best_nodes(6); 190 gum::NodeId best_node; 191 _compute_scores_(score, graph, all_scores, best_nodes, best_node); 192 193 for (const auto node: graph) { 194 const auto& change = selector.bestChange(node); 195 TS_ASSERT_EQUALS(change.type(), gum::learning::GraphChangeType::ARC_ADDITION) 196 if (change.node1() == node) { 197 TS_ASSERT_EQUALS(change.node2(), best_nodes[node]) 198 } else { 199 TS_ASSERT_EQUALS(change.node1(), best_nodes[node]) 200 } 201 } 202 203 const double best_score = selector.bestScore(); 204 gum::NodeProperty< double > scores; 205 for (const auto node: graph) { 206 const double sc = selector.bestScore(node); 207 scores.insert(node, sc); 208 TS_ASSERT(sc <= best_score) 209 TS_ASSERT_EQUALS(sc, all_scores[node][best_nodes[node]]) 210 } 211 TS_ASSERT_EQUALS(best_score, all_scores[best_node][best_nodes[best_node]]) 212 213 gum::learning::GraphChange change2(gum::learning::GraphChangeType::ARC_ADDITION, 3, 1); 214 graph.addArc(change2.node1(), change2.node2()); 215 selector.applyChangeWithoutScoreUpdate(change2); 216 selector.updateScoresAfterAppliedChanges(); 217 218 _compute_scores_(score, graph, all_scores, best_nodes, best_node); 219 220 for (const auto node: graph) { 221 const double sc = selector.bestScore(node); 222 TS_ASSERT_EQUALS(sc, all_scores[node][best_nodes[node]]) 223 if (node != 1) { 224 TS_ASSERT_EQUALS(sc, scores[node]) 225 } else { 226 TS_ASSERT_DIFFERS(sc, scores[node]) 227 } 228 } 229 230 scores[1] = selector.bestScore(1); 231 scores[3] = selector.bestScore(3); 232 gum::learning::GraphChange change3(gum::learning::GraphChangeType::ARC_ADDITION, 3, 2); 233 graph.addArc(change3.node1(), change3.node2()); 234 selector.applyChange(change3); 235 236 _compute_scores_(score, graph, all_scores, best_nodes, best_node); 237 238 for (const auto node: graph) { 239 const double sc = selector.bestScore(node); 240 TS_ASSERT_EQUALS(sc, all_scores[node][best_nodes[node]]) 241 if ((node != 2)) { 242 TS_ASSERT_EQUALS(selector.bestScore(node), scores[node]) 243 } else { 244 TS_ASSERT_DIFFERS(selector.bestScore(node), scores[node]) 245 } 246 } 247 248 scores[2] = selector.bestScore(2); 249 scores[3] = selector.bestScore(3); 250 gum::learning::GraphChange change4(gum::learning::GraphChangeType::ARC_DELETION, 3, 1); 251 graph.eraseArc(gum::Arc(change4.node1(), change4.node2())); 252 selector.applyChange(change4); 253 254 _compute_scores_(score, graph, all_scores, best_nodes, best_node); 255 256 for (const auto node: graph) { 257 const double sc = selector.bestScore(node); 258 TS_ASSERT_EQUALS(sc, all_scores[node][best_nodes[node]]) 259 if ((node != 1)) { 260 TS_ASSERT_EQUALS(selector.bestScore(node), scores[node]) 261 } else { 262 TS_ASSERT_DIFFERS(selector.bestScore(node), scores[node]) 263 } 264 } 265 266 267 const auto xnodes = selector.nodesSortedByBestScore(); 268 TS_ASSERT_EQUALS(xnodes.size(), std::size_t(6)) 269 270 std::vector< std::pair< gum::NodeId, double > > sorted_nodes(6); 271 _order_nodes_(all_scores, best_nodes, sorted_nodes); 272 for (std::size_t i = 0; i < 6; ++i) { 273 TS_ASSERT_EQUALS(xnodes[i], sorted_nodes[i]) 274 } 275 } 276 }; 277 278 279 } /* namespace gum_tests */ 280