1 /**
2  *
3  *   Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU)
4  *   info_at_agrum_dot_org
5  *
6  *  This library is free software: you can redistribute it and/or modify
7  *  it under the terms of the GNU Lesser General Public License as published by
8  *  the Free Software Foundation, either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  This library is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU Lesser General Public License for more details.
15  *
16  *  You should have received a copy of the GNU Lesser General Public License
17  *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 #include <gumtest/AgrumTestSuite.h>
23 #include <gumtest/testsuite_utils.h>
24 #include <iostream>
25 #include <sstream>
26 
27 #include <agrum/tools/database/DBTranslator4LabelizedVariable.h>
28 #include <agrum/tools/database/DBTranslatorSet.h>
29 #include <agrum/BN/learning/aprioris/aprioriSmoothing.h>
30 #include <agrum/BN/learning/scores_and_tests/scoreK2.h>
31 #include <agrum/BN/learning/structureUtils/graphChangesSelector4DiGraph.h>
32 
33 #include <agrum/BN/learning/constraints/structuralConstraintDAG.h>
34 #include <agrum/BN/learning/constraints/structuralConstraintDiGraph.h>
35 #include <agrum/BN/learning/structureUtils/graphChangesGenerator4DiGraph.h>
36 
37 namespace gum_tests {
38 
39   class GraphChangesSelector4DiGraphTestSuite: public CxxTest::TestSuite {
40     private:
_order_nodes_(const std::vector<std::vector<double>> & all_scores,const std::vector<gum::NodeId> & best_nodes,std::vector<std::pair<gum::NodeId,double>> & sorted_nodes)41     void _order_nodes_(const std::vector< std::vector< double > >&      all_scores,
42                        const std::vector< gum::NodeId >&                best_nodes,
43                        std::vector< std::pair< gum::NodeId, double > >& sorted_nodes) {
44       const std::size_t size = best_nodes.size();
45       for (std::size_t i = std::size_t(0); i < size; ++i) {
46         sorted_nodes[i].first  = gum::NodeId(i);
47         sorted_nodes[i].second = all_scores[i][best_nodes[i]];
48       }
49 
50       std::sort(
51          sorted_nodes.begin(),
52          sorted_nodes.end(),
53          [](const std::pair< gum::NodeId, double >& a,
54             const std::pair< gum::NodeId, double >& b) -> bool { return a.second > b.second; });
55     }
56 
_compute_scores_(gum::learning::ScoreK2<> & score,const gum::DAG & graph,std::vector<std::vector<double>> & all_scores,std::vector<gum::NodeId> & best_nodes,gum::NodeId & best_node)57     void _compute_scores_(gum::learning::ScoreK2<>&             score,
58                           const gum::DAG&                       graph,
59                           std::vector< std::vector< double > >& all_scores,
60                           std::vector< gum::NodeId >&           best_nodes,
61                           gum::NodeId&                          best_node) {
62       const std::size_t size = best_nodes.size();
63 
64       for (std::size_t i = 0; i < size; ++i) {
65         for (std::size_t j = 0; j < size; ++j) {
66           const auto& parents = graph.parents(i);
67           if (i != j) {
68             std::vector< gum::NodeId > pars;
69             for (const auto par: parents)
70               pars.push_back(par);
71             all_scores[i][j] = -score.score(i, pars);
72 
73             if (!parents.exists(j)) {
74               pars.push_back(gum::NodeId(j));
75               all_scores[i][j] += score.score(i, pars);
76             } else {
77               for (auto& par: pars) {
78                 if (par == gum::NodeId(j)) {
79                   par = *(pars.rbegin());
80                   pars.pop_back();
81                   break;
82                 }
83               }
84               all_scores[i][j] += score.score(i, pars);
85             }
86           } else {
87             all_scores[i][j] = std::numeric_limits< double >::lowest();
88           }
89         }
90       }
91 
92       double best_xscore = std::numeric_limits< double >::lowest();
93       best_node          = 0;
94       for (std::size_t i = 0; i < size; ++i) {
95         double best   = all_scores[i][0];
96         best_nodes[i] = 0;
97         for (std::size_t j = 1; j < size; ++j) {
98           if (all_scores[i][j] > best) {
99             best          = all_scores[i][j];
100             best_nodes[i] = j;
101             if (best_xscore < best) {
102               best_xscore = best;
103               best_node   = i;
104             }
105           }
106         }
107       }
108     }
109 
110     public:
test_K2()111     void test_K2() {
112       // create the translator set
113       gum::LabelizedVariable var("X1", "", 0);
114       var.addLabel("0");
115       var.addLabel("1");
116       var.addLabel("2");
117 
118       gum::learning::DBTranslatorSet<> trans_set;
119       {
120         const std::vector< std::string >                miss;
121         gum::learning::DBTranslator4LabelizedVariable<> translator(var, miss);
122         std::vector< std::string >                      names{"A", "B", "C", "D", "E", "F"};
123 
124         for (std::size_t i = std::size_t(0); i < names.size(); ++i) {
125           translator.setVariableName(names[i]);
126           trans_set.insertTranslator(translator, i);
127         }
128       }
129 
130       // create the database
131       gum::learning::DatabaseTable<> database(trans_set);
132       std::vector< std::string >     row0{"0", "1", "0", "2", "1", "1"};
133       std::vector< std::string >     row1{"1", "2", "0", "1", "2", "2"};
134       std::vector< std::string >     row2{"2", "1", "0", "1", "1", "0"};
135       std::vector< std::string >     row3{"1", "0", "0", "0", "0", "0"};
136       std::vector< std::string >     row4{"0", "0", "0", "1", "1", "1"};
137       for (int i = 0; i < 1000; ++i)
138         database.insertRow(row0);
139       for (int i = 0; i < 50; ++i)
140         database.insertRow(row1);
141       for (int i = 0; i < 75; ++i)
142         database.insertRow(row2);
143       for (int i = 0; i < 75; ++i)
144         database.insertRow(row3);
145       for (int i = 0; i < 200; ++i)
146         database.insertRow(row4);
147 
148       // create the parser
149       gum::learning::DBRowGeneratorSet<>    genset;
150       gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset);
151 
152       gum::learning::AprioriSmoothing<> apriori(database);
153       gum::learning::ScoreK2<>          score(parser, apriori);
154 
155       gum::learning::StructuralConstraintSetStatic< gum::learning::StructuralConstraintDiGraph >
156          struct_constraint;
157 
158       gum::learning::GraphChangesGenerator4DiGraph< decltype(struct_constraint) > op_set(
159          struct_constraint);
160 
161       gum::learning::GraphChangesSelector4DiGraph< decltype(struct_constraint), decltype(op_set) >
162          selector(score, struct_constraint, op_set);
163 
164       gum::DAG graph;
165       selector.setGraph(graph);
166 
167       TS_ASSERT(!selector.empty())
168       for (const auto node: graph) {
169         TS_ASSERT(!selector.empty(node))
170       }
171 
172       selector.setGraph(graph);
173 
174       TS_ASSERT(!selector.empty())
175       for (const auto node: graph) {
176         TS_ASSERT(!selector.empty(node))
177       }
178 
179       gum::learning::GraphChange change(gum::learning::GraphChangeType::ARC_DELETION, 0, 1);
180       TS_ASSERT(!selector.isChangeValid(change))
181 
182       for (const auto node: graph) {
183         const auto& change = selector.bestChange(node);
184         TS_ASSERT_EQUALS(change.type(), gum::learning::GraphChangeType::ARC_ADDITION)
185       }
186       TS_ASSERT_EQUALS(selector.bestChange().type(), gum::learning::GraphChangeType::ARC_ADDITION)
187 
188       std::vector< std::vector< double > > all_scores(6, std::vector< double >(6));
189       std::vector< gum::NodeId >           best_nodes(6);
190       gum::NodeId                          best_node;
191       _compute_scores_(score, graph, all_scores, best_nodes, best_node);
192 
193       for (const auto node: graph) {
194         const auto& change = selector.bestChange(node);
195         TS_ASSERT_EQUALS(change.type(), gum::learning::GraphChangeType::ARC_ADDITION)
196         if (change.node1() == node) {
197           TS_ASSERT_EQUALS(change.node2(), best_nodes[node])
198         } else {
199           TS_ASSERT_EQUALS(change.node1(), best_nodes[node])
200         }
201       }
202 
203       const double                best_score = selector.bestScore();
204       gum::NodeProperty< double > scores;
205       for (const auto node: graph) {
206         const double sc = selector.bestScore(node);
207         scores.insert(node, sc);
208         TS_ASSERT(sc <= best_score)
209         TS_ASSERT_EQUALS(sc, all_scores[node][best_nodes[node]])
210       }
211       TS_ASSERT_EQUALS(best_score, all_scores[best_node][best_nodes[best_node]])
212 
213       gum::learning::GraphChange change2(gum::learning::GraphChangeType::ARC_ADDITION, 3, 1);
214       graph.addArc(change2.node1(), change2.node2());
215       selector.applyChangeWithoutScoreUpdate(change2);
216       selector.updateScoresAfterAppliedChanges();
217 
218       _compute_scores_(score, graph, all_scores, best_nodes, best_node);
219 
220       for (const auto node: graph) {
221         const double sc = selector.bestScore(node);
222         TS_ASSERT_EQUALS(sc, all_scores[node][best_nodes[node]])
223         if (node != 1) {
224           TS_ASSERT_EQUALS(sc, scores[node])
225         } else {
226           TS_ASSERT_DIFFERS(sc, scores[node])
227         }
228       }
229 
230       scores[1] = selector.bestScore(1);
231       scores[3] = selector.bestScore(3);
232       gum::learning::GraphChange change3(gum::learning::GraphChangeType::ARC_ADDITION, 3, 2);
233       graph.addArc(change3.node1(), change3.node2());
234       selector.applyChange(change3);
235 
236       _compute_scores_(score, graph, all_scores, best_nodes, best_node);
237 
238       for (const auto node: graph) {
239         const double sc = selector.bestScore(node);
240         TS_ASSERT_EQUALS(sc, all_scores[node][best_nodes[node]])
241         if ((node != 2)) {
242           TS_ASSERT_EQUALS(selector.bestScore(node), scores[node])
243         } else {
244           TS_ASSERT_DIFFERS(selector.bestScore(node), scores[node])
245         }
246       }
247 
248       scores[2] = selector.bestScore(2);
249       scores[3] = selector.bestScore(3);
250       gum::learning::GraphChange change4(gum::learning::GraphChangeType::ARC_DELETION, 3, 1);
251       graph.eraseArc(gum::Arc(change4.node1(), change4.node2()));
252       selector.applyChange(change4);
253 
254       _compute_scores_(score, graph, all_scores, best_nodes, best_node);
255 
256       for (const auto node: graph) {
257         const double sc = selector.bestScore(node);
258         TS_ASSERT_EQUALS(sc, all_scores[node][best_nodes[node]])
259         if ((node != 1)) {
260           TS_ASSERT_EQUALS(selector.bestScore(node), scores[node])
261         } else {
262           TS_ASSERT_DIFFERS(selector.bestScore(node), scores[node])
263         }
264       }
265 
266 
267       const auto xnodes = selector.nodesSortedByBestScore();
268       TS_ASSERT_EQUALS(xnodes.size(), std::size_t(6))
269 
270       std::vector< std::pair< gum::NodeId, double > > sorted_nodes(6);
271       _order_nodes_(all_scores, best_nodes, sorted_nodes);
272       for (std::size_t i = 0; i < 6; ++i) {
273         TS_ASSERT_EQUALS(xnodes[i], sorted_nodes[i])
274       }
275     }
276   };
277 
278 
279 } /* namespace gum_tests */
280