1 
2 /**
3  *
4  *   Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe
5  * GONZALES(_at_AMU) info_at_agrum_dot_org
6  *
7  *  This library is free software: you can redistribute it and/or modify
8  *  it under the terms of the GNU Lesser General Public License as published by
9  *  the Free Software Foundation, either version 3 of the License, or
10  *  (at your option) any later version.
11  *
12  *  This library is distributed in the hope that it will be useful,
13  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
14  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  *  GNU Lesser General Public License for more details.
16  *
17  *  You should have received a copy of the GNU Lesser General Public License
18  *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
19  *
20  */
21 
22 
23 /** @file
24  * @brief Implementation of gum::learning::ThreeOffTwo and MIIC
25  *
26  * @author Quentin FALCAND, Marvin LASSERRE and Pierre-Henri WUILLEMIN(_at_LIP6)
27  */
28 
29 #include <agrum/tools/core/math/math_utils.h>
30 #include <agrum/tools/core/hashTable.h>
31 #include <agrum/tools/core/heap.h>
32 #include <agrum/tools/core/timer.h>
33 #include <agrum/tools/graphs/mixedGraph.h>
34 #include <agrum/BN/learning/Miic.h>
35 #include <agrum/BN/learning/paramUtils/DAG2BNLearner.h>
36 #include <agrum/tools/stattests/correctedMutualInformation.h>
37 
38 
39 namespace gum {
40 
41   namespace learning {
42 
43     /// default constructor
Miic()44     Miic::Miic() : _maxLog_(100), _size_(0) { GUM_CONSTRUCTOR(Miic); }
45 
46     /// default constructor with maxLog
Miic(int maxLog)47     Miic::Miic(int maxLog) : _maxLog_(maxLog), _size_(0) { GUM_CONSTRUCTOR(Miic); }
48 
49     /// copy constructor
Miic(const Miic & from)50     Miic::Miic(const Miic& from) : ApproximationScheme(from), _size_(from._size_) {
51       GUM_CONS_CPY(Miic);
52     }
53 
54     /// move constructor
Miic(Miic && from)55     Miic::Miic(Miic&& from) : ApproximationScheme(std::move(from)), _size_(from._size_) {
56       GUM_CONS_MOV(Miic);
57     }
58 
59     /// destructor
~Miic()60     Miic::~Miic() { GUM_DESTRUCTOR(Miic); }
61 
62     /// copy operator
operator =(const Miic & from)63     Miic& Miic::operator=(const Miic& from) {
64       ApproximationScheme::operator=(from);
65       return *this;
66     }
67 
68     /// move operator
operator =(Miic && from)69     Miic& Miic::operator=(Miic&& from) {
70       ApproximationScheme::operator=(std::move(from));
71       return *this;
72     }
73 
74 
operator ()(const CondRanking & e1,const CondRanking & e2) const75     bool GreaterPairOn2nd::operator()(const CondRanking& e1, const CondRanking& e2) const {
76       return e1.second > e2.second;
77     }
78 
operator ()(const Ranking & e1,const Ranking & e2) const79     bool GreaterAbsPairOn2nd::operator()(const Ranking& e1, const Ranking& e2) const {
80       return std::abs(e1.second) > std::abs(e2.second);
81     }
82 
operator ()(const ProbabilisticRanking & e1,const ProbabilisticRanking & e2) const83     bool GreaterTupleOnLast::operator()(const ProbabilisticRanking& e1,
84                                         const ProbabilisticRanking& e2) const {
85       double p1xz = std::get< 2 >(e1);
86       double p1yz = std::get< 3 >(e1);
87       double p2xz = std::get< 2 >(e2);
88       double p2yz = std::get< 3 >(e2);
89       double I1   = std::get< 1 >(e1);
90       double I2   = std::get< 1 >(e2);
91       // First, we look at the sign of information.
92       // Then, the probability values
93       // and finally the abs value of information.
94       if ((I1 < 0 && I2 < 0) || (I1 >= 0 && I2 >= 0)) {
95         if (std::max(p1xz, p1yz) == std::max(p2xz, p2yz)) {
96           return std::abs(I1) > std::abs(I2);
97         } else {
98           return std::max(p1xz, p1yz) > std::max(p2xz, p2yz);
99         }
100       } else {
101         return I1 < I2;
102       }
103     }
104 
105     /// learns the structure of a MixedGraph
learnMixedStructure(CorrectedMutualInformation<> & mutualInformation,MixedGraph graph)106     MixedGraph Miic::learnMixedStructure(CorrectedMutualInformation<>& mutualInformation,
107                                          MixedGraph                    graph) {
108       timer_.reset();
109       current_step_ = 0;
110 
111       // clear the vector of latent arcs to be sure
112       _latentCouples_.clear();
113 
114       /// the heap of ranks, with the score, and the NodeIds of x, y and z.
115       Heap< CondRanking, GreaterPairOn2nd > rank;
116 
117       /// the variables separation sets
118       HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > sep_set;
119 
120       initiation_(mutualInformation, graph, sep_set, rank);
121 
122       iteration_(mutualInformation, graph, sep_set, rank);
123 
124       if (_useMiic_) {
125         orientationMiic_(mutualInformation, graph, sep_set);
126       } else {
127         orientation3off2_(mutualInformation, graph, sep_set);
128       }
129 
130       return graph;
131     }
132 
133     /*
134      * PHASE 1 : INITIATION
135      *
136      * We go over all edges and test if the variables are independent. If they
137      * are,
138      * the edge is deleted. If not, the best contributor is found.
139      */
initiation_(CorrectedMutualInformation<> & mutualInformation,MixedGraph & graph,HashTable<std::pair<NodeId,NodeId>,std::vector<NodeId>> & sepSet,Heap<CondRanking,GreaterPairOn2nd> & rank)140     void Miic::initiation_(CorrectedMutualInformation<>& mutualInformation,
141                            MixedGraph&                   graph,
142                            HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet,
143                            Heap< CondRanking, GreaterPairOn2nd >&                           rank) {
144       NodeId  x, y;
145       EdgeSet edges      = graph.edges();
146       Size    steps_init = edges.size();
147 
148       for (const Edge& edge: edges) {
149         x          = edge.first();
150         y          = edge.second();
151         double Ixy = mutualInformation.score(x, y);
152 
153         if (Ixy <= 0) {   //< K
154           graph.eraseEdge(edge);
155           sepSet.insert(std::make_pair(x, y), _emptySet_);
156         } else {
157           findBestContributor_(x, y, _emptySet_, graph, mutualInformation, rank);
158         }
159 
160         ++current_step_;
161         if (onProgress.hasListener()) {
162           GUM_EMIT3(onProgress, (current_step_ * 33) / steps_init, 0., timer_.step());
163         }
164       }
165     }
166 
167     /*
168      * PHASE 2 : ITERATION
169      *
170      * As long as we find important nodes for edges, we go over them to see if
171      * we can assess the independence of the variables.
172      */
iteration_(CorrectedMutualInformation<> & mutualInformation,MixedGraph & graph,HashTable<std::pair<NodeId,NodeId>,std::vector<NodeId>> & sepSet,Heap<CondRanking,GreaterPairOn2nd> & rank)173     void Miic::iteration_(CorrectedMutualInformation<>& mutualInformation,
174                           MixedGraph&                   graph,
175                           HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet,
176                           Heap< CondRanking, GreaterPairOn2nd >&                           rank) {
177       // if no triples to further examine pass
178       CondRanking best;
179 
180       Size steps_init = current_step_;
181       Size steps_iter = rank.size();
182 
183       try {
184         while (rank.top().second > 0.5) {
185           best = rank.pop();
186 
187           const NodeId          x  = std::get< 0 >(*(best.first));
188           const NodeId          y  = std::get< 1 >(*(best.first));
189           const NodeId          z  = std::get< 2 >(*(best.first));
190           std::vector< NodeId > ui = std::move(std::get< 3 >(*(best.first)));
191 
192           ui.push_back(z);
193           const double i_xy_ui = mutualInformation.score(x, y, ui);
194           if (i_xy_ui < 0) {
195             graph.eraseEdge(Edge(x, y));
196             sepSet.insert(std::make_pair(x, y), std::move(ui));
197           } else {
198             findBestContributor_(x, y, ui, graph, mutualInformation, rank);
199           }
200 
201           delete best.first;
202 
203           ++current_step_;
204           if (onProgress.hasListener()) {
205             GUM_EMIT3(onProgress,
206                       (current_step_ * 66) / (steps_init + steps_iter),
207                       0.,
208                       timer_.step());
209           }
210         }
211       } catch (...) {}   // here, rank is empty
212       current_step_ = steps_init + steps_iter;
213       if (onProgress.hasListener()) { GUM_EMIT3(onProgress, 66, 0., timer_.step()); }
214       current_step_ = steps_init + steps_iter;
215     }
216 
217     /*
218      * PHASE 3 : ORIENTATION
219      *
220      * Try to assess v-structures and propagate them.
221      */
orientation3off2_(CorrectedMutualInformation<> & mutualInformation,MixedGraph & graph,const HashTable<std::pair<NodeId,NodeId>,std::vector<NodeId>> & sepSet)222     void Miic::orientation3off2_(
223        CorrectedMutualInformation<>&                                          mutualInformation,
224        MixedGraph&                                                            graph,
225        const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
226       std::vector< Ranking > triples      = unshieldedTriples_(graph, mutualInformation, sepSet);
227       Size                   steps_orient = triples.size();
228       Size                   past_steps   = current_step_;
229 
230       // marks always correspond to the head of the arc/edge. - is for a forbidden
231       // arc, > for a mandatory arc
232       // we start by adding the mandatory arcs
233       for (auto iter = _initialMarks_.begin(); iter != _initialMarks_.end(); ++iter) {
234         if (graph.existsEdge(iter.key().first, iter.key().second) && iter.val() == '>') {
235           graph.eraseEdge(Edge(iter.key().first, iter.key().second));
236           graph.addArc(iter.key().first, iter.key().second);
237         }
238       }
239 
240       NodeId i = 0;
241       // list of elements that we shouldn't read again, ie elements that are
242       // eligible to
243       // rule 0 after the first time they are tested, and elements on which rule 1
244       // has been applied
245       while (i < triples.size()) {
246         // if i not in do_not_reread
247         Ranking triple = triples[i];
248         NodeId  x, y, z;
249         x = std::get< 0 >(*triple.first);
250         y = std::get< 1 >(*triple.first);
251         z = std::get< 2 >(*triple.first);
252 
253         std::vector< NodeId >       ui;
254         std::pair< NodeId, NodeId > key     = {x, y};
255         std::pair< NodeId, NodeId > rev_key = {y, x};
256         if (sepSet.exists(key)) {
257           ui = sepSet[key];
258         } else if (sepSet.exists(rev_key)) {
259           ui = sepSet[rev_key];
260         }
261         double Ixyz_ui = triple.second;
262         bool   reset{false};
263         // try Rule 0
264         if (Ixyz_ui < 0) {
265           // if ( z not in Sep[x,y])
266           if (std::find(ui.begin(), ui.end(), z) == ui.end()) {
267             if (!graph.existsArc(x, z) && !graph.existsArc(z, x)) {
268               // when we try to add an arc to the graph, we always verify if
269               // we are allowed to do so, ie it is not a forbidden arc an it
270               // does not create a cycle
271               if (!_existsDirectedPath_(graph, z, x) && !isForbidenArc_(x, z)) {
272                 reset = true;
273                 graph.eraseEdge(Edge(x, z));
274                 graph.addArc(x, z);
275               } else if (_existsDirectedPath_(graph, z, x) && !isForbidenArc_(z, x)) {
276                 reset = true;
277                 graph.eraseEdge(Edge(x, z));
278                 // if we find a cycle, we force the competing edge
279                 graph.addArc(z, x);
280                 if (std::find(_latentCouples_.begin(), _latentCouples_.end(), Arc(z, x))
281                     == _latentCouples_.end()) {
282                   _latentCouples_.emplace_back(z, x);
283                 }
284               }
285             } else if (!graph.existsArc(y, z) && !graph.existsArc(z, y)) {
286               if (!_existsDirectedPath_(graph, z, y) && !isForbidenArc_(x, z)) {
287                 reset = true;
288                 graph.eraseEdge(Edge(y, z));
289                 graph.addArc(y, z);
290               } else if (_existsDirectedPath_(graph, z, y) && !isForbidenArc_(z, y)) {
291                 reset = true;
292                 graph.eraseEdge(Edge(y, z));
293                 // if we find a cycle, we force the competing edge
294                 graph.addArc(z, y);
295                 if (std::find(_latentCouples_.begin(), _latentCouples_.end(), Arc(z, y))
296                     == _latentCouples_.end()) {
297                   _latentCouples_.emplace_back(z, y);
298                 }
299               }
300             } else {
301               // checking if the anti-directed arc already exists, to register a
302               // latent variable
303               if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) {
304                 _latentCouples_.emplace_back(z, x);
305               }
306               if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) {
307                 _latentCouples_.emplace_back(z, y);
308               }
309             }
310           }
311         } else {   // try Rule 1
312           if (graph.existsArc(x, z) && !graph.existsArc(z, y) && !graph.existsArc(y, z)) {
313             if (!_existsDirectedPath_(graph, y, z) && !isForbidenArc_(z, y)) {
314               reset = true;
315               graph.eraseEdge(Edge(z, y));
316               graph.addArc(z, y);
317             } else if (_existsDirectedPath_(graph, y, z) && !isForbidenArc_(y, z)) {
318               reset = true;
319               graph.eraseEdge(Edge(z, y));
320               // if we find a cycle, we force the competing edge
321               graph.addArc(y, z);
322               if (std::find(_latentCouples_.begin(), _latentCouples_.end(), Arc(y, z))
323                   == _latentCouples_.end()) {
324                 _latentCouples_.emplace_back(y, z);
325               }
326             }
327           }
328           if (graph.existsArc(y, z) && !graph.existsArc(z, x) && !graph.existsArc(x, z)) {
329             if (!_existsDirectedPath_(graph, x, z) && !isForbidenArc_(z, x)) {
330               reset = true;
331               graph.eraseEdge(Edge(z, x));
332               graph.addArc(z, x);
333             } else if (_existsDirectedPath_(graph, x, z) && !isForbidenArc_(x, z)) {
334               reset = true;
335               graph.eraseEdge(Edge(z, x));
336               // if we find a cycle, we force the competing edge
337               graph.addArc(x, z);
338               if (std::find(_latentCouples_.begin(), _latentCouples_.end(), Arc(x, z))
339                   == _latentCouples_.end()) {
340                 _latentCouples_.emplace_back(x, z);
341               }
342             }
343           }
344         }   // if rule 0 or rule 1
345 
346         // if what we want to add already exists : pass to the next triplet
347         if (reset) {
348           i = 0;
349         } else {
350           ++i;
351         }
352         if (onProgress.hasListener()) {
353           GUM_EMIT3(onProgress,
354                     ((current_step_ + i) * 100) / (past_steps + steps_orient),
355                     0.,
356                     timer_.step());
357         }
358       }   // while
359 
360       // erasing the the double headed arcs
361       for (const Arc& arc: _latentCouples_) {
362         graph.eraseArc(Arc(arc.head(), arc.tail()));
363       }
364     }
365 
366     /// variant trying to propagate both orientations in a bidirected arc
orientationLatents_(CorrectedMutualInformation<> & mutualInformation,MixedGraph & graph,const HashTable<std::pair<NodeId,NodeId>,std::vector<NodeId>> & sepSet)367     void Miic::orientationLatents_(
368        CorrectedMutualInformation<>&                                          mutualInformation,
369        MixedGraph&                                                            graph,
370        const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
371       std::vector< Ranking > triples      = unshieldedTriples_(graph, mutualInformation, sepSet);
372       Size                   steps_orient = triples.size();
373       Size                   past_steps   = current_step_;
374 
375       NodeId i = 0;
376       // list of elements that we shouldnt read again, ie elements that are
377       // eligible to
378       // rule 0 after the first time they are tested, and elements on which rule 1
379       // has been applied
380       while (i < triples.size()) {
381         // if i not in do_not_reread
382         Ranking triple = triples[i];
383         NodeId  x, y, z;
384         x = std::get< 0 >(*triple.first);
385         y = std::get< 1 >(*triple.first);
386         z = std::get< 2 >(*triple.first);
387 
388         std::vector< NodeId >       ui;
389         std::pair< NodeId, NodeId > key     = {x, y};
390         std::pair< NodeId, NodeId > rev_key = {y, x};
391         if (sepSet.exists(key)) {
392           ui = sepSet[key];
393         } else if (sepSet.exists(rev_key)) {
394           ui = sepSet[rev_key];
395         }
396         double Ixyz_ui = triple.second;
397         // try Rule 0
398         if (Ixyz_ui < 0) {
399           // if ( z not in Sep[x,y])
400           if (std::find(ui.begin(), ui.end(), z) == ui.end()) {
401             // if what we want to add already exists : pass
402             if ((graph.existsArc(x, z) || graph.existsArc(z, x))
403                 && (graph.existsArc(y, z) || graph.existsArc(z, y))) {
404               ++i;
405             } else {
406               i = 0;
407               graph.eraseEdge(Edge(x, z));
408               graph.eraseEdge(Edge(y, z));
409               // checking for cycles
410               if (graph.existsArc(z, x)) {
411                 graph.eraseArc(Arc(z, x));
412                 try {
413                   std::vector< NodeId > path = graph.directedPath(z, x);
414                   // if we find a cycle, we force the competing edge
415                   _latentCouples_.emplace_back(z, x);
416                 } catch (gum::NotFound) { graph.addArc(x, z); }
417                 graph.addArc(z, x);
418               } else {
419                 try {
420                   std::vector< NodeId > path = graph.directedPath(z, x);
421                   // if we find a cycle, we force the competing edge
422                   graph.addArc(z, x);
423                   _latentCouples_.emplace_back(z, x);
424                 } catch (gum::NotFound) { graph.addArc(x, z); }
425               }
426               if (graph.existsArc(z, y)) {
427                 graph.eraseArc(Arc(z, y));
428                 try {
429                   std::vector< NodeId > path = graph.directedPath(z, y);
430                   // if we find a cycle, we force the competing edge
431                   _latentCouples_.emplace_back(z, y);
432                 } catch (gum::NotFound) { graph.addArc(y, z); }
433                 graph.addArc(z, y);
434               } else {
435                 try {
436                   std::vector< NodeId > path = graph.directedPath(z, y);
437                   // if we find a cycle, we force the competing edge
438                   graph.addArc(z, y);
439                   _latentCouples_.emplace_back(z, y);
440 
441                 } catch (gum::NotFound) { graph.addArc(y, z); }
442               }
443               if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) {
444                 _latentCouples_.emplace_back(z, x);
445               }
446               if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) {
447                 _latentCouples_.emplace_back(z, y);
448               }
449             }
450           } else {
451             ++i;
452           }
453         } else {   // try Rule 1
454           bool reset{false};
455           if (graph.existsArc(x, z) && !graph.existsArc(z, y) && !graph.existsArc(y, z)) {
456             reset = true;
457             graph.eraseEdge(Edge(z, y));
458             try {
459               std::vector< NodeId > path = graph.directedPath(y, z);
460               // if we find a cycle, we force the competing edge
461               graph.addArc(y, z);
462               _latentCouples_.emplace_back(y, z);
463             } catch (gum::NotFound) { graph.addArc(z, y); }
464           }
465           if (graph.existsArc(y, z) && !graph.existsArc(z, x) && !graph.existsArc(x, z)) {
466             reset = true;
467             graph.eraseEdge(Edge(z, x));
468             try {
469               std::vector< NodeId > path = graph.directedPath(x, z);
470               // if we find a cycle, we force the competing edge
471               graph.addArc(x, z);
472               _latentCouples_.emplace_back(x, z);
473             } catch (gum::NotFound) { graph.addArc(z, x); }
474           }
475 
476           if (reset) {
477             i = 0;
478           } else {
479             ++i;
480           }
481         }   // if rule 0 or rule 1
482         if (onProgress.hasListener()) {
483           GUM_EMIT3(onProgress,
484                     ((current_step_ + i) * 100) / (past_steps + steps_orient),
485                     0.,
486                     timer_.step());
487         }
488       }   // while
489 
490       // erasing the the double headed arcs
491       for (const Arc& arc: _latentCouples_) {
492         graph.eraseArc(Arc(arc.head(), arc.tail()));
493       }
494     }
495 
496     /// varient using the orientation protocol of MIIC
orientationMiic_(CorrectedMutualInformation<> & mutualInformation,MixedGraph & graph,const HashTable<std::pair<NodeId,NodeId>,std::vector<NodeId>> & sepSet)497     void Miic::orientationMiic_(
498        CorrectedMutualInformation<>&                                          mutualInformation,
499        MixedGraph&                                                            graph,
500        const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
501       // structure to store the orientations marks -, o, or >,
502       // Considers the head of the arc/edge first node -* second node
503       HashTable< std::pair< NodeId, NodeId >, char > marks = _initialMarks_;
504 
505       // marks always correspond to the head of the arc/edge. - is for a forbidden
506       // arc, > for a mandatory arc
507       // we start by adding the mandatory arcs
508       for (auto iter = marks.begin(); iter != marks.end(); ++iter) {
509         if (graph.existsEdge(iter.key().first, iter.key().second) && iter.val() == '>') {
510           graph.eraseEdge(Edge(iter.key().first, iter.key().second));
511           graph.addArc(iter.key().first, iter.key().second);
512         }
513       }
514 
515       std::vector< ProbabilisticRanking > proba_triples
516          = unshieldedTriplesMiic_(graph, mutualInformation, sepSet, marks);
517 
518       const Size steps_orient = proba_triples.size();
519       Size       past_steps   = current_step_;
520 
521       ProbabilisticRanking best;
522       if (steps_orient > 0) { best = proba_triples[0]; }
523 
524       while (!proba_triples.empty() && std::max(std::get< 2 >(best), std::get< 3 >(best)) > 0.5) {
525         const NodeId x = std::get< 0 >(*std::get< 0 >(best));
526         const NodeId y = std::get< 1 >(*std::get< 0 >(best));
527         const NodeId z = std::get< 2 >(*std::get< 0 >(best));
528 
529         const double i3 = std::get< 1 >(best);
530 
531         const double p1 = std::get< 2 >(best);
532         const double p2 = std::get< 3 >(best);
533         if (i3 <= 0) {
534           _orientingVstructureMiic_(graph, marks, x, y, z, p1, p2);
535         } else {
536           _propagatingOrientationMiic_(graph, marks, x, y, z, p1, p2);
537         }
538 
539         delete std::get< 0 >(best);
540         proba_triples.erase(proba_triples.begin());
541         // actualisation of the list of triples
542         proba_triples = updateProbaTriples_(graph, proba_triples);
543 
544         if (!proba_triples.empty()) best = proba_triples[0];
545 
546         ++current_step_;
547         if (onProgress.hasListener()) {
548           GUM_EMIT3(onProgress,
549                     (current_step_ * 100) / (steps_orient + past_steps),
550                     0.,
551                     timer_.step());
552         }
553       }   // while
554 
555       // erasing the double headed arcs
556       for (auto iter = _latentCouples_.rbegin(); iter != _latentCouples_.rend(); ++iter) {
557         graph.eraseArc(Arc(iter->head(), iter->tail()));
558         if (_existsDirectedPath_(graph, iter->head(), iter->tail())) {
559           // if we find a cycle, we force the competing edge
560           graph.addArc(iter->head(), iter->tail());
561           graph.eraseArc(Arc(iter->tail(), iter->head()));
562           *iter = Arc(iter->head(), iter->tail());
563         }
564       }
565 
566       if (onProgress.hasListener()) { GUM_EMIT3(onProgress, 100, 0., timer_.step()); }
567     }
568 
569     /// finds the best contributor node for a pair given a conditioning set
findBestContributor_(NodeId x,NodeId y,const std::vector<NodeId> & ui,const MixedGraph & graph,CorrectedMutualInformation<> & mutualInformation,Heap<CondRanking,GreaterPairOn2nd> & rank)570     void Miic::findBestContributor_(NodeId                                 x,
571                                     NodeId                                 y,
572                                     const std::vector< NodeId >&           ui,
573                                     const MixedGraph&                      graph,
574                                     CorrectedMutualInformation<>&          mutualInformation,
575                                     Heap< CondRanking, GreaterPairOn2nd >& rank) {
576       double maxP = -1.0;
577       NodeId maxZ = 0;
578 
579       // compute N
580       // __N = I.N();
581       const double Ixy_ui = mutualInformation.score(x, y, ui);
582 
583       for (const NodeId z: graph) {
584         // if z!=x and z!=y and z not in ui
585         if (z != x && z != y && std::find(ui.begin(), ui.end(), z) == ui.end()) {
586           double Pnv;
587           double Pb;
588 
589           // Computing Pnv
590           const double Ixyz_ui    = mutualInformation.score(x, y, z, ui);
591           double       calc_expo1 = -Ixyz_ui * M_LN2;
592           // if exponential are too high or to low, crop them at _maxLog_
593           if (calc_expo1 > _maxLog_) {
594             Pnv = 0.0;
595           } else if (calc_expo1 < -_maxLog_) {
596             Pnv = 1.0;
597           } else {
598             Pnv = 1 / (1 + std::exp(calc_expo1));
599           }
600 
601           // Computing Pb
602           const double Ixz_ui = mutualInformation.score(x, z, ui);
603           const double Iyz_ui = mutualInformation.score(y, z, ui);
604 
605           calc_expo1        = -(Ixz_ui - Ixy_ui) * M_LN2;
606           double calc_expo2 = -(Iyz_ui - Ixy_ui) * M_LN2;
607 
608           // if exponential are too high or to low, crop them at  _maxLog_
609           if (calc_expo1 > _maxLog_ || calc_expo2 > _maxLog_) {
610             Pb = 0.0;
611           } else if (calc_expo1 < -_maxLog_ && calc_expo2 < -_maxLog_) {
612             Pb = 1.0;
613           } else {
614             double expo1, expo2;
615             if (calc_expo1 < -_maxLog_) {
616               expo1 = 0.0;
617             } else {
618               expo1 = std::exp(calc_expo1);
619             }
620             if (calc_expo2 < -_maxLog_) {
621               expo2 = 0.0;
622             } else {
623               expo2 = std::exp(calc_expo2);
624             }
625             Pb = 1 / (1 + expo1 + expo2);
626           }
627 
628           // Getting max(min(Pnv, pb))
629           const double min_pnv_pb = std::min(Pnv, Pb);
630           if (min_pnv_pb > maxP) {
631             maxP = min_pnv_pb;
632             maxZ = z;
633           }
634         }   // if z not in (x, y)
635       }     // for z in graph.nodes
636       // storing best z in rank_
637       CondRanking final;
638       auto        tup = new CondThreePoints{x, y, maxZ, ui};
639       final.first     = tup;
640       final.second    = maxP;
641       rank.insert(final);
642     }
643 
644     /// gets the list of unshielded triples in the graph in decreasing value of
645     ///|I'(x, y, z|{ui})|
unshieldedTriples_(const MixedGraph & graph,CorrectedMutualInformation<> & mutualInformation,const HashTable<std::pair<NodeId,NodeId>,std::vector<NodeId>> & sepSet)646     std::vector< Ranking > Miic::unshieldedTriples_(
647        const MixedGraph&                                                      graph,
648        CorrectedMutualInformation<>&                                          mutualInformation,
649        const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
650       std::vector< Ranking > triples;
651       for (NodeId z: graph) {
652         for (NodeId x: graph.neighbours(z)) {
653           for (NodeId y: graph.neighbours(z)) {
654             if (y < x && !graph.existsEdge(x, y)) {
655               std::vector< NodeId >       ui;
656               std::pair< NodeId, NodeId > key     = {x, y};
657               std::pair< NodeId, NodeId > rev_key = {y, x};
658               if (sepSet.exists(key)) {
659                 ui = sepSet[key];
660               } else if (sepSet.exists(rev_key)) {
661                 ui = sepSet[rev_key];
662               }
663               // remove z from ui if it's present
664               const auto iter_z_place = std::find(ui.begin(), ui.end(), z);
665               if (iter_z_place != ui.end()) { ui.erase(iter_z_place); }
666 
667               double  Ixyz_ui = mutualInformation.score(x, y, z, ui);
668               Ranking triple;
669               auto    tup   = new ThreePoints{x, y, z};
670               triple.first  = tup;
671               triple.second = Ixyz_ui;
672               triples.push_back(triple);
673             }
674           }
675         }
676       }
677       std::sort(triples.begin(), triples.end(), GreaterAbsPairOn2nd());
678       return triples;
679     }
680 
681     /// gets the list of unshielded triples in the graph in decreasing value of
682     ///|I'(x, y, z|{ui})|, prepares the orientation matrix for MIIC
unshieldedTriplesMiic_(const MixedGraph & graph,CorrectedMutualInformation<> & mutualInformation,const HashTable<std::pair<NodeId,NodeId>,std::vector<NodeId>> & sepSet,HashTable<std::pair<NodeId,NodeId>,char> & marks)683     std::vector< ProbabilisticRanking > Miic::unshieldedTriplesMiic_(
684        const MixedGraph&                                                      graph,
685        CorrectedMutualInformation<>&                                          mutualInformation,
686        const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet,
687        HashTable< std::pair< NodeId, NodeId >, char >&                        marks) {
688       std::vector< ProbabilisticRanking > triples;
689       for (NodeId z: graph) {
690         for (NodeId x: graph.neighbours(z)) {
691           for (NodeId y: graph.neighbours(z)) {
692             if (y < x && !graph.existsEdge(x, y)) {
693               std::vector< NodeId >       ui;
694               std::pair< NodeId, NodeId > key     = {x, y};
695               std::pair< NodeId, NodeId > rev_key = {y, x};
696               if (sepSet.exists(key)) {
697                 ui = sepSet[key];
698               } else if (sepSet.exists(rev_key)) {
699                 ui = sepSet[rev_key];
700               }
701               // remove z from ui if it's present
702               const auto iter_z_place = std::find(ui.begin(), ui.end(), z);
703               if (iter_z_place != ui.end()) { ui.erase(iter_z_place); }
704 
705               const double         Ixyz_ui = mutualInformation.score(x, y, z, ui);
706               auto                 tup     = new ThreePoints{x, y, z};
707               ProbabilisticRanking triple{tup, Ixyz_ui, 0.5, 0.5};
708               triples.push_back(triple);
709               if (!marks.exists({x, z})) { marks.insert({x, z}, 'o'); }
710               if (!marks.exists({z, x})) { marks.insert({z, x}, 'o'); }
711               if (!marks.exists({y, z})) { marks.insert({y, z}, 'o'); }
712               if (!marks.exists({z, y})) { marks.insert({z, y}, 'o'); }
713             }
714           }
715         }
716       }
717       triples = updateProbaTriples_(graph, triples);
718       std::sort(triples.begin(), triples.end(), GreaterTupleOnLast());
719       return triples;
720     }
721 
722     /// Gets the orientation probabilities like MIIC for the orientation phase
723     std::vector< ProbabilisticRanking >
updateProbaTriples_(const MixedGraph & graph,std::vector<ProbabilisticRanking> probaTriples)724        Miic::updateProbaTriples_(const MixedGraph&                   graph,
725                                  std::vector< ProbabilisticRanking > probaTriples) {
726       for (auto& triple: probaTriples) {
727         NodeId x, y, z;
728         x                 = std::get< 0 >(*std::get< 0 >(triple));
729         y                 = std::get< 1 >(*std::get< 0 >(triple));
730         z                 = std::get< 2 >(*std::get< 0 >(triple));
731         const double Ixyz = std::get< 1 >(triple);
732         double       Pxz  = std::get< 2 >(triple);
733         double       Pyz  = std::get< 3 >(triple);
734 
735         if (Ixyz <= 0) {
736           const double expo = std::exp(Ixyz);
737           const double P0   = (1 + expo) / (1 + 3 * expo);
738           // distinguish between the initialization and the update process
739           if (Pxz == Pyz && Pyz == 0.5) {
740             std::get< 2 >(triple) = P0;
741             std::get< 3 >(triple) = P0;
742           } else {
743             if (graph.existsArc(x, z) && Pxz >= P0) {
744               std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5;
745             } else if (graph.existsArc(y, z) && Pyz >= P0) {
746               std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5;
747             }
748           }
749         } else {
750           const double expo = std::exp(-Ixyz);
751           if (graph.existsArc(x, z) && Pxz >= 0.5) {
752             std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5;
753           } else if (graph.existsArc(y, z) && Pyz >= 0.5) {
754             std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5;
755           }
756         }
757       }
758       std::sort(probaTriples.begin(), probaTriples.end(), GreaterTupleOnLast());
759       return probaTriples;
760     }
761 
762     /// learns the structure of an Bayesian network, ie a DAG, from an Essential
763     /// graph.
learnStructure(CorrectedMutualInformation<> & I,MixedGraph initialGraph)764     DAG Miic::learnStructure(CorrectedMutualInformation<>& I, MixedGraph initialGraph) {
765       MixedGraph essentialGraph = learnMixedStructure(I, initialGraph);
766       // orientate remaining edges
767 
768       const Sequence< NodeId > order = essentialGraph.topologicalOrder();
769 
770       // first, forbidden arcs force arc in the other direction
771       for (NodeId x: order) {
772         const auto nei_x = essentialGraph.neighbours(x);
773         for (NodeId y: nei_x)
774           if (isForbidenArc_(x, y)) {
775             essentialGraph.eraseEdge(Edge(x, y));
776             if (isForbidenArc_(y, x)) {
777               GUM_TRACE("Neither arc allowed for edge (" << x << "," << y << ")")
778             } else {
779               GUM_TRACE("Forced orientation : " << y << "->" << x)
780               essentialGraph.addArc(y, x);
781             }
782           } else if (isForbidenArc_(y, x)) {
783             essentialGraph.eraseEdge(Edge(x, y));
784             GUM_TRACE("Forced orientation : " << x << "->" << y)
785             essentialGraph.addArc(x, y);
786           }
787       }
788       GUM_TRACE(essentialGraph.toDot());
789 
790       // first, propagate existing orientations
791       bool newOrientation = true;
792       while (newOrientation) {
793         newOrientation = false;
794         for (NodeId x: order) {
795           if (!essentialGraph.parents(x).empty()) {
796             newOrientation |= propagatesRemainingOrientableEdges_(essentialGraph, x);
797           }
798         }
799       }
800       GUM_TRACE(essentialGraph.toDot());
801       propagatesOrientationInChainOfRemainingEdges_(essentialGraph);
802       GUM_TRACE(essentialGraph.toDot());
803 
804       // then decide the orientation for double arcs
805       for (NodeId x: order)
806         for (NodeId y: essentialGraph.parents(x))
807           if (essentialGraph.parents(y).contains(x)) {
808             GUM_TRACE(" + Resolving double arcs (poorly)")
809             essentialGraph.eraseArc(Arc(y, x));
810           }
811 
812       DAG dag;
813       for (auto node: essentialGraph) {
814         dag.addNodeWithId(node);
815       }
816       for (const Arc& arc: essentialGraph.arcs()) {
817         dag.addArc(arc.tail(), arc.head());
818       }
819 
820       return dag;
821     }
822 
isOrientable_(const MixedGraph & graph,NodeId xi,NodeId xj) const823     bool Miic::isOrientable_(const MixedGraph& graph, NodeId xi, NodeId xj) const {
824       // no cycle
825       if (_existsDirectedPath_(graph, xj, xi)) {
826         GUM_TRACE("cycle(" << xi << "-" << xj << ")")
827         return false;
828       }
829 
830       // R1
831       if (!(graph.parents(xi) - graph.adjacents(xj)).empty()) {
832         GUM_TRACE("R1(" << xi << "-" << xj << ")")
833         return true;
834       }
835 
836       // R2
837       if (_existsDirectedPath_(graph, xi, xj)) {
838         GUM_TRACE("R2(" << xi << "-" << xj << ")")
839         return true;
840       }
841 
842       // R3
843       int nbr = 0;
844       for (const auto p: graph.parents(xj)) {
845         if (!graph.mixedOrientedPath(xi, p).empty()) {
846           nbr += 1;
847           if (nbr == 2) {
848             GUM_TRACE("R3(" << xi << "-" << xj << ")")
849             return true;
850           }
851         }
852       }
853       return false;
854     }
855 
propagatesOrientationInChainOfRemainingEdges_(MixedGraph & essentialGraph)856     void Miic::propagatesOrientationInChainOfRemainingEdges_(MixedGraph& essentialGraph) {
857       // then decide the orientation for remaining edges
858       while (!essentialGraph.edges().empty()) {
859         const auto& edge               = *(essentialGraph.edges().begin());
860         NodeId      root               = edge.first();
861         Size        size_children_root = essentialGraph.children(root).size();
862         NodeSet     visited;
863         NodeSet     stack{root};
864         // check the best root for the set of neighbours
865         while (!stack.empty()) {
866           NodeId next = *(stack.begin());
867           stack.erase(next);
868           if (visited.contains(next)) continue;
869           if (essentialGraph.children(next).size() > size_children_root) {
870             size_children_root = essentialGraph.children(next).size();
871             root               = next;
872           }
873           for (const auto n: essentialGraph.neighbours(next))
874             if (!stack.contains(n) && !visited.contains(n)) stack.insert(n);
875           visited.insert(next);
876         }
877         // orientation now
878         visited.clear();
879         stack.clear();
880         stack.insert(root);
881         while (!stack.empty()) {
882           NodeId next = *(stack.begin());
883           stack.erase(next);
884           if (visited.contains(next)) continue;
885           const auto nei = essentialGraph.neighbours(next);
886           for (const auto n: nei) {
887             if (!stack.contains(n) && !visited.contains(n)) stack.insert(n);
888             GUM_TRACE(" + amap reasonably orientation for " << n << "->" << next);
889             essentialGraph.eraseEdge(Edge(n, next));
890             essentialGraph.addArc(n, next);
891           }
892           visited.insert(next);
893         }
894       }
895     }
896 
897     /// Propagates the orientation from a node to its neighbours
propagatesRemainingOrientableEdges_(MixedGraph & graph,NodeId xj)898     bool Miic::propagatesRemainingOrientableEdges_(MixedGraph& graph, NodeId xj) {
899       bool       res        = false;
900       const auto neighbours = graph.neighbours(xj);
901       for (auto& xi: neighbours) {
902         bool i_j = isOrientable_(graph, xi, xj);
903         bool j_i = isOrientable_(graph, xj, xi);
904         if (i_j || j_i) {
905           GUM_TRACE(" + Removing edge (" << xi << "," << xj << ")")
906           graph.eraseEdge(Edge(xi, xj));
907           res = true;
908         }
909         if (i_j) {
910           GUM_TRACE(" + add arc (" << xi << "," << xj << ")")
911           graph.addArc(xi, xj);
912           propagatesRemainingOrientableEdges_(graph, xj);
913         }
914         if (j_i) {
915           GUM_TRACE(" + add arc (" << xi << "," << xj << ")")
916           graph.addArc(xj, xi);
917           propagatesRemainingOrientableEdges_(graph, xi);
918         }
919         if (i_j && j_i) {
920           GUM_TRACE(" + add arc (" << xi << "," << xj << ")")
921           _latentCouples_.emplace_back(xi, xj);
922         }
923       }
924 
925       return res;
926     }
927 
928     /// get the list of arcs hiding latent variables
latentVariables() const929     const std::vector< Arc > Miic::latentVariables() const { return _latentCouples_; }
930 
931     /// learns the structure and the parameters of a BN
932     template < typename GUM_SCALAR, typename GRAPH_CHANGES_SELECTOR, typename PARAM_ESTIMATOR >
learnBN(GRAPH_CHANGES_SELECTOR & selector,PARAM_ESTIMATOR & estimator,DAG initial_dag)933     BayesNet< GUM_SCALAR > Miic::learnBN(GRAPH_CHANGES_SELECTOR& selector,
934                                          PARAM_ESTIMATOR&        estimator,
935                                          DAG                     initial_dag) {
936       return DAG2BNLearner<>::createBN< GUM_SCALAR >(estimator,
937                                                      learnStructure(selector, initial_dag));
938     }
939 
setMiicBehaviour()940     void Miic::setMiicBehaviour() { this->_useMiic_ = true; }
941 
set3of2Behaviour()942     void Miic::set3of2Behaviour() { this->_useMiic_ = false; }
943 
addConstraints(HashTable<std::pair<NodeId,NodeId>,char> constraints)944     void Miic::addConstraints(HashTable< std::pair< NodeId, NodeId >, char > constraints) {
945       this->_initialMarks_ = constraints;
946     }
947 
_existsNonTrivialDirectedPath_(const MixedGraph & graph,const NodeId n1,const NodeId n2)948     bool Miic::_existsNonTrivialDirectedPath_(const MixedGraph& graph,
949                                               const NodeId      n1,
950                                               const NodeId      n2) {
951       for (const auto parent: graph.parents(n2)) {
952         if (graph.existsArc(parent,
953                             n2))   // if there is a double arc, pass
954           continue;
955         if (parent == n1)   // trivial directed path => not recognized
956           continue;
957         if (_existsDirectedPath_(graph, n1, parent)) return true;
958       }
959       return false;
960     }
961 
_existsDirectedPath_(const MixedGraph & graph,const NodeId n1,const NodeId n2)962     bool Miic::_existsDirectedPath_(const MixedGraph& graph, const NodeId n1, const NodeId n2) {
963       // not recursive version => use a FIFO for simulating the recursion
964       List< NodeId > nodeFIFO;
965       // mark[node] = successor if visited, else mark[node] does not exist
966       Set< NodeId > mark;
967 
968       mark.insert(n2);
969       nodeFIFO.pushBack(n2);
970 
971       NodeId current;
972 
973       while (!nodeFIFO.empty()) {
974         current = nodeFIFO.front();
975         nodeFIFO.popFront();
976 
977         // check the parents
978         for (const auto new_one: graph.parents(current)) {
979           if (graph.existsArc(current,
980                               new_one))   // if there is a double arc, pass
981             continue;
982 
983           if (new_one == n1) { return true; }
984 
985           if (mark.exists(new_one))   // if this node is already marked, do not
986             continue;                 // check it again
987 
988           mark.insert(new_one);
989           nodeFIFO.pushBack(new_one);
990         }
991       }
992 
993       return false;
994     }
995 
_orientingVstructureMiic_(MixedGraph & graph,HashTable<std::pair<NodeId,NodeId>,char> & marks,NodeId x,NodeId y,NodeId z,double p1,double p2)996     void Miic::_orientingVstructureMiic_(MixedGraph&                                     graph,
997                                          HashTable< std::pair< NodeId, NodeId >, char >& marks,
998                                          NodeId                                          x,
999                                          NodeId                                          y,
1000                                          NodeId                                          z,
1001                                          double                                          p1,
1002                                          double                                          p2) {
1003       // v-structure discovery
1004       if (marks[{x, z}] == 'o' && marks[{y, z}] == 'o') {   // If x-z-y
1005         if (!_existsNonTrivialDirectedPath_(graph, z, x)) {
1006           graph.eraseEdge(Edge(x, z));
1007           graph.addArc(x, z);
1008           GUM_TRACE("1.a Removing edge (" << x << "," << z << ")")
1009           GUM_TRACE("1.a Adding arc (" << x << "," << z << ")")
1010           marks[{x, z}] = '>';
1011           if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) {
1012             GUM_TRACE("Adding latent couple (" << z << "," << x << ")")
1013             _latentCouples_.emplace_back(z, x);
1014           }
1015           if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1);
1016         } else {
1017           graph.eraseEdge(Edge(x, z));
1018           GUM_TRACE("1.b Adding arc (" << x << "," << z << ")")
1019           if (!_existsNonTrivialDirectedPath_(graph, x, z)) {
1020             graph.addArc(z, x);
1021             GUM_TRACE("1.b Removing edge (" << x << "," << z << ")")
1022             marks[{z, x}] = '>';
1023           }
1024         }
1025 
1026         if (!_existsNonTrivialDirectedPath_(graph, z, y)) {
1027           graph.eraseEdge(Edge(y, z));
1028           graph.addArc(y, z);
1029           GUM_TRACE("1.c Removing edge (" << y << "," << z << ")")
1030           GUM_TRACE("1.c Adding arc (" << y << "," << z << ")")
1031           marks[{y, z}] = '>';
1032           if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) {
1033             _latentCouples_.emplace_back(z, y);
1034           }
1035           if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2);
1036         } else {
1037           graph.eraseEdge(Edge(y, z));
1038           GUM_TRACE("1.d Removing edge (" << y << "," << z << ")")
1039           if (!_existsNonTrivialDirectedPath_(graph, y, z)) {
1040             graph.addArc(z, y);
1041             GUM_TRACE("1.d Adding arc (" << z << "," << y << ")")
1042             marks[{z, y}] = '>';
1043           }
1044         }
1045       } else if (marks[{x, z}] == '>' && marks[{y, z}] == 'o') {   // If x->z-y
1046         if (!_existsNonTrivialDirectedPath_(graph, z, y)) {
1047           graph.eraseEdge(Edge(y, z));
1048           graph.addArc(y, z);
1049           GUM_TRACE("2.a Removing edge (" << y << "," << z << ")")
1050           GUM_TRACE("2.a Adding arc (" << y << "," << z << ")")
1051           marks[{y, z}] = '>';
1052           if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) {
1053             _latentCouples_.emplace_back(z, y);
1054           }
1055           if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2);
1056         } else {
1057           graph.eraseEdge(Edge(y, z));
1058           GUM_TRACE("2.b Removing edge (" << y << "," << z << ")")
1059           if (!_existsNonTrivialDirectedPath_(graph, y, z)) {
1060             graph.addArc(z, y);
1061             GUM_TRACE("2.b Adding arc (" << y << "," << z << ")")
1062             marks[{z, y}] = '>';
1063           }
1064         }
1065       } else if (marks[{y, z}] == '>' && marks[{x, z}] == 'o') {
1066         if (!_existsNonTrivialDirectedPath_(graph, z, x)) {
1067           graph.eraseEdge(Edge(x, z));
1068           graph.addArc(x, z);
1069           GUM_TRACE("3.a Removing edge (" << x << "," << z << ")")
1070           GUM_TRACE("3.a Adding arc (" << x << "," << z << ")")
1071           marks[{x, z}] = '>';
1072           if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) {
1073             _latentCouples_.emplace_back(z, x);
1074           }
1075           if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1);
1076         } else {
1077           graph.eraseEdge(Edge(x, z));
1078           GUM_TRACE("3.b Removing edge (" << x << "," << z << ")")
1079           if (!_existsNonTrivialDirectedPath_(graph, x, z)) {
1080             graph.addArc(z, x);
1081             GUM_TRACE("3.b Adding arc (" << x << "," << z << ")")
1082             marks[{z, x}] = '>';
1083           }
1084         }
1085       }
1086     }
1087 
1088 
_propagatingOrientationMiic_(MixedGraph & graph,HashTable<std::pair<NodeId,NodeId>,char> & marks,NodeId x,NodeId y,NodeId z,double p1,double p2)1089     void Miic::_propagatingOrientationMiic_(MixedGraph&                                     graph,
1090                                             HashTable< std::pair< NodeId, NodeId >, char >& marks,
1091                                             NodeId                                          x,
1092                                             NodeId                                          y,
1093                                             NodeId                                          z,
1094                                             double                                          p1,
1095                                             double                                          p2) {
1096       // orientation propagation
1097       if (marks[{x, z}] == '>' && marks[{y, z}] == 'o' && marks[{z, y}] != '-') {
1098         graph.eraseEdge(Edge(z, y));
1099         // std::cout << "4. Removing edge (" << z << "," << y << ")" <<
1100         // std::endl;
1101         if (!_existsDirectedPath_(graph, y, z) && graph.parents(y).empty()) {
1102           graph.addArc(z, y);
1103           GUM_TRACE("4.a Adding arc (" << z << "," << y << ")")
1104           marks[{z, y}] = '>';
1105           marks[{y, z}] = '-';
1106           if (!_arcProbas_.exists(Arc(z, y))) _arcProbas_.insert(Arc(z, y), p2);
1107         } else if (!_existsDirectedPath_(graph, z, y) && graph.parents(z).empty()) {
1108           graph.addArc(y, z);
1109           GUM_TRACE("4.b Adding arc (" << y << "," << z << ")")
1110           marks[{z, y}] = '-';
1111           marks[{y, z}] = '>';
1112           _latentCouples_.emplace_back(y, z);
1113           if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2);
1114         } else if (!_existsDirectedPath_(graph, y, z)) {
1115           graph.addArc(z, y);
1116           GUM_TRACE("4.c Adding arc (" << z << "," << y << ")")
1117           marks[{z, y}] = '>';
1118           marks[{y, z}] = '-';
1119           if (!_arcProbas_.exists(Arc(z, y))) _arcProbas_.insert(Arc(z, y), p2);
1120         } else if (!_existsDirectedPath_(graph, z, y)) {
1121           graph.addArc(y, z);
1122           GUM_TRACE("4.d Adding arc (" << y << "," << z << ")")
1123           _latentCouples_.emplace_back(y, z);
1124           marks[{z, y}] = '-';
1125           marks[{y, z}] = '>';
1126           if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2);
1127         }
1128       } else if (marks[{y, z}] == '>' && marks[{x, z}] == 'o' && marks[{z, x}] != '-') {
1129         graph.eraseEdge(Edge(z, x));
1130         GUM_TRACE("5. Removing edge (" << z << "," << x << ")")
1131         if (!_existsDirectedPath_(graph, x, z) && graph.parents(x).empty()) {
1132           graph.addArc(z, x);
1133           GUM_TRACE("5.a Adding arc (" << z << "," << x << ")")
1134           marks[{z, x}] = '>';
1135           marks[{x, z}] = '-';
1136           if (!_arcProbas_.exists(Arc(z, x))) _arcProbas_.insert(Arc(z, x), p1);
1137         } else if (!_existsDirectedPath_(graph, z, x) && graph.parents(z).empty()) {
1138           graph.addArc(x, z);
1139           GUM_TRACE("5.b Adding arc (" << x << "," << z << ")")
1140           marks[{z, x}] = '-';
1141           marks[{x, z}] = '>';
1142           _latentCouples_.emplace_back(x, z);
1143           if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1);
1144         } else if (!_existsDirectedPath_(graph, x, z)) {
1145           graph.addArc(z, x);
1146           GUM_TRACE("5.c Adding arc (" << z << "," << x << ")")
1147           marks[{z, x}] = '>';
1148           marks[{x, z}] = '-';
1149           if (!_arcProbas_.exists(Arc(z, x))) _arcProbas_.insert(Arc(z, x), p1);
1150         } else if (!_existsDirectedPath_(graph, z, x)) {
1151           graph.addArc(x, z);
1152           GUM_TRACE("5.d Adding arc (" << x << "," << z << ")")
1153           marks[{z, x}] = '-';
1154           marks[{x, z}] = '>';
1155           _latentCouples_.emplace_back(x, z);
1156           if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1);
1157         }
1158       }
1159     }
1160 
_isNotLatentCouple_(const NodeId x,const NodeId y)1161     bool Miic::_isNotLatentCouple_(const NodeId x, const NodeId y) {
1162       const auto& lbeg = _latentCouples_.begin();
1163       const auto& lend = _latentCouples_.end();
1164 
1165       return (std::find(lbeg, lend, Arc(x, y)) == lend)
1166           && (std::find(lbeg, lend, Arc(y, x)) == lend);
1167     }
1168 
isForbidenArc_(NodeId x,NodeId y) const1169     bool Miic::isForbidenArc_(NodeId x, NodeId y) const {
1170       return (_initialMarks_.exists({x, y}) && _initialMarks_[{x, y}] == '-');
1171     }
1172   } /* namespace learning */
1173 
1174 } /* namespace gum */
1175