1 2 /** 3 * 4 * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe 5 * GONZALES(_at_AMU) info_at_agrum_dot_org 6 * 7 * This library is free software: you can redistribute it and/or modify 8 * it under the terms of the GNU Lesser General Public License as published by 9 * the Free Software Foundation, either version 3 of the License, or 10 * (at your option) any later version. 11 * 12 * This library is distributed in the hope that it will be useful, 13 * but WITHOUT ANY WARRANTY; without even the implied warranty of 14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 * GNU Lesser General Public License for more details. 16 * 17 * You should have received a copy of the GNU Lesser General Public License 18 * along with this library. If not, see <http://www.gnu.org/licenses/>. 19 * 20 */ 21 22 23 /** @file 24 * @brief Implementation of gum::learning::ThreeOffTwo and MIIC 25 * 26 * @author Quentin FALCAND, Marvin LASSERRE and Pierre-Henri WUILLEMIN(_at_LIP6) 27 */ 28 29 #include <agrum/tools/core/math/math_utils.h> 30 #include <agrum/tools/core/hashTable.h> 31 #include <agrum/tools/core/heap.h> 32 #include <agrum/tools/core/timer.h> 33 #include <agrum/tools/graphs/mixedGraph.h> 34 #include <agrum/BN/learning/Miic.h> 35 #include <agrum/BN/learning/paramUtils/DAG2BNLearner.h> 36 #include <agrum/tools/stattests/correctedMutualInformation.h> 37 38 39 namespace gum { 40 41 namespace learning { 42 43 /// default constructor Miic()44 Miic::Miic() : _maxLog_(100), _size_(0) { GUM_CONSTRUCTOR(Miic); } 45 46 /// default constructor with maxLog Miic(int maxLog)47 Miic::Miic(int maxLog) : _maxLog_(maxLog), _size_(0) { GUM_CONSTRUCTOR(Miic); } 48 49 /// copy constructor Miic(const Miic & from)50 Miic::Miic(const Miic& from) : ApproximationScheme(from), _size_(from._size_) { 51 GUM_CONS_CPY(Miic); 52 } 53 54 /// move constructor Miic(Miic && from)55 Miic::Miic(Miic&& from) : ApproximationScheme(std::move(from)), _size_(from._size_) { 56 GUM_CONS_MOV(Miic); 57 } 58 59 /// destructor ~Miic()60 Miic::~Miic() { GUM_DESTRUCTOR(Miic); } 61 62 /// copy operator operator =(const Miic & from)63 Miic& Miic::operator=(const Miic& from) { 64 ApproximationScheme::operator=(from); 65 return *this; 66 } 67 68 /// move operator operator =(Miic && from)69 Miic& Miic::operator=(Miic&& from) { 70 ApproximationScheme::operator=(std::move(from)); 71 return *this; 72 } 73 74 operator ()(const CondRanking & e1,const CondRanking & e2) const75 bool GreaterPairOn2nd::operator()(const CondRanking& e1, const CondRanking& e2) const { 76 return e1.second > e2.second; 77 } 78 operator ()(const Ranking & e1,const Ranking & e2) const79 bool GreaterAbsPairOn2nd::operator()(const Ranking& e1, const Ranking& e2) const { 80 return std::abs(e1.second) > std::abs(e2.second); 81 } 82 operator ()(const ProbabilisticRanking & e1,const ProbabilisticRanking & e2) const83 bool GreaterTupleOnLast::operator()(const ProbabilisticRanking& e1, 84 const ProbabilisticRanking& e2) const { 85 double p1xz = std::get< 2 >(e1); 86 double p1yz = std::get< 3 >(e1); 87 double p2xz = std::get< 2 >(e2); 88 double p2yz = std::get< 3 >(e2); 89 double I1 = std::get< 1 >(e1); 90 double I2 = std::get< 1 >(e2); 91 // First, we look at the sign of information. 92 // Then, the probability values 93 // and finally the abs value of information. 94 if ((I1 < 0 && I2 < 0) || (I1 >= 0 && I2 >= 0)) { 95 if (std::max(p1xz, p1yz) == std::max(p2xz, p2yz)) { 96 return std::abs(I1) > std::abs(I2); 97 } else { 98 return std::max(p1xz, p1yz) > std::max(p2xz, p2yz); 99 } 100 } else { 101 return I1 < I2; 102 } 103 } 104 105 /// learns the structure of a MixedGraph learnMixedStructure(CorrectedMutualInformation<> & mutualInformation,MixedGraph graph)106 MixedGraph Miic::learnMixedStructure(CorrectedMutualInformation<>& mutualInformation, 107 MixedGraph graph) { 108 timer_.reset(); 109 current_step_ = 0; 110 111 // clear the vector of latent arcs to be sure 112 _latentCouples_.clear(); 113 114 /// the heap of ranks, with the score, and the NodeIds of x, y and z. 115 Heap< CondRanking, GreaterPairOn2nd > rank; 116 117 /// the variables separation sets 118 HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > sep_set; 119 120 initiation_(mutualInformation, graph, sep_set, rank); 121 122 iteration_(mutualInformation, graph, sep_set, rank); 123 124 if (_useMiic_) { 125 orientationMiic_(mutualInformation, graph, sep_set); 126 } else { 127 orientation3off2_(mutualInformation, graph, sep_set); 128 } 129 130 return graph; 131 } 132 133 /* 134 * PHASE 1 : INITIATION 135 * 136 * We go over all edges and test if the variables are independent. If they 137 * are, 138 * the edge is deleted. If not, the best contributor is found. 139 */ initiation_(CorrectedMutualInformation<> & mutualInformation,MixedGraph & graph,HashTable<std::pair<NodeId,NodeId>,std::vector<NodeId>> & sepSet,Heap<CondRanking,GreaterPairOn2nd> & rank)140 void Miic::initiation_(CorrectedMutualInformation<>& mutualInformation, 141 MixedGraph& graph, 142 HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet, 143 Heap< CondRanking, GreaterPairOn2nd >& rank) { 144 NodeId x, y; 145 EdgeSet edges = graph.edges(); 146 Size steps_init = edges.size(); 147 148 for (const Edge& edge: edges) { 149 x = edge.first(); 150 y = edge.second(); 151 double Ixy = mutualInformation.score(x, y); 152 153 if (Ixy <= 0) { //< K 154 graph.eraseEdge(edge); 155 sepSet.insert(std::make_pair(x, y), _emptySet_); 156 } else { 157 findBestContributor_(x, y, _emptySet_, graph, mutualInformation, rank); 158 } 159 160 ++current_step_; 161 if (onProgress.hasListener()) { 162 GUM_EMIT3(onProgress, (current_step_ * 33) / steps_init, 0., timer_.step()); 163 } 164 } 165 } 166 167 /* 168 * PHASE 2 : ITERATION 169 * 170 * As long as we find important nodes for edges, we go over them to see if 171 * we can assess the independence of the variables. 172 */ iteration_(CorrectedMutualInformation<> & mutualInformation,MixedGraph & graph,HashTable<std::pair<NodeId,NodeId>,std::vector<NodeId>> & sepSet,Heap<CondRanking,GreaterPairOn2nd> & rank)173 void Miic::iteration_(CorrectedMutualInformation<>& mutualInformation, 174 MixedGraph& graph, 175 HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet, 176 Heap< CondRanking, GreaterPairOn2nd >& rank) { 177 // if no triples to further examine pass 178 CondRanking best; 179 180 Size steps_init = current_step_; 181 Size steps_iter = rank.size(); 182 183 try { 184 while (rank.top().second > 0.5) { 185 best = rank.pop(); 186 187 const NodeId x = std::get< 0 >(*(best.first)); 188 const NodeId y = std::get< 1 >(*(best.first)); 189 const NodeId z = std::get< 2 >(*(best.first)); 190 std::vector< NodeId > ui = std::move(std::get< 3 >(*(best.first))); 191 192 ui.push_back(z); 193 const double i_xy_ui = mutualInformation.score(x, y, ui); 194 if (i_xy_ui < 0) { 195 graph.eraseEdge(Edge(x, y)); 196 sepSet.insert(std::make_pair(x, y), std::move(ui)); 197 } else { 198 findBestContributor_(x, y, ui, graph, mutualInformation, rank); 199 } 200 201 delete best.first; 202 203 ++current_step_; 204 if (onProgress.hasListener()) { 205 GUM_EMIT3(onProgress, 206 (current_step_ * 66) / (steps_init + steps_iter), 207 0., 208 timer_.step()); 209 } 210 } 211 } catch (...) {} // here, rank is empty 212 current_step_ = steps_init + steps_iter; 213 if (onProgress.hasListener()) { GUM_EMIT3(onProgress, 66, 0., timer_.step()); } 214 current_step_ = steps_init + steps_iter; 215 } 216 217 /* 218 * PHASE 3 : ORIENTATION 219 * 220 * Try to assess v-structures and propagate them. 221 */ orientation3off2_(CorrectedMutualInformation<> & mutualInformation,MixedGraph & graph,const HashTable<std::pair<NodeId,NodeId>,std::vector<NodeId>> & sepSet)222 void Miic::orientation3off2_( 223 CorrectedMutualInformation<>& mutualInformation, 224 MixedGraph& graph, 225 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) { 226 std::vector< Ranking > triples = unshieldedTriples_(graph, mutualInformation, sepSet); 227 Size steps_orient = triples.size(); 228 Size past_steps = current_step_; 229 230 // marks always correspond to the head of the arc/edge. - is for a forbidden 231 // arc, > for a mandatory arc 232 // we start by adding the mandatory arcs 233 for (auto iter = _initialMarks_.begin(); iter != _initialMarks_.end(); ++iter) { 234 if (graph.existsEdge(iter.key().first, iter.key().second) && iter.val() == '>') { 235 graph.eraseEdge(Edge(iter.key().first, iter.key().second)); 236 graph.addArc(iter.key().first, iter.key().second); 237 } 238 } 239 240 NodeId i = 0; 241 // list of elements that we shouldn't read again, ie elements that are 242 // eligible to 243 // rule 0 after the first time they are tested, and elements on which rule 1 244 // has been applied 245 while (i < triples.size()) { 246 // if i not in do_not_reread 247 Ranking triple = triples[i]; 248 NodeId x, y, z; 249 x = std::get< 0 >(*triple.first); 250 y = std::get< 1 >(*triple.first); 251 z = std::get< 2 >(*triple.first); 252 253 std::vector< NodeId > ui; 254 std::pair< NodeId, NodeId > key = {x, y}; 255 std::pair< NodeId, NodeId > rev_key = {y, x}; 256 if (sepSet.exists(key)) { 257 ui = sepSet[key]; 258 } else if (sepSet.exists(rev_key)) { 259 ui = sepSet[rev_key]; 260 } 261 double Ixyz_ui = triple.second; 262 bool reset{false}; 263 // try Rule 0 264 if (Ixyz_ui < 0) { 265 // if ( z not in Sep[x,y]) 266 if (std::find(ui.begin(), ui.end(), z) == ui.end()) { 267 if (!graph.existsArc(x, z) && !graph.existsArc(z, x)) { 268 // when we try to add an arc to the graph, we always verify if 269 // we are allowed to do so, ie it is not a forbidden arc an it 270 // does not create a cycle 271 if (!_existsDirectedPath_(graph, z, x) && !isForbidenArc_(x, z)) { 272 reset = true; 273 graph.eraseEdge(Edge(x, z)); 274 graph.addArc(x, z); 275 } else if (_existsDirectedPath_(graph, z, x) && !isForbidenArc_(z, x)) { 276 reset = true; 277 graph.eraseEdge(Edge(x, z)); 278 // if we find a cycle, we force the competing edge 279 graph.addArc(z, x); 280 if (std::find(_latentCouples_.begin(), _latentCouples_.end(), Arc(z, x)) 281 == _latentCouples_.end()) { 282 _latentCouples_.emplace_back(z, x); 283 } 284 } 285 } else if (!graph.existsArc(y, z) && !graph.existsArc(z, y)) { 286 if (!_existsDirectedPath_(graph, z, y) && !isForbidenArc_(x, z)) { 287 reset = true; 288 graph.eraseEdge(Edge(y, z)); 289 graph.addArc(y, z); 290 } else if (_existsDirectedPath_(graph, z, y) && !isForbidenArc_(z, y)) { 291 reset = true; 292 graph.eraseEdge(Edge(y, z)); 293 // if we find a cycle, we force the competing edge 294 graph.addArc(z, y); 295 if (std::find(_latentCouples_.begin(), _latentCouples_.end(), Arc(z, y)) 296 == _latentCouples_.end()) { 297 _latentCouples_.emplace_back(z, y); 298 } 299 } 300 } else { 301 // checking if the anti-directed arc already exists, to register a 302 // latent variable 303 if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) { 304 _latentCouples_.emplace_back(z, x); 305 } 306 if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) { 307 _latentCouples_.emplace_back(z, y); 308 } 309 } 310 } 311 } else { // try Rule 1 312 if (graph.existsArc(x, z) && !graph.existsArc(z, y) && !graph.existsArc(y, z)) { 313 if (!_existsDirectedPath_(graph, y, z) && !isForbidenArc_(z, y)) { 314 reset = true; 315 graph.eraseEdge(Edge(z, y)); 316 graph.addArc(z, y); 317 } else if (_existsDirectedPath_(graph, y, z) && !isForbidenArc_(y, z)) { 318 reset = true; 319 graph.eraseEdge(Edge(z, y)); 320 // if we find a cycle, we force the competing edge 321 graph.addArc(y, z); 322 if (std::find(_latentCouples_.begin(), _latentCouples_.end(), Arc(y, z)) 323 == _latentCouples_.end()) { 324 _latentCouples_.emplace_back(y, z); 325 } 326 } 327 } 328 if (graph.existsArc(y, z) && !graph.existsArc(z, x) && !graph.existsArc(x, z)) { 329 if (!_existsDirectedPath_(graph, x, z) && !isForbidenArc_(z, x)) { 330 reset = true; 331 graph.eraseEdge(Edge(z, x)); 332 graph.addArc(z, x); 333 } else if (_existsDirectedPath_(graph, x, z) && !isForbidenArc_(x, z)) { 334 reset = true; 335 graph.eraseEdge(Edge(z, x)); 336 // if we find a cycle, we force the competing edge 337 graph.addArc(x, z); 338 if (std::find(_latentCouples_.begin(), _latentCouples_.end(), Arc(x, z)) 339 == _latentCouples_.end()) { 340 _latentCouples_.emplace_back(x, z); 341 } 342 } 343 } 344 } // if rule 0 or rule 1 345 346 // if what we want to add already exists : pass to the next triplet 347 if (reset) { 348 i = 0; 349 } else { 350 ++i; 351 } 352 if (onProgress.hasListener()) { 353 GUM_EMIT3(onProgress, 354 ((current_step_ + i) * 100) / (past_steps + steps_orient), 355 0., 356 timer_.step()); 357 } 358 } // while 359 360 // erasing the the double headed arcs 361 for (const Arc& arc: _latentCouples_) { 362 graph.eraseArc(Arc(arc.head(), arc.tail())); 363 } 364 } 365 366 /// variant trying to propagate both orientations in a bidirected arc orientationLatents_(CorrectedMutualInformation<> & mutualInformation,MixedGraph & graph,const HashTable<std::pair<NodeId,NodeId>,std::vector<NodeId>> & sepSet)367 void Miic::orientationLatents_( 368 CorrectedMutualInformation<>& mutualInformation, 369 MixedGraph& graph, 370 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) { 371 std::vector< Ranking > triples = unshieldedTriples_(graph, mutualInformation, sepSet); 372 Size steps_orient = triples.size(); 373 Size past_steps = current_step_; 374 375 NodeId i = 0; 376 // list of elements that we shouldnt read again, ie elements that are 377 // eligible to 378 // rule 0 after the first time they are tested, and elements on which rule 1 379 // has been applied 380 while (i < triples.size()) { 381 // if i not in do_not_reread 382 Ranking triple = triples[i]; 383 NodeId x, y, z; 384 x = std::get< 0 >(*triple.first); 385 y = std::get< 1 >(*triple.first); 386 z = std::get< 2 >(*triple.first); 387 388 std::vector< NodeId > ui; 389 std::pair< NodeId, NodeId > key = {x, y}; 390 std::pair< NodeId, NodeId > rev_key = {y, x}; 391 if (sepSet.exists(key)) { 392 ui = sepSet[key]; 393 } else if (sepSet.exists(rev_key)) { 394 ui = sepSet[rev_key]; 395 } 396 double Ixyz_ui = triple.second; 397 // try Rule 0 398 if (Ixyz_ui < 0) { 399 // if ( z not in Sep[x,y]) 400 if (std::find(ui.begin(), ui.end(), z) == ui.end()) { 401 // if what we want to add already exists : pass 402 if ((graph.existsArc(x, z) || graph.existsArc(z, x)) 403 && (graph.existsArc(y, z) || graph.existsArc(z, y))) { 404 ++i; 405 } else { 406 i = 0; 407 graph.eraseEdge(Edge(x, z)); 408 graph.eraseEdge(Edge(y, z)); 409 // checking for cycles 410 if (graph.existsArc(z, x)) { 411 graph.eraseArc(Arc(z, x)); 412 try { 413 std::vector< NodeId > path = graph.directedPath(z, x); 414 // if we find a cycle, we force the competing edge 415 _latentCouples_.emplace_back(z, x); 416 } catch (gum::NotFound) { graph.addArc(x, z); } 417 graph.addArc(z, x); 418 } else { 419 try { 420 std::vector< NodeId > path = graph.directedPath(z, x); 421 // if we find a cycle, we force the competing edge 422 graph.addArc(z, x); 423 _latentCouples_.emplace_back(z, x); 424 } catch (gum::NotFound) { graph.addArc(x, z); } 425 } 426 if (graph.existsArc(z, y)) { 427 graph.eraseArc(Arc(z, y)); 428 try { 429 std::vector< NodeId > path = graph.directedPath(z, y); 430 // if we find a cycle, we force the competing edge 431 _latentCouples_.emplace_back(z, y); 432 } catch (gum::NotFound) { graph.addArc(y, z); } 433 graph.addArc(z, y); 434 } else { 435 try { 436 std::vector< NodeId > path = graph.directedPath(z, y); 437 // if we find a cycle, we force the competing edge 438 graph.addArc(z, y); 439 _latentCouples_.emplace_back(z, y); 440 441 } catch (gum::NotFound) { graph.addArc(y, z); } 442 } 443 if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) { 444 _latentCouples_.emplace_back(z, x); 445 } 446 if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) { 447 _latentCouples_.emplace_back(z, y); 448 } 449 } 450 } else { 451 ++i; 452 } 453 } else { // try Rule 1 454 bool reset{false}; 455 if (graph.existsArc(x, z) && !graph.existsArc(z, y) && !graph.existsArc(y, z)) { 456 reset = true; 457 graph.eraseEdge(Edge(z, y)); 458 try { 459 std::vector< NodeId > path = graph.directedPath(y, z); 460 // if we find a cycle, we force the competing edge 461 graph.addArc(y, z); 462 _latentCouples_.emplace_back(y, z); 463 } catch (gum::NotFound) { graph.addArc(z, y); } 464 } 465 if (graph.existsArc(y, z) && !graph.existsArc(z, x) && !graph.existsArc(x, z)) { 466 reset = true; 467 graph.eraseEdge(Edge(z, x)); 468 try { 469 std::vector< NodeId > path = graph.directedPath(x, z); 470 // if we find a cycle, we force the competing edge 471 graph.addArc(x, z); 472 _latentCouples_.emplace_back(x, z); 473 } catch (gum::NotFound) { graph.addArc(z, x); } 474 } 475 476 if (reset) { 477 i = 0; 478 } else { 479 ++i; 480 } 481 } // if rule 0 or rule 1 482 if (onProgress.hasListener()) { 483 GUM_EMIT3(onProgress, 484 ((current_step_ + i) * 100) / (past_steps + steps_orient), 485 0., 486 timer_.step()); 487 } 488 } // while 489 490 // erasing the the double headed arcs 491 for (const Arc& arc: _latentCouples_) { 492 graph.eraseArc(Arc(arc.head(), arc.tail())); 493 } 494 } 495 496 /// varient using the orientation protocol of MIIC orientationMiic_(CorrectedMutualInformation<> & mutualInformation,MixedGraph & graph,const HashTable<std::pair<NodeId,NodeId>,std::vector<NodeId>> & sepSet)497 void Miic::orientationMiic_( 498 CorrectedMutualInformation<>& mutualInformation, 499 MixedGraph& graph, 500 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) { 501 // structure to store the orientations marks -, o, or >, 502 // Considers the head of the arc/edge first node -* second node 503 HashTable< std::pair< NodeId, NodeId >, char > marks = _initialMarks_; 504 505 // marks always correspond to the head of the arc/edge. - is for a forbidden 506 // arc, > for a mandatory arc 507 // we start by adding the mandatory arcs 508 for (auto iter = marks.begin(); iter != marks.end(); ++iter) { 509 if (graph.existsEdge(iter.key().first, iter.key().second) && iter.val() == '>') { 510 graph.eraseEdge(Edge(iter.key().first, iter.key().second)); 511 graph.addArc(iter.key().first, iter.key().second); 512 } 513 } 514 515 std::vector< ProbabilisticRanking > proba_triples 516 = unshieldedTriplesMiic_(graph, mutualInformation, sepSet, marks); 517 518 const Size steps_orient = proba_triples.size(); 519 Size past_steps = current_step_; 520 521 ProbabilisticRanking best; 522 if (steps_orient > 0) { best = proba_triples[0]; } 523 524 while (!proba_triples.empty() && std::max(std::get< 2 >(best), std::get< 3 >(best)) > 0.5) { 525 const NodeId x = std::get< 0 >(*std::get< 0 >(best)); 526 const NodeId y = std::get< 1 >(*std::get< 0 >(best)); 527 const NodeId z = std::get< 2 >(*std::get< 0 >(best)); 528 529 const double i3 = std::get< 1 >(best); 530 531 const double p1 = std::get< 2 >(best); 532 const double p2 = std::get< 3 >(best); 533 if (i3 <= 0) { 534 _orientingVstructureMiic_(graph, marks, x, y, z, p1, p2); 535 } else { 536 _propagatingOrientationMiic_(graph, marks, x, y, z, p1, p2); 537 } 538 539 delete std::get< 0 >(best); 540 proba_triples.erase(proba_triples.begin()); 541 // actualisation of the list of triples 542 proba_triples = updateProbaTriples_(graph, proba_triples); 543 544 if (!proba_triples.empty()) best = proba_triples[0]; 545 546 ++current_step_; 547 if (onProgress.hasListener()) { 548 GUM_EMIT3(onProgress, 549 (current_step_ * 100) / (steps_orient + past_steps), 550 0., 551 timer_.step()); 552 } 553 } // while 554 555 // erasing the double headed arcs 556 for (auto iter = _latentCouples_.rbegin(); iter != _latentCouples_.rend(); ++iter) { 557 graph.eraseArc(Arc(iter->head(), iter->tail())); 558 if (_existsDirectedPath_(graph, iter->head(), iter->tail())) { 559 // if we find a cycle, we force the competing edge 560 graph.addArc(iter->head(), iter->tail()); 561 graph.eraseArc(Arc(iter->tail(), iter->head())); 562 *iter = Arc(iter->head(), iter->tail()); 563 } 564 } 565 566 if (onProgress.hasListener()) { GUM_EMIT3(onProgress, 100, 0., timer_.step()); } 567 } 568 569 /// finds the best contributor node for a pair given a conditioning set findBestContributor_(NodeId x,NodeId y,const std::vector<NodeId> & ui,const MixedGraph & graph,CorrectedMutualInformation<> & mutualInformation,Heap<CondRanking,GreaterPairOn2nd> & rank)570 void Miic::findBestContributor_(NodeId x, 571 NodeId y, 572 const std::vector< NodeId >& ui, 573 const MixedGraph& graph, 574 CorrectedMutualInformation<>& mutualInformation, 575 Heap< CondRanking, GreaterPairOn2nd >& rank) { 576 double maxP = -1.0; 577 NodeId maxZ = 0; 578 579 // compute N 580 // __N = I.N(); 581 const double Ixy_ui = mutualInformation.score(x, y, ui); 582 583 for (const NodeId z: graph) { 584 // if z!=x and z!=y and z not in ui 585 if (z != x && z != y && std::find(ui.begin(), ui.end(), z) == ui.end()) { 586 double Pnv; 587 double Pb; 588 589 // Computing Pnv 590 const double Ixyz_ui = mutualInformation.score(x, y, z, ui); 591 double calc_expo1 = -Ixyz_ui * M_LN2; 592 // if exponential are too high or to low, crop them at _maxLog_ 593 if (calc_expo1 > _maxLog_) { 594 Pnv = 0.0; 595 } else if (calc_expo1 < -_maxLog_) { 596 Pnv = 1.0; 597 } else { 598 Pnv = 1 / (1 + std::exp(calc_expo1)); 599 } 600 601 // Computing Pb 602 const double Ixz_ui = mutualInformation.score(x, z, ui); 603 const double Iyz_ui = mutualInformation.score(y, z, ui); 604 605 calc_expo1 = -(Ixz_ui - Ixy_ui) * M_LN2; 606 double calc_expo2 = -(Iyz_ui - Ixy_ui) * M_LN2; 607 608 // if exponential are too high or to low, crop them at _maxLog_ 609 if (calc_expo1 > _maxLog_ || calc_expo2 > _maxLog_) { 610 Pb = 0.0; 611 } else if (calc_expo1 < -_maxLog_ && calc_expo2 < -_maxLog_) { 612 Pb = 1.0; 613 } else { 614 double expo1, expo2; 615 if (calc_expo1 < -_maxLog_) { 616 expo1 = 0.0; 617 } else { 618 expo1 = std::exp(calc_expo1); 619 } 620 if (calc_expo2 < -_maxLog_) { 621 expo2 = 0.0; 622 } else { 623 expo2 = std::exp(calc_expo2); 624 } 625 Pb = 1 / (1 + expo1 + expo2); 626 } 627 628 // Getting max(min(Pnv, pb)) 629 const double min_pnv_pb = std::min(Pnv, Pb); 630 if (min_pnv_pb > maxP) { 631 maxP = min_pnv_pb; 632 maxZ = z; 633 } 634 } // if z not in (x, y) 635 } // for z in graph.nodes 636 // storing best z in rank_ 637 CondRanking final; 638 auto tup = new CondThreePoints{x, y, maxZ, ui}; 639 final.first = tup; 640 final.second = maxP; 641 rank.insert(final); 642 } 643 644 /// gets the list of unshielded triples in the graph in decreasing value of 645 ///|I'(x, y, z|{ui})| unshieldedTriples_(const MixedGraph & graph,CorrectedMutualInformation<> & mutualInformation,const HashTable<std::pair<NodeId,NodeId>,std::vector<NodeId>> & sepSet)646 std::vector< Ranking > Miic::unshieldedTriples_( 647 const MixedGraph& graph, 648 CorrectedMutualInformation<>& mutualInformation, 649 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) { 650 std::vector< Ranking > triples; 651 for (NodeId z: graph) { 652 for (NodeId x: graph.neighbours(z)) { 653 for (NodeId y: graph.neighbours(z)) { 654 if (y < x && !graph.existsEdge(x, y)) { 655 std::vector< NodeId > ui; 656 std::pair< NodeId, NodeId > key = {x, y}; 657 std::pair< NodeId, NodeId > rev_key = {y, x}; 658 if (sepSet.exists(key)) { 659 ui = sepSet[key]; 660 } else if (sepSet.exists(rev_key)) { 661 ui = sepSet[rev_key]; 662 } 663 // remove z from ui if it's present 664 const auto iter_z_place = std::find(ui.begin(), ui.end(), z); 665 if (iter_z_place != ui.end()) { ui.erase(iter_z_place); } 666 667 double Ixyz_ui = mutualInformation.score(x, y, z, ui); 668 Ranking triple; 669 auto tup = new ThreePoints{x, y, z}; 670 triple.first = tup; 671 triple.second = Ixyz_ui; 672 triples.push_back(triple); 673 } 674 } 675 } 676 } 677 std::sort(triples.begin(), triples.end(), GreaterAbsPairOn2nd()); 678 return triples; 679 } 680 681 /// gets the list of unshielded triples in the graph in decreasing value of 682 ///|I'(x, y, z|{ui})|, prepares the orientation matrix for MIIC unshieldedTriplesMiic_(const MixedGraph & graph,CorrectedMutualInformation<> & mutualInformation,const HashTable<std::pair<NodeId,NodeId>,std::vector<NodeId>> & sepSet,HashTable<std::pair<NodeId,NodeId>,char> & marks)683 std::vector< ProbabilisticRanking > Miic::unshieldedTriplesMiic_( 684 const MixedGraph& graph, 685 CorrectedMutualInformation<>& mutualInformation, 686 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet, 687 HashTable< std::pair< NodeId, NodeId >, char >& marks) { 688 std::vector< ProbabilisticRanking > triples; 689 for (NodeId z: graph) { 690 for (NodeId x: graph.neighbours(z)) { 691 for (NodeId y: graph.neighbours(z)) { 692 if (y < x && !graph.existsEdge(x, y)) { 693 std::vector< NodeId > ui; 694 std::pair< NodeId, NodeId > key = {x, y}; 695 std::pair< NodeId, NodeId > rev_key = {y, x}; 696 if (sepSet.exists(key)) { 697 ui = sepSet[key]; 698 } else if (sepSet.exists(rev_key)) { 699 ui = sepSet[rev_key]; 700 } 701 // remove z from ui if it's present 702 const auto iter_z_place = std::find(ui.begin(), ui.end(), z); 703 if (iter_z_place != ui.end()) { ui.erase(iter_z_place); } 704 705 const double Ixyz_ui = mutualInformation.score(x, y, z, ui); 706 auto tup = new ThreePoints{x, y, z}; 707 ProbabilisticRanking triple{tup, Ixyz_ui, 0.5, 0.5}; 708 triples.push_back(triple); 709 if (!marks.exists({x, z})) { marks.insert({x, z}, 'o'); } 710 if (!marks.exists({z, x})) { marks.insert({z, x}, 'o'); } 711 if (!marks.exists({y, z})) { marks.insert({y, z}, 'o'); } 712 if (!marks.exists({z, y})) { marks.insert({z, y}, 'o'); } 713 } 714 } 715 } 716 } 717 triples = updateProbaTriples_(graph, triples); 718 std::sort(triples.begin(), triples.end(), GreaterTupleOnLast()); 719 return triples; 720 } 721 722 /// Gets the orientation probabilities like MIIC for the orientation phase 723 std::vector< ProbabilisticRanking > updateProbaTriples_(const MixedGraph & graph,std::vector<ProbabilisticRanking> probaTriples)724 Miic::updateProbaTriples_(const MixedGraph& graph, 725 std::vector< ProbabilisticRanking > probaTriples) { 726 for (auto& triple: probaTriples) { 727 NodeId x, y, z; 728 x = std::get< 0 >(*std::get< 0 >(triple)); 729 y = std::get< 1 >(*std::get< 0 >(triple)); 730 z = std::get< 2 >(*std::get< 0 >(triple)); 731 const double Ixyz = std::get< 1 >(triple); 732 double Pxz = std::get< 2 >(triple); 733 double Pyz = std::get< 3 >(triple); 734 735 if (Ixyz <= 0) { 736 const double expo = std::exp(Ixyz); 737 const double P0 = (1 + expo) / (1 + 3 * expo); 738 // distinguish between the initialization and the update process 739 if (Pxz == Pyz && Pyz == 0.5) { 740 std::get< 2 >(triple) = P0; 741 std::get< 3 >(triple) = P0; 742 } else { 743 if (graph.existsArc(x, z) && Pxz >= P0) { 744 std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5; 745 } else if (graph.existsArc(y, z) && Pyz >= P0) { 746 std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5; 747 } 748 } 749 } else { 750 const double expo = std::exp(-Ixyz); 751 if (graph.existsArc(x, z) && Pxz >= 0.5) { 752 std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5; 753 } else if (graph.existsArc(y, z) && Pyz >= 0.5) { 754 std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5; 755 } 756 } 757 } 758 std::sort(probaTriples.begin(), probaTriples.end(), GreaterTupleOnLast()); 759 return probaTriples; 760 } 761 762 /// learns the structure of an Bayesian network, ie a DAG, from an Essential 763 /// graph. learnStructure(CorrectedMutualInformation<> & I,MixedGraph initialGraph)764 DAG Miic::learnStructure(CorrectedMutualInformation<>& I, MixedGraph initialGraph) { 765 MixedGraph essentialGraph = learnMixedStructure(I, initialGraph); 766 // orientate remaining edges 767 768 const Sequence< NodeId > order = essentialGraph.topologicalOrder(); 769 770 // first, forbidden arcs force arc in the other direction 771 for (NodeId x: order) { 772 const auto nei_x = essentialGraph.neighbours(x); 773 for (NodeId y: nei_x) 774 if (isForbidenArc_(x, y)) { 775 essentialGraph.eraseEdge(Edge(x, y)); 776 if (isForbidenArc_(y, x)) { 777 GUM_TRACE("Neither arc allowed for edge (" << x << "," << y << ")") 778 } else { 779 GUM_TRACE("Forced orientation : " << y << "->" << x) 780 essentialGraph.addArc(y, x); 781 } 782 } else if (isForbidenArc_(y, x)) { 783 essentialGraph.eraseEdge(Edge(x, y)); 784 GUM_TRACE("Forced orientation : " << x << "->" << y) 785 essentialGraph.addArc(x, y); 786 } 787 } 788 GUM_TRACE(essentialGraph.toDot()); 789 790 // first, propagate existing orientations 791 bool newOrientation = true; 792 while (newOrientation) { 793 newOrientation = false; 794 for (NodeId x: order) { 795 if (!essentialGraph.parents(x).empty()) { 796 newOrientation |= propagatesRemainingOrientableEdges_(essentialGraph, x); 797 } 798 } 799 } 800 GUM_TRACE(essentialGraph.toDot()); 801 propagatesOrientationInChainOfRemainingEdges_(essentialGraph); 802 GUM_TRACE(essentialGraph.toDot()); 803 804 // then decide the orientation for double arcs 805 for (NodeId x: order) 806 for (NodeId y: essentialGraph.parents(x)) 807 if (essentialGraph.parents(y).contains(x)) { 808 GUM_TRACE(" + Resolving double arcs (poorly)") 809 essentialGraph.eraseArc(Arc(y, x)); 810 } 811 812 DAG dag; 813 for (auto node: essentialGraph) { 814 dag.addNodeWithId(node); 815 } 816 for (const Arc& arc: essentialGraph.arcs()) { 817 dag.addArc(arc.tail(), arc.head()); 818 } 819 820 return dag; 821 } 822 isOrientable_(const MixedGraph & graph,NodeId xi,NodeId xj) const823 bool Miic::isOrientable_(const MixedGraph& graph, NodeId xi, NodeId xj) const { 824 // no cycle 825 if (_existsDirectedPath_(graph, xj, xi)) { 826 GUM_TRACE("cycle(" << xi << "-" << xj << ")") 827 return false; 828 } 829 830 // R1 831 if (!(graph.parents(xi) - graph.adjacents(xj)).empty()) { 832 GUM_TRACE("R1(" << xi << "-" << xj << ")") 833 return true; 834 } 835 836 // R2 837 if (_existsDirectedPath_(graph, xi, xj)) { 838 GUM_TRACE("R2(" << xi << "-" << xj << ")") 839 return true; 840 } 841 842 // R3 843 int nbr = 0; 844 for (const auto p: graph.parents(xj)) { 845 if (!graph.mixedOrientedPath(xi, p).empty()) { 846 nbr += 1; 847 if (nbr == 2) { 848 GUM_TRACE("R3(" << xi << "-" << xj << ")") 849 return true; 850 } 851 } 852 } 853 return false; 854 } 855 propagatesOrientationInChainOfRemainingEdges_(MixedGraph & essentialGraph)856 void Miic::propagatesOrientationInChainOfRemainingEdges_(MixedGraph& essentialGraph) { 857 // then decide the orientation for remaining edges 858 while (!essentialGraph.edges().empty()) { 859 const auto& edge = *(essentialGraph.edges().begin()); 860 NodeId root = edge.first(); 861 Size size_children_root = essentialGraph.children(root).size(); 862 NodeSet visited; 863 NodeSet stack{root}; 864 // check the best root for the set of neighbours 865 while (!stack.empty()) { 866 NodeId next = *(stack.begin()); 867 stack.erase(next); 868 if (visited.contains(next)) continue; 869 if (essentialGraph.children(next).size() > size_children_root) { 870 size_children_root = essentialGraph.children(next).size(); 871 root = next; 872 } 873 for (const auto n: essentialGraph.neighbours(next)) 874 if (!stack.contains(n) && !visited.contains(n)) stack.insert(n); 875 visited.insert(next); 876 } 877 // orientation now 878 visited.clear(); 879 stack.clear(); 880 stack.insert(root); 881 while (!stack.empty()) { 882 NodeId next = *(stack.begin()); 883 stack.erase(next); 884 if (visited.contains(next)) continue; 885 const auto nei = essentialGraph.neighbours(next); 886 for (const auto n: nei) { 887 if (!stack.contains(n) && !visited.contains(n)) stack.insert(n); 888 GUM_TRACE(" + amap reasonably orientation for " << n << "->" << next); 889 essentialGraph.eraseEdge(Edge(n, next)); 890 essentialGraph.addArc(n, next); 891 } 892 visited.insert(next); 893 } 894 } 895 } 896 897 /// Propagates the orientation from a node to its neighbours propagatesRemainingOrientableEdges_(MixedGraph & graph,NodeId xj)898 bool Miic::propagatesRemainingOrientableEdges_(MixedGraph& graph, NodeId xj) { 899 bool res = false; 900 const auto neighbours = graph.neighbours(xj); 901 for (auto& xi: neighbours) { 902 bool i_j = isOrientable_(graph, xi, xj); 903 bool j_i = isOrientable_(graph, xj, xi); 904 if (i_j || j_i) { 905 GUM_TRACE(" + Removing edge (" << xi << "," << xj << ")") 906 graph.eraseEdge(Edge(xi, xj)); 907 res = true; 908 } 909 if (i_j) { 910 GUM_TRACE(" + add arc (" << xi << "," << xj << ")") 911 graph.addArc(xi, xj); 912 propagatesRemainingOrientableEdges_(graph, xj); 913 } 914 if (j_i) { 915 GUM_TRACE(" + add arc (" << xi << "," << xj << ")") 916 graph.addArc(xj, xi); 917 propagatesRemainingOrientableEdges_(graph, xi); 918 } 919 if (i_j && j_i) { 920 GUM_TRACE(" + add arc (" << xi << "," << xj << ")") 921 _latentCouples_.emplace_back(xi, xj); 922 } 923 } 924 925 return res; 926 } 927 928 /// get the list of arcs hiding latent variables latentVariables() const929 const std::vector< Arc > Miic::latentVariables() const { return _latentCouples_; } 930 931 /// learns the structure and the parameters of a BN 932 template < typename GUM_SCALAR, typename GRAPH_CHANGES_SELECTOR, typename PARAM_ESTIMATOR > learnBN(GRAPH_CHANGES_SELECTOR & selector,PARAM_ESTIMATOR & estimator,DAG initial_dag)933 BayesNet< GUM_SCALAR > Miic::learnBN(GRAPH_CHANGES_SELECTOR& selector, 934 PARAM_ESTIMATOR& estimator, 935 DAG initial_dag) { 936 return DAG2BNLearner<>::createBN< GUM_SCALAR >(estimator, 937 learnStructure(selector, initial_dag)); 938 } 939 setMiicBehaviour()940 void Miic::setMiicBehaviour() { this->_useMiic_ = true; } 941 set3of2Behaviour()942 void Miic::set3of2Behaviour() { this->_useMiic_ = false; } 943 addConstraints(HashTable<std::pair<NodeId,NodeId>,char> constraints)944 void Miic::addConstraints(HashTable< std::pair< NodeId, NodeId >, char > constraints) { 945 this->_initialMarks_ = constraints; 946 } 947 _existsNonTrivialDirectedPath_(const MixedGraph & graph,const NodeId n1,const NodeId n2)948 bool Miic::_existsNonTrivialDirectedPath_(const MixedGraph& graph, 949 const NodeId n1, 950 const NodeId n2) { 951 for (const auto parent: graph.parents(n2)) { 952 if (graph.existsArc(parent, 953 n2)) // if there is a double arc, pass 954 continue; 955 if (parent == n1) // trivial directed path => not recognized 956 continue; 957 if (_existsDirectedPath_(graph, n1, parent)) return true; 958 } 959 return false; 960 } 961 _existsDirectedPath_(const MixedGraph & graph,const NodeId n1,const NodeId n2)962 bool Miic::_existsDirectedPath_(const MixedGraph& graph, const NodeId n1, const NodeId n2) { 963 // not recursive version => use a FIFO for simulating the recursion 964 List< NodeId > nodeFIFO; 965 // mark[node] = successor if visited, else mark[node] does not exist 966 Set< NodeId > mark; 967 968 mark.insert(n2); 969 nodeFIFO.pushBack(n2); 970 971 NodeId current; 972 973 while (!nodeFIFO.empty()) { 974 current = nodeFIFO.front(); 975 nodeFIFO.popFront(); 976 977 // check the parents 978 for (const auto new_one: graph.parents(current)) { 979 if (graph.existsArc(current, 980 new_one)) // if there is a double arc, pass 981 continue; 982 983 if (new_one == n1) { return true; } 984 985 if (mark.exists(new_one)) // if this node is already marked, do not 986 continue; // check it again 987 988 mark.insert(new_one); 989 nodeFIFO.pushBack(new_one); 990 } 991 } 992 993 return false; 994 } 995 _orientingVstructureMiic_(MixedGraph & graph,HashTable<std::pair<NodeId,NodeId>,char> & marks,NodeId x,NodeId y,NodeId z,double p1,double p2)996 void Miic::_orientingVstructureMiic_(MixedGraph& graph, 997 HashTable< std::pair< NodeId, NodeId >, char >& marks, 998 NodeId x, 999 NodeId y, 1000 NodeId z, 1001 double p1, 1002 double p2) { 1003 // v-structure discovery 1004 if (marks[{x, z}] == 'o' && marks[{y, z}] == 'o') { // If x-z-y 1005 if (!_existsNonTrivialDirectedPath_(graph, z, x)) { 1006 graph.eraseEdge(Edge(x, z)); 1007 graph.addArc(x, z); 1008 GUM_TRACE("1.a Removing edge (" << x << "," << z << ")") 1009 GUM_TRACE("1.a Adding arc (" << x << "," << z << ")") 1010 marks[{x, z}] = '>'; 1011 if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) { 1012 GUM_TRACE("Adding latent couple (" << z << "," << x << ")") 1013 _latentCouples_.emplace_back(z, x); 1014 } 1015 if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1); 1016 } else { 1017 graph.eraseEdge(Edge(x, z)); 1018 GUM_TRACE("1.b Adding arc (" << x << "," << z << ")") 1019 if (!_existsNonTrivialDirectedPath_(graph, x, z)) { 1020 graph.addArc(z, x); 1021 GUM_TRACE("1.b Removing edge (" << x << "," << z << ")") 1022 marks[{z, x}] = '>'; 1023 } 1024 } 1025 1026 if (!_existsNonTrivialDirectedPath_(graph, z, y)) { 1027 graph.eraseEdge(Edge(y, z)); 1028 graph.addArc(y, z); 1029 GUM_TRACE("1.c Removing edge (" << y << "," << z << ")") 1030 GUM_TRACE("1.c Adding arc (" << y << "," << z << ")") 1031 marks[{y, z}] = '>'; 1032 if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) { 1033 _latentCouples_.emplace_back(z, y); 1034 } 1035 if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2); 1036 } else { 1037 graph.eraseEdge(Edge(y, z)); 1038 GUM_TRACE("1.d Removing edge (" << y << "," << z << ")") 1039 if (!_existsNonTrivialDirectedPath_(graph, y, z)) { 1040 graph.addArc(z, y); 1041 GUM_TRACE("1.d Adding arc (" << z << "," << y << ")") 1042 marks[{z, y}] = '>'; 1043 } 1044 } 1045 } else if (marks[{x, z}] == '>' && marks[{y, z}] == 'o') { // If x->z-y 1046 if (!_existsNonTrivialDirectedPath_(graph, z, y)) { 1047 graph.eraseEdge(Edge(y, z)); 1048 graph.addArc(y, z); 1049 GUM_TRACE("2.a Removing edge (" << y << "," << z << ")") 1050 GUM_TRACE("2.a Adding arc (" << y << "," << z << ")") 1051 marks[{y, z}] = '>'; 1052 if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) { 1053 _latentCouples_.emplace_back(z, y); 1054 } 1055 if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2); 1056 } else { 1057 graph.eraseEdge(Edge(y, z)); 1058 GUM_TRACE("2.b Removing edge (" << y << "," << z << ")") 1059 if (!_existsNonTrivialDirectedPath_(graph, y, z)) { 1060 graph.addArc(z, y); 1061 GUM_TRACE("2.b Adding arc (" << y << "," << z << ")") 1062 marks[{z, y}] = '>'; 1063 } 1064 } 1065 } else if (marks[{y, z}] == '>' && marks[{x, z}] == 'o') { 1066 if (!_existsNonTrivialDirectedPath_(graph, z, x)) { 1067 graph.eraseEdge(Edge(x, z)); 1068 graph.addArc(x, z); 1069 GUM_TRACE("3.a Removing edge (" << x << "," << z << ")") 1070 GUM_TRACE("3.a Adding arc (" << x << "," << z << ")") 1071 marks[{x, z}] = '>'; 1072 if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) { 1073 _latentCouples_.emplace_back(z, x); 1074 } 1075 if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1); 1076 } else { 1077 graph.eraseEdge(Edge(x, z)); 1078 GUM_TRACE("3.b Removing edge (" << x << "," << z << ")") 1079 if (!_existsNonTrivialDirectedPath_(graph, x, z)) { 1080 graph.addArc(z, x); 1081 GUM_TRACE("3.b Adding arc (" << x << "," << z << ")") 1082 marks[{z, x}] = '>'; 1083 } 1084 } 1085 } 1086 } 1087 1088 _propagatingOrientationMiic_(MixedGraph & graph,HashTable<std::pair<NodeId,NodeId>,char> & marks,NodeId x,NodeId y,NodeId z,double p1,double p2)1089 void Miic::_propagatingOrientationMiic_(MixedGraph& graph, 1090 HashTable< std::pair< NodeId, NodeId >, char >& marks, 1091 NodeId x, 1092 NodeId y, 1093 NodeId z, 1094 double p1, 1095 double p2) { 1096 // orientation propagation 1097 if (marks[{x, z}] == '>' && marks[{y, z}] == 'o' && marks[{z, y}] != '-') { 1098 graph.eraseEdge(Edge(z, y)); 1099 // std::cout << "4. Removing edge (" << z << "," << y << ")" << 1100 // std::endl; 1101 if (!_existsDirectedPath_(graph, y, z) && graph.parents(y).empty()) { 1102 graph.addArc(z, y); 1103 GUM_TRACE("4.a Adding arc (" << z << "," << y << ")") 1104 marks[{z, y}] = '>'; 1105 marks[{y, z}] = '-'; 1106 if (!_arcProbas_.exists(Arc(z, y))) _arcProbas_.insert(Arc(z, y), p2); 1107 } else if (!_existsDirectedPath_(graph, z, y) && graph.parents(z).empty()) { 1108 graph.addArc(y, z); 1109 GUM_TRACE("4.b Adding arc (" << y << "," << z << ")") 1110 marks[{z, y}] = '-'; 1111 marks[{y, z}] = '>'; 1112 _latentCouples_.emplace_back(y, z); 1113 if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2); 1114 } else if (!_existsDirectedPath_(graph, y, z)) { 1115 graph.addArc(z, y); 1116 GUM_TRACE("4.c Adding arc (" << z << "," << y << ")") 1117 marks[{z, y}] = '>'; 1118 marks[{y, z}] = '-'; 1119 if (!_arcProbas_.exists(Arc(z, y))) _arcProbas_.insert(Arc(z, y), p2); 1120 } else if (!_existsDirectedPath_(graph, z, y)) { 1121 graph.addArc(y, z); 1122 GUM_TRACE("4.d Adding arc (" << y << "," << z << ")") 1123 _latentCouples_.emplace_back(y, z); 1124 marks[{z, y}] = '-'; 1125 marks[{y, z}] = '>'; 1126 if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2); 1127 } 1128 } else if (marks[{y, z}] == '>' && marks[{x, z}] == 'o' && marks[{z, x}] != '-') { 1129 graph.eraseEdge(Edge(z, x)); 1130 GUM_TRACE("5. Removing edge (" << z << "," << x << ")") 1131 if (!_existsDirectedPath_(graph, x, z) && graph.parents(x).empty()) { 1132 graph.addArc(z, x); 1133 GUM_TRACE("5.a Adding arc (" << z << "," << x << ")") 1134 marks[{z, x}] = '>'; 1135 marks[{x, z}] = '-'; 1136 if (!_arcProbas_.exists(Arc(z, x))) _arcProbas_.insert(Arc(z, x), p1); 1137 } else if (!_existsDirectedPath_(graph, z, x) && graph.parents(z).empty()) { 1138 graph.addArc(x, z); 1139 GUM_TRACE("5.b Adding arc (" << x << "," << z << ")") 1140 marks[{z, x}] = '-'; 1141 marks[{x, z}] = '>'; 1142 _latentCouples_.emplace_back(x, z); 1143 if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1); 1144 } else if (!_existsDirectedPath_(graph, x, z)) { 1145 graph.addArc(z, x); 1146 GUM_TRACE("5.c Adding arc (" << z << "," << x << ")") 1147 marks[{z, x}] = '>'; 1148 marks[{x, z}] = '-'; 1149 if (!_arcProbas_.exists(Arc(z, x))) _arcProbas_.insert(Arc(z, x), p1); 1150 } else if (!_existsDirectedPath_(graph, z, x)) { 1151 graph.addArc(x, z); 1152 GUM_TRACE("5.d Adding arc (" << x << "," << z << ")") 1153 marks[{z, x}] = '-'; 1154 marks[{x, z}] = '>'; 1155 _latentCouples_.emplace_back(x, z); 1156 if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1); 1157 } 1158 } 1159 } 1160 _isNotLatentCouple_(const NodeId x,const NodeId y)1161 bool Miic::_isNotLatentCouple_(const NodeId x, const NodeId y) { 1162 const auto& lbeg = _latentCouples_.begin(); 1163 const auto& lend = _latentCouples_.end(); 1164 1165 return (std::find(lbeg, lend, Arc(x, y)) == lend) 1166 && (std::find(lbeg, lend, Arc(y, x)) == lend); 1167 } 1168 isForbidenArc_(NodeId x,NodeId y) const1169 bool Miic::isForbidenArc_(NodeId x, NodeId y) const { 1170 return (_initialMarks_.exists({x, y}) && _initialMarks_[{x, y}] == '-'); 1171 } 1172 } /* namespace learning */ 1173 1174 } /* namespace gum */ 1175