1 /** 2 * 3 * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU) 4 * info_at_agrum_dot_org 5 * 6 * This library is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU Lesser General Public License as published by 8 * the Free Software Foundation, either version 3 of the License, or 9 * (at your option) any later version. 10 * 11 * This library is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 * GNU Lesser General Public License for more details. 15 * 16 * You should have received a copy of the GNU Lesser General Public License 17 * along with this library. If not, see <http://www.gnu.org/licenses/>. 18 * 19 */ 20 21 22 /** 23 * @file 24 * @brief Implementation of Variable Elimination for inference in 25 * Bayesian networks. 26 * 27 * @author Christophe GONZALES(_at_AMU) and Pierre-Henri WUILLEMIN(_at_LIP6) 28 */ 29 30 #ifndef DOXYGEN_SHOULD_SKIP_THIS 31 32 # include <agrum/BN/inference/variableElimination.h> 33 34 # include <agrum/BN/algorithms/BayesBall.h> 35 # include <agrum/BN/algorithms/dSeparation.h> 36 # include <agrum/tools/graphs/algorithms/binaryJoinTreeConverterDefault.h> 37 # include <agrum/tools/multidim/instantiation.h> 38 # include <agrum/tools/multidim/utils/operators/multiDimCombineAndProjectDefault.h> 39 # include <agrum/tools/multidim/utils/operators/multiDimProjection.h> 40 41 42 namespace gum { 43 44 45 // default constructor 46 template < typename GUM_SCALAR > VariableElimination(const IBayesNet<GUM_SCALAR> * BN,RelevantPotentialsFinderType relevant_type,FindBarrenNodesType barren_type)47 INLINE VariableElimination< GUM_SCALAR >::VariableElimination( 48 const IBayesNet< GUM_SCALAR >* BN, 49 RelevantPotentialsFinderType relevant_type, 50 FindBarrenNodesType barren_type) : 51 JointTargetedInference< GUM_SCALAR >(BN) { 52 // sets the relevant potential and the barren nodes finding algorithm 53 setRelevantPotentialsFinderType(relevant_type); 54 setFindBarrenNodesType(barren_type); 55 56 // create a default triangulation (the user can change it afterwards) 57 _triangulation_ = new DefaultTriangulation; 58 59 // for debugging purposessetRequiredInference 60 GUM_CONSTRUCTOR(VariableElimination); 61 } 62 63 64 // destructor 65 template < typename GUM_SCALAR > ~VariableElimination()66 INLINE VariableElimination< GUM_SCALAR >::~VariableElimination() { 67 // remove the junction tree and the triangulation algorithm 68 if (_JT_ != nullptr) delete _JT_; 69 delete _triangulation_; 70 if (_target_posterior_ != nullptr) delete _target_posterior_; 71 72 // for debugging purposes 73 GUM_DESTRUCTOR(VariableElimination); 74 } 75 76 77 /// set a new triangulation algorithm 78 template < typename GUM_SCALAR > setTriangulation(const Triangulation & new_triangulation)79 void VariableElimination< GUM_SCALAR >::setTriangulation(const Triangulation& new_triangulation) { 80 delete _triangulation_; 81 _triangulation_ = new_triangulation.newFactory(); 82 } 83 84 85 /// returns the current join tree used 86 template < typename GUM_SCALAR > junctionTree(NodeId id)87 INLINE const JunctionTree* VariableElimination< GUM_SCALAR >::junctionTree(NodeId id) { 88 _createNewJT_(NodeSet{id}); 89 90 return _JT_; 91 } 92 93 94 /// sets the operator for performing the projections 95 template < typename GUM_SCALAR > _setProjectionFunction_(Potential<GUM_SCALAR> * (* proj)(const Potential<GUM_SCALAR> &,const Set<const DiscreteVariable * > &))96 INLINE void VariableElimination< GUM_SCALAR >::_setProjectionFunction_(Potential< GUM_SCALAR >* ( 97 *proj)(const Potential< GUM_SCALAR >&, const Set< const DiscreteVariable* >&)) { 98 _projection_op_ = proj; 99 } 100 101 102 /// sets the operator for performing the combinations 103 template < typename GUM_SCALAR > _setCombinationFunction_(Potential<GUM_SCALAR> * (* comb)(const Potential<GUM_SCALAR> &,const Potential<GUM_SCALAR> &))104 INLINE void VariableElimination< GUM_SCALAR >::_setCombinationFunction_(Potential< GUM_SCALAR >* ( 105 *comb)(const Potential< GUM_SCALAR >&, const Potential< GUM_SCALAR >&)) { 106 _combination_op_ = comb; 107 } 108 109 110 /// sets how we determine the relevant potentials to combine 111 template < typename GUM_SCALAR > setRelevantPotentialsFinderType(RelevantPotentialsFinderType type)112 void VariableElimination< GUM_SCALAR >::setRelevantPotentialsFinderType( 113 RelevantPotentialsFinderType type) { 114 if (type != _find_relevant_potential_type_) { 115 switch (type) { 116 case RelevantPotentialsFinderType::DSEP_BAYESBALL_POTENTIALS: 117 _findRelevantPotentials_ 118 = &VariableElimination< GUM_SCALAR >::_findRelevantPotentialsWithdSeparation2_; 119 break; 120 121 case RelevantPotentialsFinderType::DSEP_BAYESBALL_NODES: 122 _findRelevantPotentials_ 123 = &VariableElimination< GUM_SCALAR >::_findRelevantPotentialsWithdSeparation_; 124 break; 125 126 case RelevantPotentialsFinderType::DSEP_KOLLER_FRIEDMAN_2009: 127 _findRelevantPotentials_ 128 = &VariableElimination< GUM_SCALAR >::_findRelevantPotentialsWithdSeparation3_; 129 break; 130 131 case RelevantPotentialsFinderType::FIND_ALL: 132 _findRelevantPotentials_ 133 = &VariableElimination< GUM_SCALAR >::_findRelevantPotentialsGetAll_; 134 break; 135 136 default: 137 GUM_ERROR(InvalidArgument, 138 "setRelevantPotentialsFinderType for type " << (unsigned int)type 139 << " is not implemented yet"); 140 } 141 142 _find_relevant_potential_type_ = type; 143 } 144 } 145 146 147 /// sets how we determine barren nodes 148 template < typename GUM_SCALAR > setFindBarrenNodesType(FindBarrenNodesType type)149 void VariableElimination< GUM_SCALAR >::setFindBarrenNodesType(FindBarrenNodesType type) { 150 if (type != _barren_nodes_type_) { 151 // WARNING: if a new type is added here, method _createJT_ should 152 // certainly 153 // be updated as well, in particular its step 2. 154 switch (type) { 155 case FindBarrenNodesType::FIND_BARREN_NODES: 156 case FindBarrenNodesType::FIND_NO_BARREN_NODES: 157 break; 158 159 default: 160 GUM_ERROR(InvalidArgument, 161 "setFindBarrenNodesType for type " << (unsigned int)type 162 << " is not implemented yet"); 163 } 164 165 _barren_nodes_type_ = type; 166 } 167 } 168 169 170 /// fired when a new evidence is inserted 171 template < typename GUM_SCALAR > onEvidenceAdded_(const NodeId,bool)172 INLINE void VariableElimination< GUM_SCALAR >::onEvidenceAdded_(const NodeId, bool) {} 173 174 175 /// fired when an evidence is removed 176 template < typename GUM_SCALAR > onEvidenceErased_(const NodeId,bool)177 INLINE void VariableElimination< GUM_SCALAR >::onEvidenceErased_(const NodeId, bool) {} 178 179 180 /// fired when all the evidence are erased 181 template < typename GUM_SCALAR > onAllEvidenceErased_(bool)182 void VariableElimination< GUM_SCALAR >::onAllEvidenceErased_(bool) {} 183 184 185 /// fired when an evidence is changed 186 template < typename GUM_SCALAR > onEvidenceChanged_(const NodeId,bool)187 INLINE void VariableElimination< GUM_SCALAR >::onEvidenceChanged_(const NodeId, bool) {} 188 189 190 /// fired after a new target is inserted 191 template < typename GUM_SCALAR > onMarginalTargetAdded_(const NodeId)192 INLINE void VariableElimination< GUM_SCALAR >::onMarginalTargetAdded_(const NodeId) {} 193 194 195 /// fired before a target is removed 196 template < typename GUM_SCALAR > onMarginalTargetErased_(const NodeId)197 INLINE void VariableElimination< GUM_SCALAR >::onMarginalTargetErased_(const NodeId) {} 198 199 /// fired after a new Bayes net has been assigned to the engine 200 template < typename GUM_SCALAR > onModelChanged_(const GraphicalModel * bn)201 INLINE void VariableElimination< GUM_SCALAR >::onModelChanged_(const GraphicalModel* bn) {} 202 203 /// fired after a new set target is inserted 204 template < typename GUM_SCALAR > onJointTargetAdded_(const NodeSet &)205 INLINE void VariableElimination< GUM_SCALAR >::onJointTargetAdded_(const NodeSet&) {} 206 207 208 /// fired before a set target is removed 209 template < typename GUM_SCALAR > onJointTargetErased_(const NodeSet &)210 INLINE void VariableElimination< GUM_SCALAR >::onJointTargetErased_(const NodeSet&) {} 211 212 213 /// fired after all the nodes of the BN are added as single targets 214 template < typename GUM_SCALAR > onAllMarginalTargetsAdded_()215 INLINE void VariableElimination< GUM_SCALAR >::onAllMarginalTargetsAdded_() {} 216 217 218 /// fired before a all the single_targets are removed 219 template < typename GUM_SCALAR > onAllMarginalTargetsErased_()220 INLINE void VariableElimination< GUM_SCALAR >::onAllMarginalTargetsErased_() {} 221 222 223 /// fired before a all the joint_targets are removed 224 template < typename GUM_SCALAR > onAllJointTargetsErased_()225 INLINE void VariableElimination< GUM_SCALAR >::onAllJointTargetsErased_() {} 226 227 228 /// fired before a all the single and joint_targets are removed 229 template < typename GUM_SCALAR > onAllTargetsErased_()230 INLINE void VariableElimination< GUM_SCALAR >::onAllTargetsErased_() {} 231 232 233 /// create a new junction tree as well as its related data structures 234 template < typename GUM_SCALAR > _createNewJT_(const NodeSet & targets)235 void VariableElimination< GUM_SCALAR >::_createNewJT_(const NodeSet& targets) { 236 // to create the JT, we first create the moral graph of the BN in the 237 // following way in order to take into account the barren nodes and the 238 // nodes that received evidence: 239 // 1/ we create an undirected graph containing only the nodes and no edge 240 // 2/ if we take into account barren nodes, remove them from the graph 241 // 3/ if we take d-separation into account, remove the d-separated nodes 242 // 4/ add edges so that each node and its parents in the BN form a clique 243 // 5/ add edges so that the targets form a clique of the moral graph 244 // 6/ remove the nodes that received hard evidence (by step 4/, their 245 // parents are linked by edges, which is necessary for inference) 246 // 247 // At the end of step 6/, we have our moral graph and we can triangulate it 248 // to get the new junction tree 249 250 // 1/ create an undirected graph containing only the nodes and no edge 251 const auto& bn = this->BN(); 252 _graph_.clear(); 253 for (auto node: bn.dag()) 254 _graph_.addNodeWithId(node); 255 256 // 2/ if we wish to exploit barren nodes, we shall remove them from the BN 257 // to do so: we identify all the nodes that are not targets and have 258 // received no evidence and such that their descendants are neither targets 259 // nor evidence nodes. Such nodes can be safely discarded from the BN 260 // without altering the inference output 261 if (_barren_nodes_type_ == FindBarrenNodesType::FIND_BARREN_NODES) { 262 // check that all the nodes are not targets, otherwise, there is no 263 // barren node 264 if (targets.size() != bn.size()) { 265 BarrenNodesFinder finder(&(bn.dag())); 266 finder.setTargets(&targets); 267 268 NodeSet evidence_nodes; 269 for (const auto& pair: this->evidence()) { 270 evidence_nodes.insert(pair.first); 271 } 272 finder.setEvidence(&evidence_nodes); 273 274 NodeSet barren_nodes = finder.barrenNodes(); 275 276 // remove the barren nodes from the moral graph 277 for (const auto node: barren_nodes) { 278 _graph_.eraseNode(node); 279 } 280 } 281 } 282 283 // 3/ if we wish to exploit d-separation, remove all the nodes that are 284 // d-separated from our targets 285 { 286 NodeSet requisite_nodes; 287 bool dsep_analysis = false; 288 switch (_find_relevant_potential_type_) { 289 case RelevantPotentialsFinderType::DSEP_BAYESBALL_POTENTIALS: 290 case RelevantPotentialsFinderType::DSEP_BAYESBALL_NODES: { 291 BayesBall::requisiteNodes(bn.dag(), 292 targets, 293 this->hardEvidenceNodes(), 294 this->softEvidenceNodes(), 295 requisite_nodes); 296 dsep_analysis = true; 297 } break; 298 299 case RelevantPotentialsFinderType::DSEP_KOLLER_FRIEDMAN_2009: { 300 dSeparation dsep; 301 dsep.requisiteNodes(bn.dag(), 302 targets, 303 this->hardEvidenceNodes(), 304 this->softEvidenceNodes(), 305 requisite_nodes); 306 dsep_analysis = true; 307 } break; 308 309 case RelevantPotentialsFinderType::FIND_ALL: 310 break; 311 312 default: 313 GUM_ERROR(FatalError, "not implemented yet") 314 } 315 316 // remove all the nodes that are not requisite 317 if (dsep_analysis) { 318 for (auto iter = _graph_.beginSafe(); iter != _graph_.endSafe(); ++iter) { 319 if (!requisite_nodes.contains(*iter) && !this->hardEvidenceNodes().contains(*iter)) { 320 _graph_.eraseNode(*iter); 321 } 322 } 323 } 324 } 325 326 // 4/ add edges so that each node and its parents in the BN form a clique 327 for (const auto node: _graph_) { 328 const NodeSet& parents = bn.parents(node); 329 for (auto iter1 = parents.cbegin(); iter1 != parents.cend(); ++iter1) { 330 // before adding an edge between node and its parent, check that the 331 // parent belong to the graph. Actually, when d-separated nodes are 332 // removed, it may be the case that the parents of hard evidence nodes 333 // are removed. But the latter still exist in the graph. 334 if (_graph_.existsNode(*iter1)) _graph_.addEdge(*iter1, node); 335 336 auto iter2 = iter1; 337 for (++iter2; iter2 != parents.cend(); ++iter2) { 338 // before adding an edge, check that both extremities belong to 339 // the graph. Actually, when d-separated nodes are removed, it may 340 // be the case that the parents of hard evidence nodes are removed. 341 // But the latter still exist in the graph. 342 if (_graph_.existsNode(*iter1) && _graph_.existsNode(*iter2)) 343 _graph_.addEdge(*iter1, *iter2); 344 } 345 } 346 } 347 348 // 5/ if targets contains several nodes, we shall add new edges into the 349 // moral graph in order to ensure that there exists a clique containing 350 // thier joint distribution 351 for (auto iter1 = targets.cbegin(); iter1 != targets.cend(); ++iter1) { 352 auto iter2 = iter1; 353 for (++iter2; iter2 != targets.cend(); ++iter2) { 354 _graph_.addEdge(*iter1, *iter2); 355 } 356 } 357 358 // 6/ remove all the nodes that received hard evidence 359 for (const auto node: this->hardEvidenceNodes()) { 360 _graph_.eraseNode(node); 361 } 362 363 364 // now, we can compute the new junction tree. 365 if (_JT_ != nullptr) delete _JT_; 366 _triangulation_->setGraph(&_graph_, &(this->domainSizes())); 367 const JunctionTree& triang_jt = _triangulation_->junctionTree(); 368 _JT_ = new CliqueGraph(triang_jt); 369 370 // indicate, for each node of the moral graph a clique in _JT_ that can 371 // contain its conditional probability table 372 _node_to_clique_.clear(); 373 _clique_potentials_.clear(); 374 NodeSet emptyset; 375 for (auto clique: *_JT_) 376 _clique_potentials_.insert(clique, emptyset); 377 const std::vector< NodeId >& JT_elim_order = _triangulation_->eliminationOrder(); 378 NodeProperty< Size > elim_order(Size(JT_elim_order.size())); 379 for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size; ++i) 380 elim_order.insert(JT_elim_order[i], NodeId(i)); 381 const DAG& dag = bn.dag(); 382 for (const auto node: _graph_) { 383 // get the variables in the potential of node (and its parents) 384 NodeId first_eliminated_node = node; 385 Size elim_number = elim_order[first_eliminated_node]; 386 387 for (const auto parent: dag.parents(node)) { 388 if (_graph_.existsNode(parent) && (elim_order[parent] < elim_number)) { 389 elim_number = elim_order[parent]; 390 first_eliminated_node = parent; 391 } 392 } 393 394 // first_eliminated_node contains the first var (node or one of its 395 // parents) eliminated => the clique created during its elimination 396 // contains node and all of its parents => it can contain the potential 397 // assigned to the node in the BN 398 NodeId clique = _triangulation_->createdJunctionTreeClique(first_eliminated_node); 399 _node_to_clique_.insert(node, clique); 400 _clique_potentials_[clique].insert(node); 401 } 402 403 // do the same for the nodes that received evidence. Here, we only store 404 // the nodes whose at least one parent belongs to _graph_ (otherwise 405 // their CPT is just a constant real number). 406 for (const auto node: this->hardEvidenceNodes()) { 407 // get the set of parents of the node that belong to _graph_ 408 NodeSet pars(dag.parents(node).size()); 409 for (const auto par: dag.parents(node)) 410 if (_graph_.exists(par)) pars.insert(par); 411 412 if (!pars.empty()) { 413 NodeId first_eliminated_node = *(pars.begin()); 414 Size elim_number = elim_order[first_eliminated_node]; 415 416 for (const auto parent: pars) { 417 if (elim_order[parent] < elim_number) { 418 elim_number = elim_order[parent]; 419 first_eliminated_node = parent; 420 } 421 } 422 423 // first_eliminated_node contains the first var (node or one of its 424 // parents) eliminated => the clique created during its elimination 425 // contains node and all of its parents => it can contain the potential 426 // assigned to the node in the BN 427 NodeId clique = _triangulation_->createdJunctionTreeClique(first_eliminated_node); 428 _node_to_clique_.insert(node, clique); 429 _clique_potentials_[clique].insert(node); 430 } 431 } 432 433 434 // indicate a clique that contains all the nodes of targets 435 _targets2clique_ = std::numeric_limits< NodeId >::max(); 436 { 437 // remove from set all the nodes that received hard evidence (since they 438 // do not belong to the join tree) 439 NodeSet nodeset = targets; 440 for (const auto node: this->hardEvidenceNodes()) 441 if (nodeset.contains(node)) nodeset.erase(node); 442 443 if (!nodeset.empty()) { 444 NodeId first_eliminated_node = *(nodeset.begin()); 445 Size elim_number = elim_order[first_eliminated_node]; 446 for (const auto node: nodeset) { 447 if (elim_order[node] < elim_number) { 448 elim_number = elim_order[node]; 449 first_eliminated_node = node; 450 } 451 } 452 _targets2clique_ = _triangulation_->createdJunctionTreeClique(first_eliminated_node); 453 } 454 } 455 } 456 457 458 /// prepare the inference structures w.r.t. new targets, soft/hard evidence 459 template < typename GUM_SCALAR > updateOutdatedStructure_()460 void VariableElimination< GUM_SCALAR >::updateOutdatedStructure_() {} 461 462 463 /// update the potentials stored in the cliques and invalidate outdated 464 /// messages 465 template < typename GUM_SCALAR > updateOutdatedPotentials_()466 void VariableElimination< GUM_SCALAR >::updateOutdatedPotentials_() {} 467 468 469 // find the potentials d-connected to a set of variables 470 template < typename GUM_SCALAR > _findRelevantPotentialsGetAll_(Set<const Potential<GUM_SCALAR> * > & pot_list,Set<const DiscreteVariable * > & kept_vars)471 void VariableElimination< GUM_SCALAR >::_findRelevantPotentialsGetAll_( 472 Set< const Potential< GUM_SCALAR >* >& pot_list, 473 Set< const DiscreteVariable* >& kept_vars) {} 474 475 476 // find the potentials d-connected to a set of variables 477 template < typename GUM_SCALAR > _findRelevantPotentialsWithdSeparation_(Set<const Potential<GUM_SCALAR> * > & pot_list,Set<const DiscreteVariable * > & kept_vars)478 void VariableElimination< GUM_SCALAR >::_findRelevantPotentialsWithdSeparation_( 479 Set< const Potential< GUM_SCALAR >* >& pot_list, 480 Set< const DiscreteVariable* >& kept_vars) { 481 // find the node ids of the kept variables 482 NodeSet kept_ids; 483 const auto& bn = this->BN(); 484 for (const auto var: kept_vars) { 485 kept_ids.insert(bn.nodeId(*var)); 486 } 487 488 // determine the set of potentials d-connected with the kept variables 489 NodeSet requisite_nodes; 490 BayesBall::requisiteNodes(bn.dag(), 491 kept_ids, 492 this->hardEvidenceNodes(), 493 this->softEvidenceNodes(), 494 requisite_nodes); 495 for (auto iter = pot_list.beginSafe(); iter != pot_list.endSafe(); ++iter) { 496 const Sequence< const DiscreteVariable* >& vars = (**iter).variablesSequence(); 497 bool found = false; 498 for (auto var: vars) { 499 if (requisite_nodes.exists(bn.nodeId(*var))) { 500 found = true; 501 break; 502 } 503 } 504 505 if (!found) { pot_list.erase(iter); } 506 } 507 } 508 509 510 // find the potentials d-connected to a set of variables 511 template < typename GUM_SCALAR > _findRelevantPotentialsWithdSeparation2_(Set<const Potential<GUM_SCALAR> * > & pot_list,Set<const DiscreteVariable * > & kept_vars)512 void VariableElimination< GUM_SCALAR >::_findRelevantPotentialsWithdSeparation2_( 513 Set< const Potential< GUM_SCALAR >* >& pot_list, 514 Set< const DiscreteVariable* >& kept_vars) { 515 // find the node ids of the kept variables 516 NodeSet kept_ids; 517 const auto& bn = this->BN(); 518 for (const auto var: kept_vars) { 519 kept_ids.insert(bn.nodeId(*var)); 520 } 521 522 // determine the set of potentials d-connected with the kept variables 523 BayesBall::relevantPotentials(bn, 524 kept_ids, 525 this->hardEvidenceNodes(), 526 this->softEvidenceNodes(), 527 pot_list); 528 } 529 530 531 // find the potentials d-connected to a set of variables 532 template < typename GUM_SCALAR > _findRelevantPotentialsWithdSeparation3_(Set<const Potential<GUM_SCALAR> * > & pot_list,Set<const DiscreteVariable * > & kept_vars)533 void VariableElimination< GUM_SCALAR >::_findRelevantPotentialsWithdSeparation3_( 534 Set< const Potential< GUM_SCALAR >* >& pot_list, 535 Set< const DiscreteVariable* >& kept_vars) { 536 // find the node ids of the kept variables 537 NodeSet kept_ids; 538 const auto& bn = this->BN(); 539 for (const auto var: kept_vars) { 540 kept_ids.insert(bn.nodeId(*var)); 541 } 542 543 // determine the set of potentials d-connected with the kept variables 544 dSeparation dsep; 545 dsep.relevantPotentials(bn, 546 kept_ids, 547 this->hardEvidenceNodes(), 548 this->softEvidenceNodes(), 549 pot_list); 550 } 551 552 553 // find the potentials d-connected to a set of variables 554 template < typename GUM_SCALAR > _findRelevantPotentialsXX_(Set<const Potential<GUM_SCALAR> * > & pot_list,Set<const DiscreteVariable * > & kept_vars)555 void VariableElimination< GUM_SCALAR >::_findRelevantPotentialsXX_( 556 Set< const Potential< GUM_SCALAR >* >& pot_list, 557 Set< const DiscreteVariable* >& kept_vars) { 558 switch (_find_relevant_potential_type_) { 559 case RelevantPotentialsFinderType::DSEP_BAYESBALL_POTENTIALS: 560 _findRelevantPotentialsWithdSeparation2_(pot_list, kept_vars); 561 break; 562 563 case RelevantPotentialsFinderType::DSEP_BAYESBALL_NODES: 564 _findRelevantPotentialsWithdSeparation_(pot_list, kept_vars); 565 break; 566 567 case RelevantPotentialsFinderType::DSEP_KOLLER_FRIEDMAN_2009: 568 _findRelevantPotentialsWithdSeparation3_(pot_list, kept_vars); 569 break; 570 571 case RelevantPotentialsFinderType::FIND_ALL: 572 _findRelevantPotentialsGetAll_(pot_list, kept_vars); 573 break; 574 575 default: 576 GUM_ERROR(FatalError, "not implemented yet") 577 } 578 } 579 580 581 // remove barren variables 582 template < typename GUM_SCALAR > _removeBarrenVariables_(_PotentialSet_ & pot_list,Set<const DiscreteVariable * > & del_vars)583 Set< const Potential< GUM_SCALAR >* > VariableElimination< GUM_SCALAR >::_removeBarrenVariables_( 584 _PotentialSet_& pot_list, 585 Set< const DiscreteVariable* >& del_vars) { 586 // remove from del_vars the variables that received some evidence: 587 // only those that did not received evidence can be barren variables 588 Set< const DiscreteVariable* > the_del_vars = del_vars; 589 for (auto iter = the_del_vars.beginSafe(); iter != the_del_vars.endSafe(); ++iter) { 590 NodeId id = this->BN().nodeId(**iter); 591 if (this->hardEvidenceNodes().exists(id) || this->softEvidenceNodes().exists(id)) { 592 the_del_vars.erase(iter); 593 } 594 } 595 596 // assign to each random variable the set of potentials that contain it 597 HashTable< const DiscreteVariable*, _PotentialSet_ > var2pots; 598 _PotentialSet_ empty_pot_set; 599 for (const auto pot: pot_list) { 600 const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence(); 601 for (const auto var: vars) { 602 if (the_del_vars.exists(var)) { 603 if (!var2pots.exists(var)) { var2pots.insert(var, empty_pot_set); } 604 var2pots[var].insert(pot); 605 } 606 } 607 } 608 609 // each variable with only one potential is a barren variable 610 // assign to each potential with barren nodes its set of barren variables 611 HashTable< const Potential< GUM_SCALAR >*, Set< const DiscreteVariable* > > pot2barren_var; 612 Set< const DiscreteVariable* > empty_var_set; 613 for (auto elt: var2pots) { 614 if (elt.second.size() == 1) { // here we have a barren variable 615 const Potential< GUM_SCALAR >* pot = *(elt.second.begin()); 616 if (!pot2barren_var.exists(pot)) { pot2barren_var.insert(pot, empty_var_set); } 617 pot2barren_var[pot].insert(elt.first); // insert the barren variable 618 } 619 } 620 621 // for each potential with barren variables, marginalize them. 622 // if the potential has only barren variables, simply remove them from the 623 // set of potentials, else just project the potential 624 MultiDimProjection< GUM_SCALAR, Potential > projector(VENewprojPotential); 625 _PotentialSet_ projected_pots; 626 for (auto elt: pot2barren_var) { 627 // remove the current potential from pot_list as, anyway, we will change 628 // it 629 const Potential< GUM_SCALAR >* pot = elt.first; 630 pot_list.erase(pot); 631 632 // check whether we need to add a projected new potential or not (i.e., 633 // whether there exist non-barren variables or not) 634 if (pot->variablesSequence().size() != elt.second.size()) { 635 auto new_pot = projector.project(*pot, elt.second); 636 pot_list.insert(new_pot); 637 projected_pots.insert(new_pot); 638 } 639 } 640 641 return projected_pots; 642 } 643 644 645 // performs the collect phase of Lazy Propagation 646 template < typename GUM_SCALAR > 647 std::pair< Set< const Potential< GUM_SCALAR >* >, Set< const Potential< GUM_SCALAR >* > > _collectMessage_(NodeId id,NodeId from)648 VariableElimination< GUM_SCALAR >::_collectMessage_(NodeId id, NodeId from) { 649 // collect messages from all the neighbors 650 std::pair< _PotentialSet_, _PotentialSet_ > collect_messages; 651 for (const auto other: _JT_->neighbours(id)) { 652 if (other != from) { 653 std::pair< _PotentialSet_, _PotentialSet_ > message(_collectMessage_(other, id)); 654 collect_messages.first += message.first; 655 collect_messages.second += message.second; 656 } 657 } 658 659 // combine the collect messages with those of id's clique 660 return _produceMessage_(id, from, std::move(collect_messages)); 661 } 662 663 664 // get the CPT + evidence of a node projected w.r.t. hard evidence 665 template < typename GUM_SCALAR > 666 std::pair< Set< const Potential< GUM_SCALAR >* >, Set< const Potential< GUM_SCALAR >* > > _NodePotentials_(NodeId node)667 VariableElimination< GUM_SCALAR >::_NodePotentials_(NodeId node) { 668 std::pair< _PotentialSet_, _PotentialSet_ > res; 669 const auto& bn = this->BN(); 670 671 // get the CPT's of the node 672 // beware: all the potentials that are defined over some nodes 673 // including hard evidence must be projected so that these nodes are 674 // removed from the potential 675 // also beware that the CPT of a hard evidence node may be defined over 676 // parents that do not belong to _graph_ and that are not hard evidence. 677 // In this case, those parents have been removed by d-separation and it is 678 // easy to show that, in this case all the parents have been removed, so 679 // that the CPT does not need to be taken into account 680 const auto& evidence = this->evidence(); 681 const auto& hard_evidence = this->hardEvidence(); 682 if (_graph_.exists(node) || this->hardEvidenceNodes().contains(node)) { 683 const Potential< GUM_SCALAR >& cpt = bn.cpt(node); 684 const auto& variables = cpt.variablesSequence(); 685 686 // check if the parents of a hard evidence node do not belong to _graph_ 687 // and are not themselves hard evidence, discard the CPT, it is useless 688 // for inference 689 if (this->hardEvidenceNodes().contains(node)) { 690 for (const auto var: variables) { 691 NodeId xnode = bn.nodeId(*var); 692 if (!this->hardEvidenceNodes().contains(xnode) && !_graph_.existsNode(xnode)) return res; 693 } 694 } 695 696 // get the list of nodes with hard evidence in cpt 697 NodeSet hard_nodes; 698 for (const auto var: variables) { 699 NodeId xnode = bn.nodeId(*var); 700 if (this->hardEvidenceNodes().contains(xnode)) hard_nodes.insert(xnode); 701 } 702 703 // if hard_nodes contains hard evidence nodes, perform a projection 704 // and insert the result into the appropriate clique, else insert 705 // directly cpt into the clique 706 if (hard_nodes.empty()) { 707 res.first.insert(&cpt); 708 } else { 709 // marginalize out the hard evidence nodes: if the cpt is defined 710 // only over nodes that received hard evidence, do not consider it 711 // as a potential anymore 712 if (hard_nodes.size() != variables.size()) { 713 // perform the projection with a combine and project instance 714 Set< const DiscreteVariable* > hard_variables; 715 _PotentialSet_ marg_cpt_set{&cpt}; 716 for (const auto xnode: hard_nodes) { 717 marg_cpt_set.insert(evidence[xnode]); 718 hard_variables.insert(&(bn.variable(xnode))); 719 } 720 // perform the combination of those potentials and their projection 721 MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential > combine_and_project( 722 _combination_op_, 723 VENewprojPotential); 724 _PotentialSet_ new_cpt_list 725 = combine_and_project.combineAndProject(marg_cpt_set, hard_variables); 726 727 // there should be only one potential in new_cpt_list 728 if (new_cpt_list.size() != 1) { 729 // remove the CPT created to avoid memory leaks 730 for (auto pot: new_cpt_list) { 731 if (!marg_cpt_set.contains(pot)) delete pot; 732 } 733 GUM_ERROR(FatalError, 734 "the projection of a potential containing " 735 << "hard evidence is empty!"); 736 } 737 const Potential< GUM_SCALAR >* projected_cpt = *(new_cpt_list.begin()); 738 res.first.insert(projected_cpt); 739 res.second.insert(projected_cpt); 740 } 741 } 742 743 // if the node received some soft evidence, add it 744 if (evidence.exists(node) && !hard_evidence.exists(node)) { 745 res.first.insert(this->evidence()[node]); 746 } 747 } 748 749 return res; 750 } 751 752 753 // creates the message sent by clique from_id to clique to_id 754 template < typename GUM_SCALAR > 755 std::pair< Set< const Potential< GUM_SCALAR >* >, Set< const Potential< GUM_SCALAR >* > > _produceMessage_(NodeId from_id,NodeId to_id,std::pair<Set<const Potential<GUM_SCALAR> * >,Set<const Potential<GUM_SCALAR> * >> && incoming_messages)756 VariableElimination< GUM_SCALAR >::_produceMessage_( 757 NodeId from_id, 758 NodeId to_id, 759 std::pair< Set< const Potential< GUM_SCALAR >* >, Set< const Potential< GUM_SCALAR >* > >&& 760 incoming_messages) { 761 // get the messages sent by adjacent nodes to from_id 762 std::pair< Set< const Potential< GUM_SCALAR >* >, Set< const Potential< GUM_SCALAR >* > > 763 pot_list(std::move(incoming_messages)); 764 765 // get the potentials of the clique 766 for (const auto node: _clique_potentials_[from_id]) { 767 auto new_pots = _NodePotentials_(node); 768 pot_list.first += new_pots.first; 769 pot_list.second += new_pots.second; 770 } 771 772 // if from_id = to_id: this is the endpoint of a collect 773 if (!_JT_->existsEdge(from_id, to_id)) { 774 return pot_list; 775 } else { 776 // get the set of variables that need be removed from the potentials 777 const NodeSet& from_clique = _JT_->clique(from_id); 778 const NodeSet& separator = _JT_->separator(from_id, to_id); 779 Set< const DiscreteVariable* > del_vars(from_clique.size()); 780 Set< const DiscreteVariable* > kept_vars(separator.size()); 781 const auto& bn = this->BN(); 782 783 for (const auto node: from_clique) { 784 if (!separator.contains(node)) { 785 del_vars.insert(&(bn.variable(node))); 786 } else { 787 kept_vars.insert(&(bn.variable(node))); 788 } 789 } 790 791 // pot_list now contains all the potentials to multiply and marginalize 792 // => combine the messages 793 _PotentialSet_ new_pot_list = _marginalizeOut_(pot_list.first, del_vars, kept_vars); 794 795 /* 796 for the moment, remove this test: due to some optimizations, some 797 potentials might have all their cells greater than 1. 798 799 // remove all the potentials that are equal to ones (as probability 800 // matrix multiplications are tensorial, such potentials are useless) 801 for (auto iter = new_pot_list.beginSafe(); iter != new_pot_list.endSafe(); 802 ++iter) { 803 const auto pot = *iter; 804 if (pot->variablesSequence().size() == 1) { 805 bool is_all_ones = true; 806 for (Instantiation inst(*pot); !inst.end(); ++inst) { 807 if ((*pot)[inst] < _one_minus_epsilon_) { 808 is_all_ones = false; 809 break; 810 } 811 } 812 if (is_all_ones) { 813 if (!pot_list.first.exists(pot)) delete pot; 814 new_pot_list.erase(iter); 815 continue; 816 } 817 } 818 } 819 */ 820 821 // remove the unnecessary temporary messages 822 for (auto iter = pot_list.second.beginSafe(); iter != pot_list.second.endSafe(); ++iter) { 823 if (!new_pot_list.contains(*iter)) { 824 delete *iter; 825 pot_list.second.erase(iter); 826 } 827 } 828 829 // keep track of all the newly created potentials 830 for (const auto pot: new_pot_list) { 831 if (!pot_list.first.contains(pot)) { pot_list.second.insert(pot); } 832 } 833 834 // return the new set of potentials 835 return std::pair< _PotentialSet_, _PotentialSet_ >(std::move(new_pot_list), 836 std::move(pot_list.second)); 837 } 838 } 839 840 841 // remove variables del_vars from the list of potentials pot_list 842 template < typename GUM_SCALAR > _marginalizeOut_(Set<const Potential<GUM_SCALAR> * > pot_list,Set<const DiscreteVariable * > & del_vars,Set<const DiscreteVariable * > & kept_vars)843 Set< const Potential< GUM_SCALAR >* > VariableElimination< GUM_SCALAR >::_marginalizeOut_( 844 Set< const Potential< GUM_SCALAR >* > pot_list, 845 Set< const DiscreteVariable* >& del_vars, 846 Set< const DiscreteVariable* >& kept_vars) { 847 // use d-separation analysis to check which potentials shall be combined 848 _findRelevantPotentialsXX_(pot_list, kept_vars); 849 850 // remove the potentials corresponding to barren variables if we want 851 // to exploit barren nodes 852 _PotentialSet_ barren_projected_potentials; 853 if (_barren_nodes_type_ == FindBarrenNodesType::FIND_BARREN_NODES) { 854 barren_projected_potentials = _removeBarrenVariables_(pot_list, del_vars); 855 } 856 857 // create a combine and project operator that will perform the 858 // marginalization 859 MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential > combine_and_project(_combination_op_, 860 _projection_op_); 861 _PotentialSet_ new_pot_list = combine_and_project.combineAndProject(pot_list, del_vars); 862 863 // remove all the potentials that were created due to projections of 864 // barren nodes and that are not part of the new_pot_list: these 865 // potentials were just temporary potentials 866 for (auto iter = barren_projected_potentials.beginSafe(); 867 iter != barren_projected_potentials.endSafe(); 868 ++iter) { 869 if (!new_pot_list.exists(*iter)) delete *iter; 870 } 871 872 // remove all the potentials that have no dimension 873 for (auto iter_pot = new_pot_list.beginSafe(); iter_pot != new_pot_list.endSafe(); ++iter_pot) { 874 if ((*iter_pot)->variablesSequence().size() == 0) { 875 // as we have already marginalized out variables that received evidence, 876 // it may be the case that, after combining and projecting, some 877 // potentials might be empty. In this case, we shall keep their 878 // constant and remove them from memory 879 // # TODO: keep the constants! 880 delete *iter_pot; 881 new_pot_list.erase(iter_pot); 882 } 883 } 884 885 return new_pot_list; 886 } 887 888 889 // performs a whole inference 890 template < typename GUM_SCALAR > makeInference_()891 INLINE void VariableElimination< GUM_SCALAR >::makeInference_() {} 892 893 894 /// returns a fresh potential equal to P(1st arg,evidence) 895 template < typename GUM_SCALAR > 896 Potential< GUM_SCALAR >* unnormalizedJointPosterior_(NodeId id)897 VariableElimination< GUM_SCALAR >::unnormalizedJointPosterior_(NodeId id) { 898 const auto& bn = this->BN(); 899 900 // hard evidence do not belong to the join tree 901 // # TODO: check for sets of inconsistent hard evidence 902 if (this->hardEvidenceNodes().contains(id)) { 903 return new Potential< GUM_SCALAR >(*(this->evidence()[id])); 904 } 905 906 // if we still need to perform some inference task, do it 907 _createNewJT_(NodeSet{id}); 908 NodeId clique_of_id = _node_to_clique_[id]; 909 auto pot_list = _collectMessage_(clique_of_id, clique_of_id); 910 911 // get the set of variables that need be removed from the potentials 912 const NodeSet& nodes = _JT_->clique(clique_of_id); 913 Set< const DiscreteVariable* > kept_vars{&(bn.variable(id))}; 914 Set< const DiscreteVariable* > del_vars(nodes.size()); 915 for (const auto node: nodes) { 916 if (node != id) del_vars.insert(&(bn.variable(node))); 917 } 918 919 // pot_list now contains all the potentials to multiply and marginalize 920 // => combine the messages 921 _PotentialSet_ new_pot_list = _marginalizeOut_(pot_list.first, del_vars, kept_vars); 922 Potential< GUM_SCALAR >* joint = nullptr; 923 924 if (new_pot_list.size() == 1) { 925 joint = const_cast< Potential< GUM_SCALAR >* >(*(new_pot_list.begin())); 926 // if joint already existed, create a copy, so that we can put it into 927 // the _target_posterior_ property 928 if (pot_list.first.exists(joint)) { 929 joint = new Potential< GUM_SCALAR >(*joint); 930 } else { 931 // remove the joint from new_pot_list so that it will not be 932 // removed just after the else block 933 new_pot_list.clear(); 934 } 935 } else { 936 MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(_combination_op_); 937 joint = fast_combination.combine(new_pot_list); 938 } 939 940 // remove the potentials that were created in new_pot_list 941 for (auto pot: new_pot_list) 942 if (!pot_list.first.exists(pot)) delete pot; 943 944 // remove all the temporary potentials created in pot_list 945 for (auto pot: pot_list.second) 946 delete pot; 947 948 // check that the joint posterior is different from a 0 vector: this would 949 // indicate that some hard evidence are not compatible (their joint 950 // probability is equal to 0) 951 bool nonzero_found = false; 952 for (Instantiation inst(*joint); !inst.end(); ++inst) { 953 if ((*joint)[inst]) { 954 nonzero_found = true; 955 break; 956 } 957 } 958 if (!nonzero_found) { 959 // remove joint from memory to avoid memory leaks 960 delete joint; 961 GUM_ERROR(IncompatibleEvidence, 962 "some evidence entered into the Bayes " 963 "net are incompatible (their joint proba = 0)"); 964 } 965 966 return joint; 967 } 968 969 970 /// returns the posterior of a given variable 971 template < typename GUM_SCALAR > posterior_(NodeId id)972 const Potential< GUM_SCALAR >& VariableElimination< GUM_SCALAR >::posterior_(NodeId id) { 973 // compute the joint posterior and normalize 974 auto joint = unnormalizedJointPosterior_(id); 975 if (joint->sum() != 1) // hard test for ReadOnly CPT (as aggregator) 976 joint->normalize(); 977 978 if (_target_posterior_ != nullptr) delete _target_posterior_; 979 _target_posterior_ = joint; 980 981 return *joint; 982 } 983 984 985 // returns the marginal a posteriori proba of a given node 986 template < typename GUM_SCALAR > 987 Potential< GUM_SCALAR >* unnormalizedJointPosterior_(const NodeSet & set)988 VariableElimination< GUM_SCALAR >::unnormalizedJointPosterior_(const NodeSet& set) { 989 // hard evidence do not belong to the join tree, so extract the nodes 990 // from targets that are not hard evidence 991 NodeSet targets = set, hard_ev_nodes; 992 for (const auto node: this->hardEvidenceNodes()) { 993 if (targets.contains(node)) { 994 targets.erase(node); 995 hard_ev_nodes.insert(node); 996 } 997 } 998 999 // if all the nodes have received hard evidence, then compute the 1000 // joint posterior directly by multiplying the hard evidence potentials 1001 const auto& evidence = this->evidence(); 1002 if (targets.empty()) { 1003 _PotentialSet_ pot_list; 1004 for (const auto node: set) { 1005 pot_list.insert(evidence[node]); 1006 } 1007 if (pot_list.size() == 1) { 1008 return new Potential< GUM_SCALAR >(**(pot_list.begin())); 1009 } else { 1010 MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(_combination_op_); 1011 return fast_combination.combine(pot_list); 1012 } 1013 } 1014 1015 // if we still need to perform some inference task, do it 1016 _createNewJT_(set); 1017 auto pot_list = _collectMessage_(_targets2clique_, _targets2clique_); 1018 1019 // get the set of variables that need be removed from the potentials 1020 const NodeSet& nodes = _JT_->clique(_targets2clique_); 1021 Set< const DiscreteVariable* > del_vars(nodes.size()); 1022 Set< const DiscreteVariable* > kept_vars(targets.size()); 1023 const auto& bn = this->BN(); 1024 for (const auto node: nodes) { 1025 if (!targets.contains(node)) { 1026 del_vars.insert(&(bn.variable(node))); 1027 } else { 1028 kept_vars.insert(&(bn.variable(node))); 1029 } 1030 } 1031 1032 // pot_list now contains all the potentials to multiply and marginalize 1033 // => combine the messages 1034 _PotentialSet_ new_pot_list = _marginalizeOut_(pot_list.first, del_vars, kept_vars); 1035 Potential< GUM_SCALAR >* joint = nullptr; 1036 1037 if ((new_pot_list.size() == 1) && hard_ev_nodes.empty()) { 1038 joint = const_cast< Potential< GUM_SCALAR >* >(*(new_pot_list.begin())); 1039 // if pot already existed, create a copy, so that we can put it into 1040 // the _target_posteriors_ property 1041 if (pot_list.first.exists(joint)) { 1042 joint = new Potential< GUM_SCALAR >(*joint); 1043 } else { 1044 // remove the joint from new_pot_list so that it will not be 1045 // removed just after the next else block 1046 new_pot_list.clear(); 1047 } 1048 } else { 1049 // combine all the potentials in new_pot_list with all the hard evidence 1050 // of the nodes in set 1051 _PotentialSet_ new_new_pot_list = new_pot_list; 1052 for (const auto node: hard_ev_nodes) { 1053 new_new_pot_list.insert(evidence[node]); 1054 } 1055 MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(_combination_op_); 1056 joint = fast_combination.combine(new_new_pot_list); 1057 } 1058 1059 // remove the potentials that were created in new_pot_list 1060 for (auto pot: new_pot_list) 1061 if (!pot_list.first.exists(pot)) delete pot; 1062 1063 // remove all the temporary potentials created in pot_list 1064 for (auto pot: pot_list.second) 1065 delete pot; 1066 1067 // check that the joint posterior is different from a 0 vector: this would 1068 // indicate that some hard evidence are not compatible 1069 bool nonzero_found = false; 1070 for (Instantiation inst(*joint); !inst.end(); ++inst) { 1071 if ((*joint)[inst]) { 1072 nonzero_found = true; 1073 break; 1074 } 1075 } 1076 if (!nonzero_found) { 1077 // remove joint from memory to avoid memory leaks 1078 delete joint; 1079 GUM_ERROR(IncompatibleEvidence, 1080 "some evidence entered into the Bayes " 1081 "net are incompatible (their joint proba = 0)"); 1082 } 1083 1084 return joint; 1085 } 1086 1087 1088 /// returns the posterior of a given set of variables 1089 template < typename GUM_SCALAR > 1090 const Potential< GUM_SCALAR >& jointPosterior_(const NodeSet & set)1091 VariableElimination< GUM_SCALAR >::jointPosterior_(const NodeSet& set) { 1092 // compute the joint posterior and normalize 1093 auto joint = unnormalizedJointPosterior_(set); 1094 joint->normalize(); 1095 1096 if (_target_posterior_ != nullptr) delete _target_posterior_; 1097 _target_posterior_ = joint; 1098 1099 return *joint; 1100 } 1101 1102 1103 /// returns the posterior of a given set of variables 1104 template < typename GUM_SCALAR > 1105 const Potential< GUM_SCALAR >& jointPosterior_(const NodeSet & wanted_target,const NodeSet & declared_target)1106 VariableElimination< GUM_SCALAR >::jointPosterior_(const NodeSet& wanted_target, 1107 const NodeSet& declared_target) { 1108 return jointPosterior_(wanted_target); 1109 } 1110 1111 1112 } /* namespace gum */ 1113 1114 #endif // DOXYGEN_SHOULD_SKIP_THIS 1115