1 /**
2  *
3  *   Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU)
4  *   info_at_agrum_dot_org
5  *
6  *  This library is free software: you can redistribute it and/or modify
7  *  it under the terms of the GNU Lesser General Public License as published by
8  *  the Free Software Foundation, either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  This library is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU Lesser General Public License for more details.
15  *
16  *  You should have received a copy of the GNU Lesser General Public License
17  *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Implementation of Variable Elimination for inference in
25  * Bayesian networks.
26  *
27  * @author Christophe GONZALES(_at_AMU) and Pierre-Henri WUILLEMIN(_at_LIP6)
28  */
29 
30 #ifndef DOXYGEN_SHOULD_SKIP_THIS
31 
32 #  include <agrum/BN/inference/variableElimination.h>
33 
34 #  include <agrum/BN/algorithms/BayesBall.h>
35 #  include <agrum/BN/algorithms/dSeparation.h>
36 #  include <agrum/tools/graphs/algorithms/binaryJoinTreeConverterDefault.h>
37 #  include <agrum/tools/multidim/instantiation.h>
38 #  include <agrum/tools/multidim/utils/operators/multiDimCombineAndProjectDefault.h>
39 #  include <agrum/tools/multidim/utils/operators/multiDimProjection.h>
40 
41 
42 namespace gum {
43 
44 
45   // default constructor
46   template < typename GUM_SCALAR >
VariableElimination(const IBayesNet<GUM_SCALAR> * BN,RelevantPotentialsFinderType relevant_type,FindBarrenNodesType barren_type)47   INLINE VariableElimination< GUM_SCALAR >::VariableElimination(
48      const IBayesNet< GUM_SCALAR >* BN,
49      RelevantPotentialsFinderType   relevant_type,
50      FindBarrenNodesType            barren_type) :
51       JointTargetedInference< GUM_SCALAR >(BN) {
52     // sets the relevant potential and the barren nodes finding algorithm
53     setRelevantPotentialsFinderType(relevant_type);
54     setFindBarrenNodesType(barren_type);
55 
56     // create a default triangulation (the user can change it afterwards)
57     _triangulation_ = new DefaultTriangulation;
58 
59     // for debugging purposessetRequiredInference
60     GUM_CONSTRUCTOR(VariableElimination);
61   }
62 
63 
64   // destructor
65   template < typename GUM_SCALAR >
~VariableElimination()66   INLINE VariableElimination< GUM_SCALAR >::~VariableElimination() {
67     // remove the junction tree and the triangulation algorithm
68     if (_JT_ != nullptr) delete _JT_;
69     delete _triangulation_;
70     if (_target_posterior_ != nullptr) delete _target_posterior_;
71 
72     // for debugging purposes
73     GUM_DESTRUCTOR(VariableElimination);
74   }
75 
76 
77   /// set a new triangulation algorithm
78   template < typename GUM_SCALAR >
setTriangulation(const Triangulation & new_triangulation)79   void VariableElimination< GUM_SCALAR >::setTriangulation(const Triangulation& new_triangulation) {
80     delete _triangulation_;
81     _triangulation_ = new_triangulation.newFactory();
82   }
83 
84 
85   /// returns the current join tree used
86   template < typename GUM_SCALAR >
junctionTree(NodeId id)87   INLINE const JunctionTree* VariableElimination< GUM_SCALAR >::junctionTree(NodeId id) {
88     _createNewJT_(NodeSet{id});
89 
90     return _JT_;
91   }
92 
93 
94   /// sets the operator for performing the projections
95   template < typename GUM_SCALAR >
_setProjectionFunction_(Potential<GUM_SCALAR> * (* proj)(const Potential<GUM_SCALAR> &,const Set<const DiscreteVariable * > &))96   INLINE void VariableElimination< GUM_SCALAR >::_setProjectionFunction_(Potential< GUM_SCALAR >* (
97      *proj)(const Potential< GUM_SCALAR >&, const Set< const DiscreteVariable* >&)) {
98     _projection_op_ = proj;
99   }
100 
101 
102   /// sets the operator for performing the combinations
103   template < typename GUM_SCALAR >
_setCombinationFunction_(Potential<GUM_SCALAR> * (* comb)(const Potential<GUM_SCALAR> &,const Potential<GUM_SCALAR> &))104   INLINE void VariableElimination< GUM_SCALAR >::_setCombinationFunction_(Potential< GUM_SCALAR >* (
105      *comb)(const Potential< GUM_SCALAR >&, const Potential< GUM_SCALAR >&)) {
106     _combination_op_ = comb;
107   }
108 
109 
110   /// sets how we determine the relevant potentials to combine
111   template < typename GUM_SCALAR >
setRelevantPotentialsFinderType(RelevantPotentialsFinderType type)112   void VariableElimination< GUM_SCALAR >::setRelevantPotentialsFinderType(
113      RelevantPotentialsFinderType type) {
114     if (type != _find_relevant_potential_type_) {
115       switch (type) {
116         case RelevantPotentialsFinderType::DSEP_BAYESBALL_POTENTIALS:
117           _findRelevantPotentials_
118              = &VariableElimination< GUM_SCALAR >::_findRelevantPotentialsWithdSeparation2_;
119           break;
120 
121         case RelevantPotentialsFinderType::DSEP_BAYESBALL_NODES:
122           _findRelevantPotentials_
123              = &VariableElimination< GUM_SCALAR >::_findRelevantPotentialsWithdSeparation_;
124           break;
125 
126         case RelevantPotentialsFinderType::DSEP_KOLLER_FRIEDMAN_2009:
127           _findRelevantPotentials_
128              = &VariableElimination< GUM_SCALAR >::_findRelevantPotentialsWithdSeparation3_;
129           break;
130 
131         case RelevantPotentialsFinderType::FIND_ALL:
132           _findRelevantPotentials_
133              = &VariableElimination< GUM_SCALAR >::_findRelevantPotentialsGetAll_;
134           break;
135 
136         default:
137           GUM_ERROR(InvalidArgument,
138                     "setRelevantPotentialsFinderType for type " << (unsigned int)type
139                                                                 << " is not implemented yet");
140       }
141 
142       _find_relevant_potential_type_ = type;
143     }
144   }
145 
146 
147   /// sets how we determine barren nodes
148   template < typename GUM_SCALAR >
setFindBarrenNodesType(FindBarrenNodesType type)149   void VariableElimination< GUM_SCALAR >::setFindBarrenNodesType(FindBarrenNodesType type) {
150     if (type != _barren_nodes_type_) {
151       // WARNING: if a new type is added here, method  _createJT_ should
152       // certainly
153       // be updated as well, in particular its step 2.
154       switch (type) {
155         case FindBarrenNodesType::FIND_BARREN_NODES:
156         case FindBarrenNodesType::FIND_NO_BARREN_NODES:
157           break;
158 
159         default:
160           GUM_ERROR(InvalidArgument,
161                     "setFindBarrenNodesType for type " << (unsigned int)type
162                                                        << " is not implemented yet");
163       }
164 
165       _barren_nodes_type_ = type;
166     }
167   }
168 
169 
170   /// fired when a new evidence is inserted
171   template < typename GUM_SCALAR >
onEvidenceAdded_(const NodeId,bool)172   INLINE void VariableElimination< GUM_SCALAR >::onEvidenceAdded_(const NodeId, bool) {}
173 
174 
175   /// fired when an evidence is removed
176   template < typename GUM_SCALAR >
onEvidenceErased_(const NodeId,bool)177   INLINE void VariableElimination< GUM_SCALAR >::onEvidenceErased_(const NodeId, bool) {}
178 
179 
180   /// fired when all the evidence are erased
181   template < typename GUM_SCALAR >
onAllEvidenceErased_(bool)182   void VariableElimination< GUM_SCALAR >::onAllEvidenceErased_(bool) {}
183 
184 
185   /// fired when an evidence is changed
186   template < typename GUM_SCALAR >
onEvidenceChanged_(const NodeId,bool)187   INLINE void VariableElimination< GUM_SCALAR >::onEvidenceChanged_(const NodeId, bool) {}
188 
189 
190   /// fired after a new target is inserted
191   template < typename GUM_SCALAR >
onMarginalTargetAdded_(const NodeId)192   INLINE void VariableElimination< GUM_SCALAR >::onMarginalTargetAdded_(const NodeId) {}
193 
194 
195   /// fired before a target is removed
196   template < typename GUM_SCALAR >
onMarginalTargetErased_(const NodeId)197   INLINE void VariableElimination< GUM_SCALAR >::onMarginalTargetErased_(const NodeId) {}
198 
199   /// fired after a new Bayes net has been assigned to the engine
200   template < typename GUM_SCALAR >
onModelChanged_(const GraphicalModel * bn)201   INLINE void VariableElimination< GUM_SCALAR >::onModelChanged_(const GraphicalModel* bn) {}
202 
203   /// fired after a new set target is inserted
204   template < typename GUM_SCALAR >
onJointTargetAdded_(const NodeSet &)205   INLINE void VariableElimination< GUM_SCALAR >::onJointTargetAdded_(const NodeSet&) {}
206 
207 
208   /// fired before a set target is removed
209   template < typename GUM_SCALAR >
onJointTargetErased_(const NodeSet &)210   INLINE void VariableElimination< GUM_SCALAR >::onJointTargetErased_(const NodeSet&) {}
211 
212 
213   /// fired after all the nodes of the BN are added as single targets
214   template < typename GUM_SCALAR >
onAllMarginalTargetsAdded_()215   INLINE void VariableElimination< GUM_SCALAR >::onAllMarginalTargetsAdded_() {}
216 
217 
218   /// fired before a all the single_targets are removed
219   template < typename GUM_SCALAR >
onAllMarginalTargetsErased_()220   INLINE void VariableElimination< GUM_SCALAR >::onAllMarginalTargetsErased_() {}
221 
222 
223   /// fired before a all the joint_targets are removed
224   template < typename GUM_SCALAR >
onAllJointTargetsErased_()225   INLINE void VariableElimination< GUM_SCALAR >::onAllJointTargetsErased_() {}
226 
227 
228   /// fired before a all the single and joint_targets are removed
229   template < typename GUM_SCALAR >
onAllTargetsErased_()230   INLINE void VariableElimination< GUM_SCALAR >::onAllTargetsErased_() {}
231 
232 
233   /// create a new junction tree as well as its related data structures
234   template < typename GUM_SCALAR >
_createNewJT_(const NodeSet & targets)235   void VariableElimination< GUM_SCALAR >::_createNewJT_(const NodeSet& targets) {
236     // to create the JT, we first create the moral graph of the BN in the
237     // following way in order to take into account the barren nodes and the
238     // nodes that received evidence:
239     // 1/ we create an undirected graph containing only the nodes and no edge
240     // 2/ if we take into account barren nodes, remove them from the graph
241     // 3/ if we take d-separation into account, remove the d-separated nodes
242     // 4/ add edges so that each node and its parents in the BN form a clique
243     // 5/ add edges so that the targets form a clique of the moral graph
244     // 6/ remove the nodes that received hard evidence (by step 4/, their
245     //    parents are linked by edges, which is necessary for inference)
246     //
247     // At the end of step 6/, we have our moral graph and we can triangulate it
248     // to get the new junction tree
249 
250     // 1/ create an undirected graph containing only the nodes and no edge
251     const auto& bn = this->BN();
252     _graph_.clear();
253     for (auto node: bn.dag())
254       _graph_.addNodeWithId(node);
255 
256     // 2/ if we wish to exploit barren nodes, we shall remove them from the BN
257     // to do so: we identify all the nodes that are not targets and have
258     // received no evidence and such that their descendants are neither targets
259     // nor evidence nodes. Such nodes can be safely discarded from the BN
260     // without altering the inference output
261     if (_barren_nodes_type_ == FindBarrenNodesType::FIND_BARREN_NODES) {
262       // check that all the nodes are not targets, otherwise, there is no
263       // barren node
264       if (targets.size() != bn.size()) {
265         BarrenNodesFinder finder(&(bn.dag()));
266         finder.setTargets(&targets);
267 
268         NodeSet evidence_nodes;
269         for (const auto& pair: this->evidence()) {
270           evidence_nodes.insert(pair.first);
271         }
272         finder.setEvidence(&evidence_nodes);
273 
274         NodeSet barren_nodes = finder.barrenNodes();
275 
276         // remove the barren nodes from the moral graph
277         for (const auto node: barren_nodes) {
278           _graph_.eraseNode(node);
279         }
280       }
281     }
282 
283     // 3/ if we wish to exploit d-separation, remove all the nodes that are
284     // d-separated from our targets
285     {
286       NodeSet requisite_nodes;
287       bool    dsep_analysis = false;
288       switch (_find_relevant_potential_type_) {
289         case RelevantPotentialsFinderType::DSEP_BAYESBALL_POTENTIALS:
290         case RelevantPotentialsFinderType::DSEP_BAYESBALL_NODES: {
291           BayesBall::requisiteNodes(bn.dag(),
292                                     targets,
293                                     this->hardEvidenceNodes(),
294                                     this->softEvidenceNodes(),
295                                     requisite_nodes);
296           dsep_analysis = true;
297         } break;
298 
299         case RelevantPotentialsFinderType::DSEP_KOLLER_FRIEDMAN_2009: {
300           dSeparation dsep;
301           dsep.requisiteNodes(bn.dag(),
302                               targets,
303                               this->hardEvidenceNodes(),
304                               this->softEvidenceNodes(),
305                               requisite_nodes);
306           dsep_analysis = true;
307         } break;
308 
309         case RelevantPotentialsFinderType::FIND_ALL:
310           break;
311 
312         default:
313           GUM_ERROR(FatalError, "not implemented yet")
314       }
315 
316       // remove all the nodes that are not requisite
317       if (dsep_analysis) {
318         for (auto iter = _graph_.beginSafe(); iter != _graph_.endSafe(); ++iter) {
319           if (!requisite_nodes.contains(*iter) && !this->hardEvidenceNodes().contains(*iter)) {
320             _graph_.eraseNode(*iter);
321           }
322         }
323       }
324     }
325 
326     // 4/ add edges so that each node and its parents in the BN form a clique
327     for (const auto node: _graph_) {
328       const NodeSet& parents = bn.parents(node);
329       for (auto iter1 = parents.cbegin(); iter1 != parents.cend(); ++iter1) {
330         // before adding an edge between node and its parent, check that the
331         // parent belong to the graph. Actually, when d-separated nodes are
332         // removed, it may be the case that the parents of hard evidence nodes
333         // are removed. But the latter still exist in the graph.
334         if (_graph_.existsNode(*iter1)) _graph_.addEdge(*iter1, node);
335 
336         auto iter2 = iter1;
337         for (++iter2; iter2 != parents.cend(); ++iter2) {
338           // before adding an edge, check that both extremities belong to
339           // the graph. Actually, when d-separated nodes are removed, it may
340           // be the case that the parents of hard evidence nodes are removed.
341           // But the latter still exist in the graph.
342           if (_graph_.existsNode(*iter1) && _graph_.existsNode(*iter2))
343             _graph_.addEdge(*iter1, *iter2);
344         }
345       }
346     }
347 
348     // 5/ if targets contains several nodes, we shall add new edges into the
349     // moral graph in order to ensure that there exists a clique containing
350     // thier joint distribution
351     for (auto iter1 = targets.cbegin(); iter1 != targets.cend(); ++iter1) {
352       auto iter2 = iter1;
353       for (++iter2; iter2 != targets.cend(); ++iter2) {
354         _graph_.addEdge(*iter1, *iter2);
355       }
356     }
357 
358     // 6/ remove all the nodes that received hard evidence
359     for (const auto node: this->hardEvidenceNodes()) {
360       _graph_.eraseNode(node);
361     }
362 
363 
364     // now, we can compute the new junction tree.
365     if (_JT_ != nullptr) delete _JT_;
366     _triangulation_->setGraph(&_graph_, &(this->domainSizes()));
367     const JunctionTree& triang_jt = _triangulation_->junctionTree();
368     _JT_                          = new CliqueGraph(triang_jt);
369 
370     // indicate, for each node of the moral graph a clique in  _JT_ that can
371     // contain its conditional probability table
372     _node_to_clique_.clear();
373     _clique_potentials_.clear();
374     NodeSet emptyset;
375     for (auto clique: *_JT_)
376       _clique_potentials_.insert(clique, emptyset);
377     const std::vector< NodeId >& JT_elim_order = _triangulation_->eliminationOrder();
378     NodeProperty< Size >         elim_order(Size(JT_elim_order.size()));
379     for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size; ++i)
380       elim_order.insert(JT_elim_order[i], NodeId(i));
381     const DAG& dag = bn.dag();
382     for (const auto node: _graph_) {
383       // get the variables in the potential of node (and its parents)
384       NodeId first_eliminated_node = node;
385       Size   elim_number           = elim_order[first_eliminated_node];
386 
387       for (const auto parent: dag.parents(node)) {
388         if (_graph_.existsNode(parent) && (elim_order[parent] < elim_number)) {
389           elim_number           = elim_order[parent];
390           first_eliminated_node = parent;
391         }
392       }
393 
394       // first_eliminated_node contains the first var (node or one of its
395       // parents) eliminated => the clique created during its elimination
396       // contains node and all of its parents => it can contain the potential
397       // assigned to the node in the BN
398       NodeId clique = _triangulation_->createdJunctionTreeClique(first_eliminated_node);
399       _node_to_clique_.insert(node, clique);
400       _clique_potentials_[clique].insert(node);
401     }
402 
403     // do the same for the nodes that received evidence. Here, we only store
404     // the nodes whose at least one parent belongs to  _graph_ (otherwise
405     // their CPT is just a constant real number).
406     for (const auto node: this->hardEvidenceNodes()) {
407       // get the set of parents of the node that belong to  _graph_
408       NodeSet pars(dag.parents(node).size());
409       for (const auto par: dag.parents(node))
410         if (_graph_.exists(par)) pars.insert(par);
411 
412       if (!pars.empty()) {
413         NodeId first_eliminated_node = *(pars.begin());
414         Size   elim_number           = elim_order[first_eliminated_node];
415 
416         for (const auto parent: pars) {
417           if (elim_order[parent] < elim_number) {
418             elim_number           = elim_order[parent];
419             first_eliminated_node = parent;
420           }
421         }
422 
423         // first_eliminated_node contains the first var (node or one of its
424         // parents) eliminated => the clique created during its elimination
425         // contains node and all of its parents => it can contain the potential
426         // assigned to the node in the BN
427         NodeId clique = _triangulation_->createdJunctionTreeClique(first_eliminated_node);
428         _node_to_clique_.insert(node, clique);
429         _clique_potentials_[clique].insert(node);
430       }
431     }
432 
433 
434     // indicate a clique that contains all the nodes of targets
435     _targets2clique_ = std::numeric_limits< NodeId >::max();
436     {
437       // remove from set all the nodes that received hard evidence (since they
438       // do not belong to the join tree)
439       NodeSet nodeset = targets;
440       for (const auto node: this->hardEvidenceNodes())
441         if (nodeset.contains(node)) nodeset.erase(node);
442 
443       if (!nodeset.empty()) {
444         NodeId first_eliminated_node = *(nodeset.begin());
445         Size   elim_number           = elim_order[first_eliminated_node];
446         for (const auto node: nodeset) {
447           if (elim_order[node] < elim_number) {
448             elim_number           = elim_order[node];
449             first_eliminated_node = node;
450           }
451         }
452         _targets2clique_ = _triangulation_->createdJunctionTreeClique(first_eliminated_node);
453       }
454     }
455   }
456 
457 
458   /// prepare the inference structures w.r.t. new targets, soft/hard evidence
459   template < typename GUM_SCALAR >
updateOutdatedStructure_()460   void VariableElimination< GUM_SCALAR >::updateOutdatedStructure_() {}
461 
462 
463   /// update the potentials stored in the cliques and invalidate outdated
464   /// messages
465   template < typename GUM_SCALAR >
updateOutdatedPotentials_()466   void VariableElimination< GUM_SCALAR >::updateOutdatedPotentials_() {}
467 
468 
469   // find the potentials d-connected to a set of variables
470   template < typename GUM_SCALAR >
_findRelevantPotentialsGetAll_(Set<const Potential<GUM_SCALAR> * > & pot_list,Set<const DiscreteVariable * > & kept_vars)471   void VariableElimination< GUM_SCALAR >::_findRelevantPotentialsGetAll_(
472      Set< const Potential< GUM_SCALAR >* >& pot_list,
473      Set< const DiscreteVariable* >&        kept_vars) {}
474 
475 
476   // find the potentials d-connected to a set of variables
477   template < typename GUM_SCALAR >
_findRelevantPotentialsWithdSeparation_(Set<const Potential<GUM_SCALAR> * > & pot_list,Set<const DiscreteVariable * > & kept_vars)478   void VariableElimination< GUM_SCALAR >::_findRelevantPotentialsWithdSeparation_(
479      Set< const Potential< GUM_SCALAR >* >& pot_list,
480      Set< const DiscreteVariable* >&        kept_vars) {
481     // find the node ids of the kept variables
482     NodeSet     kept_ids;
483     const auto& bn = this->BN();
484     for (const auto var: kept_vars) {
485       kept_ids.insert(bn.nodeId(*var));
486     }
487 
488     // determine the set of potentials d-connected with the kept variables
489     NodeSet requisite_nodes;
490     BayesBall::requisiteNodes(bn.dag(),
491                               kept_ids,
492                               this->hardEvidenceNodes(),
493                               this->softEvidenceNodes(),
494                               requisite_nodes);
495     for (auto iter = pot_list.beginSafe(); iter != pot_list.endSafe(); ++iter) {
496       const Sequence< const DiscreteVariable* >& vars  = (**iter).variablesSequence();
497       bool                                       found = false;
498       for (auto var: vars) {
499         if (requisite_nodes.exists(bn.nodeId(*var))) {
500           found = true;
501           break;
502         }
503       }
504 
505       if (!found) { pot_list.erase(iter); }
506     }
507   }
508 
509 
510   // find the potentials d-connected to a set of variables
511   template < typename GUM_SCALAR >
_findRelevantPotentialsWithdSeparation2_(Set<const Potential<GUM_SCALAR> * > & pot_list,Set<const DiscreteVariable * > & kept_vars)512   void VariableElimination< GUM_SCALAR >::_findRelevantPotentialsWithdSeparation2_(
513      Set< const Potential< GUM_SCALAR >* >& pot_list,
514      Set< const DiscreteVariable* >&        kept_vars) {
515     // find the node ids of the kept variables
516     NodeSet     kept_ids;
517     const auto& bn = this->BN();
518     for (const auto var: kept_vars) {
519       kept_ids.insert(bn.nodeId(*var));
520     }
521 
522     // determine the set of potentials d-connected with the kept variables
523     BayesBall::relevantPotentials(bn,
524                                   kept_ids,
525                                   this->hardEvidenceNodes(),
526                                   this->softEvidenceNodes(),
527                                   pot_list);
528   }
529 
530 
531   // find the potentials d-connected to a set of variables
532   template < typename GUM_SCALAR >
_findRelevantPotentialsWithdSeparation3_(Set<const Potential<GUM_SCALAR> * > & pot_list,Set<const DiscreteVariable * > & kept_vars)533   void VariableElimination< GUM_SCALAR >::_findRelevantPotentialsWithdSeparation3_(
534      Set< const Potential< GUM_SCALAR >* >& pot_list,
535      Set< const DiscreteVariable* >&        kept_vars) {
536     // find the node ids of the kept variables
537     NodeSet     kept_ids;
538     const auto& bn = this->BN();
539     for (const auto var: kept_vars) {
540       kept_ids.insert(bn.nodeId(*var));
541     }
542 
543     // determine the set of potentials d-connected with the kept variables
544     dSeparation dsep;
545     dsep.relevantPotentials(bn,
546                             kept_ids,
547                             this->hardEvidenceNodes(),
548                             this->softEvidenceNodes(),
549                             pot_list);
550   }
551 
552 
553   // find the potentials d-connected to a set of variables
554   template < typename GUM_SCALAR >
_findRelevantPotentialsXX_(Set<const Potential<GUM_SCALAR> * > & pot_list,Set<const DiscreteVariable * > & kept_vars)555   void VariableElimination< GUM_SCALAR >::_findRelevantPotentialsXX_(
556      Set< const Potential< GUM_SCALAR >* >& pot_list,
557      Set< const DiscreteVariable* >&        kept_vars) {
558     switch (_find_relevant_potential_type_) {
559       case RelevantPotentialsFinderType::DSEP_BAYESBALL_POTENTIALS:
560         _findRelevantPotentialsWithdSeparation2_(pot_list, kept_vars);
561         break;
562 
563       case RelevantPotentialsFinderType::DSEP_BAYESBALL_NODES:
564         _findRelevantPotentialsWithdSeparation_(pot_list, kept_vars);
565         break;
566 
567       case RelevantPotentialsFinderType::DSEP_KOLLER_FRIEDMAN_2009:
568         _findRelevantPotentialsWithdSeparation3_(pot_list, kept_vars);
569         break;
570 
571       case RelevantPotentialsFinderType::FIND_ALL:
572         _findRelevantPotentialsGetAll_(pot_list, kept_vars);
573         break;
574 
575       default:
576         GUM_ERROR(FatalError, "not implemented yet")
577     }
578   }
579 
580 
581   // remove barren variables
582   template < typename GUM_SCALAR >
_removeBarrenVariables_(_PotentialSet_ & pot_list,Set<const DiscreteVariable * > & del_vars)583   Set< const Potential< GUM_SCALAR >* > VariableElimination< GUM_SCALAR >::_removeBarrenVariables_(
584      _PotentialSet_&                 pot_list,
585      Set< const DiscreteVariable* >& del_vars) {
586     // remove from del_vars the variables that received some evidence:
587     // only those that did not received evidence can be barren variables
588     Set< const DiscreteVariable* > the_del_vars = del_vars;
589     for (auto iter = the_del_vars.beginSafe(); iter != the_del_vars.endSafe(); ++iter) {
590       NodeId id = this->BN().nodeId(**iter);
591       if (this->hardEvidenceNodes().exists(id) || this->softEvidenceNodes().exists(id)) {
592         the_del_vars.erase(iter);
593       }
594     }
595 
596     // assign to each random variable the set of potentials that contain it
597     HashTable< const DiscreteVariable*, _PotentialSet_ > var2pots;
598     _PotentialSet_                                       empty_pot_set;
599     for (const auto pot: pot_list) {
600       const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
601       for (const auto var: vars) {
602         if (the_del_vars.exists(var)) {
603           if (!var2pots.exists(var)) { var2pots.insert(var, empty_pot_set); }
604           var2pots[var].insert(pot);
605         }
606       }
607     }
608 
609     // each variable with only one potential is a barren variable
610     // assign to each potential with barren nodes its set of barren variables
611     HashTable< const Potential< GUM_SCALAR >*, Set< const DiscreteVariable* > > pot2barren_var;
612     Set< const DiscreteVariable* >                                              empty_var_set;
613     for (auto elt: var2pots) {
614       if (elt.second.size() == 1) {   // here we have a barren variable
615         const Potential< GUM_SCALAR >* pot = *(elt.second.begin());
616         if (!pot2barren_var.exists(pot)) { pot2barren_var.insert(pot, empty_var_set); }
617         pot2barren_var[pot].insert(elt.first);   // insert the barren variable
618       }
619     }
620 
621     // for each potential with barren variables, marginalize them.
622     // if the potential has only barren variables, simply remove them from the
623     // set of potentials, else just project the potential
624     MultiDimProjection< GUM_SCALAR, Potential > projector(VENewprojPotential);
625     _PotentialSet_                              projected_pots;
626     for (auto elt: pot2barren_var) {
627       // remove the current potential from pot_list as, anyway, we will change
628       // it
629       const Potential< GUM_SCALAR >* pot = elt.first;
630       pot_list.erase(pot);
631 
632       // check whether we need to add a projected new potential or not (i.e.,
633       // whether there exist non-barren variables or not)
634       if (pot->variablesSequence().size() != elt.second.size()) {
635         auto new_pot = projector.project(*pot, elt.second);
636         pot_list.insert(new_pot);
637         projected_pots.insert(new_pot);
638       }
639     }
640 
641     return projected_pots;
642   }
643 
644 
645   // performs the collect phase of Lazy Propagation
646   template < typename GUM_SCALAR >
647   std::pair< Set< const Potential< GUM_SCALAR >* >, Set< const Potential< GUM_SCALAR >* > >
_collectMessage_(NodeId id,NodeId from)648      VariableElimination< GUM_SCALAR >::_collectMessage_(NodeId id, NodeId from) {
649     // collect messages from all the neighbors
650     std::pair< _PotentialSet_, _PotentialSet_ > collect_messages;
651     for (const auto other: _JT_->neighbours(id)) {
652       if (other != from) {
653         std::pair< _PotentialSet_, _PotentialSet_ > message(_collectMessage_(other, id));
654         collect_messages.first += message.first;
655         collect_messages.second += message.second;
656       }
657     }
658 
659     // combine the collect messages with those of id's clique
660     return _produceMessage_(id, from, std::move(collect_messages));
661   }
662 
663 
664   // get the CPT + evidence of a node projected w.r.t. hard evidence
665   template < typename GUM_SCALAR >
666   std::pair< Set< const Potential< GUM_SCALAR >* >, Set< const Potential< GUM_SCALAR >* > >
_NodePotentials_(NodeId node)667      VariableElimination< GUM_SCALAR >::_NodePotentials_(NodeId node) {
668     std::pair< _PotentialSet_, _PotentialSet_ > res;
669     const auto&                                 bn = this->BN();
670 
671     // get the CPT's of the node
672     // beware: all the potentials that are defined over some nodes
673     // including hard evidence must be projected so that these nodes are
674     // removed from the potential
675     // also beware that the CPT of a hard evidence node may be defined over
676     // parents that do not belong to  _graph_ and that are not hard evidence.
677     // In this case, those parents have been removed by d-separation and it is
678     // easy to show that, in this case all the parents have been removed, so
679     // that the CPT does not need to be taken into account
680     const auto& evidence      = this->evidence();
681     const auto& hard_evidence = this->hardEvidence();
682     if (_graph_.exists(node) || this->hardEvidenceNodes().contains(node)) {
683       const Potential< GUM_SCALAR >& cpt       = bn.cpt(node);
684       const auto&                    variables = cpt.variablesSequence();
685 
686       // check if the parents of a hard evidence node do not belong to  _graph_
687       // and are not themselves hard evidence, discard the CPT, it is useless
688       // for inference
689       if (this->hardEvidenceNodes().contains(node)) {
690         for (const auto var: variables) {
691           NodeId xnode = bn.nodeId(*var);
692           if (!this->hardEvidenceNodes().contains(xnode) && !_graph_.existsNode(xnode)) return res;
693         }
694       }
695 
696       // get the list of nodes with hard evidence in cpt
697       NodeSet hard_nodes;
698       for (const auto var: variables) {
699         NodeId xnode = bn.nodeId(*var);
700         if (this->hardEvidenceNodes().contains(xnode)) hard_nodes.insert(xnode);
701       }
702 
703       // if hard_nodes contains hard evidence nodes, perform a projection
704       // and insert the result into the appropriate clique, else insert
705       // directly cpt into the clique
706       if (hard_nodes.empty()) {
707         res.first.insert(&cpt);
708       } else {
709         // marginalize out the hard evidence nodes: if the cpt is defined
710         // only over nodes that received hard evidence, do not consider it
711         // as a potential anymore
712         if (hard_nodes.size() != variables.size()) {
713           // perform the projection with a combine and project instance
714           Set< const DiscreteVariable* > hard_variables;
715           _PotentialSet_                 marg_cpt_set{&cpt};
716           for (const auto xnode: hard_nodes) {
717             marg_cpt_set.insert(evidence[xnode]);
718             hard_variables.insert(&(bn.variable(xnode)));
719           }
720           // perform the combination of those potentials and their projection
721           MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential > combine_and_project(
722              _combination_op_,
723              VENewprojPotential);
724           _PotentialSet_ new_cpt_list
725              = combine_and_project.combineAndProject(marg_cpt_set, hard_variables);
726 
727           // there should be only one potential in new_cpt_list
728           if (new_cpt_list.size() != 1) {
729             // remove the CPT created to avoid memory leaks
730             for (auto pot: new_cpt_list) {
731               if (!marg_cpt_set.contains(pot)) delete pot;
732             }
733             GUM_ERROR(FatalError,
734                       "the projection of a potential containing "
735                          << "hard evidence is empty!");
736           }
737           const Potential< GUM_SCALAR >* projected_cpt = *(new_cpt_list.begin());
738           res.first.insert(projected_cpt);
739           res.second.insert(projected_cpt);
740         }
741       }
742 
743       // if the node received some soft evidence, add it
744       if (evidence.exists(node) && !hard_evidence.exists(node)) {
745         res.first.insert(this->evidence()[node]);
746       }
747     }
748 
749     return res;
750   }
751 
752 
753   // creates the message sent by clique from_id to clique to_id
754   template < typename GUM_SCALAR >
755   std::pair< Set< const Potential< GUM_SCALAR >* >, Set< const Potential< GUM_SCALAR >* > >
_produceMessage_(NodeId from_id,NodeId to_id,std::pair<Set<const Potential<GUM_SCALAR> * >,Set<const Potential<GUM_SCALAR> * >> && incoming_messages)756      VariableElimination< GUM_SCALAR >::_produceMessage_(
757         NodeId from_id,
758         NodeId to_id,
759         std::pair< Set< const Potential< GUM_SCALAR >* >, Set< const Potential< GUM_SCALAR >* > >&&
760            incoming_messages) {
761     // get the messages sent by adjacent nodes to from_id
762     std::pair< Set< const Potential< GUM_SCALAR >* >, Set< const Potential< GUM_SCALAR >* > >
763        pot_list(std::move(incoming_messages));
764 
765     // get the potentials of the clique
766     for (const auto node: _clique_potentials_[from_id]) {
767       auto new_pots = _NodePotentials_(node);
768       pot_list.first += new_pots.first;
769       pot_list.second += new_pots.second;
770     }
771 
772     // if from_id = to_id: this is the endpoint of a collect
773     if (!_JT_->existsEdge(from_id, to_id)) {
774       return pot_list;
775     } else {
776       // get the set of variables that need be removed from the potentials
777       const NodeSet&                 from_clique = _JT_->clique(from_id);
778       const NodeSet&                 separator   = _JT_->separator(from_id, to_id);
779       Set< const DiscreteVariable* > del_vars(from_clique.size());
780       Set< const DiscreteVariable* > kept_vars(separator.size());
781       const auto&                    bn = this->BN();
782 
783       for (const auto node: from_clique) {
784         if (!separator.contains(node)) {
785           del_vars.insert(&(bn.variable(node)));
786         } else {
787           kept_vars.insert(&(bn.variable(node)));
788         }
789       }
790 
791       // pot_list now contains all the potentials to multiply and marginalize
792       // => combine the messages
793       _PotentialSet_ new_pot_list = _marginalizeOut_(pot_list.first, del_vars, kept_vars);
794 
795       /*
796         for the moment, remove this test: due to some optimizations, some
797         potentials might have all their cells greater than 1.
798 
799       // remove all the potentials that are equal to ones (as probability
800       // matrix multiplications are tensorial, such potentials are useless)
801       for (auto iter = new_pot_list.beginSafe(); iter != new_pot_list.endSafe();
802            ++iter) {
803         const auto pot = *iter;
804         if (pot->variablesSequence().size() == 1) {
805           bool is_all_ones = true;
806           for (Instantiation inst(*pot); !inst.end(); ++inst) {
807             if ((*pot)[inst] <  _one_minus_epsilon_) {
808               is_all_ones = false;
809               break;
810             }
811           }
812           if (is_all_ones) {
813             if (!pot_list.first.exists(pot)) delete pot;
814             new_pot_list.erase(iter);
815             continue;
816           }
817         }
818       }
819       */
820 
821       // remove the unnecessary temporary messages
822       for (auto iter = pot_list.second.beginSafe(); iter != pot_list.second.endSafe(); ++iter) {
823         if (!new_pot_list.contains(*iter)) {
824           delete *iter;
825           pot_list.second.erase(iter);
826         }
827       }
828 
829       // keep track of all the newly created potentials
830       for (const auto pot: new_pot_list) {
831         if (!pot_list.first.contains(pot)) { pot_list.second.insert(pot); }
832       }
833 
834       // return the new set of potentials
835       return std::pair< _PotentialSet_, _PotentialSet_ >(std::move(new_pot_list),
836                                                          std::move(pot_list.second));
837     }
838   }
839 
840 
841   // remove variables del_vars from the list of potentials pot_list
842   template < typename GUM_SCALAR >
_marginalizeOut_(Set<const Potential<GUM_SCALAR> * > pot_list,Set<const DiscreteVariable * > & del_vars,Set<const DiscreteVariable * > & kept_vars)843   Set< const Potential< GUM_SCALAR >* > VariableElimination< GUM_SCALAR >::_marginalizeOut_(
844      Set< const Potential< GUM_SCALAR >* > pot_list,
845      Set< const DiscreteVariable* >&       del_vars,
846      Set< const DiscreteVariable* >&       kept_vars) {
847     // use d-separation analysis to check which potentials shall be combined
848     _findRelevantPotentialsXX_(pot_list, kept_vars);
849 
850     // remove the potentials corresponding to barren variables if we want
851     // to exploit barren nodes
852     _PotentialSet_ barren_projected_potentials;
853     if (_barren_nodes_type_ == FindBarrenNodesType::FIND_BARREN_NODES) {
854       barren_projected_potentials = _removeBarrenVariables_(pot_list, del_vars);
855     }
856 
857     // create a combine and project operator that will perform the
858     // marginalization
859     MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential > combine_and_project(_combination_op_,
860                                                                                   _projection_op_);
861     _PotentialSet_ new_pot_list = combine_and_project.combineAndProject(pot_list, del_vars);
862 
863     // remove all the potentials that were created due to projections of
864     // barren nodes and that are not part of the new_pot_list: these
865     // potentials were just temporary potentials
866     for (auto iter = barren_projected_potentials.beginSafe();
867          iter != barren_projected_potentials.endSafe();
868          ++iter) {
869       if (!new_pot_list.exists(*iter)) delete *iter;
870     }
871 
872     // remove all the potentials that have no dimension
873     for (auto iter_pot = new_pot_list.beginSafe(); iter_pot != new_pot_list.endSafe(); ++iter_pot) {
874       if ((*iter_pot)->variablesSequence().size() == 0) {
875         // as we have already marginalized out variables that received evidence,
876         // it may be the case that, after combining and projecting, some
877         // potentials might be empty. In this case, we shall keep their
878         // constant and remove them from memory
879         // # TODO: keep the constants!
880         delete *iter_pot;
881         new_pot_list.erase(iter_pot);
882       }
883     }
884 
885     return new_pot_list;
886   }
887 
888 
889   // performs a whole inference
890   template < typename GUM_SCALAR >
makeInference_()891   INLINE void VariableElimination< GUM_SCALAR >::makeInference_() {}
892 
893 
894   /// returns a fresh potential equal to P(1st arg,evidence)
895   template < typename GUM_SCALAR >
896   Potential< GUM_SCALAR >*
unnormalizedJointPosterior_(NodeId id)897      VariableElimination< GUM_SCALAR >::unnormalizedJointPosterior_(NodeId id) {
898     const auto& bn = this->BN();
899 
900     // hard evidence do not belong to the join tree
901     // # TODO: check for sets of inconsistent hard evidence
902     if (this->hardEvidenceNodes().contains(id)) {
903       return new Potential< GUM_SCALAR >(*(this->evidence()[id]));
904     }
905 
906     // if we still need to perform some inference task, do it
907     _createNewJT_(NodeSet{id});
908     NodeId clique_of_id = _node_to_clique_[id];
909     auto   pot_list     = _collectMessage_(clique_of_id, clique_of_id);
910 
911     // get the set of variables that need be removed from the potentials
912     const NodeSet&                 nodes = _JT_->clique(clique_of_id);
913     Set< const DiscreteVariable* > kept_vars{&(bn.variable(id))};
914     Set< const DiscreteVariable* > del_vars(nodes.size());
915     for (const auto node: nodes) {
916       if (node != id) del_vars.insert(&(bn.variable(node)));
917     }
918 
919     // pot_list now contains all the potentials to multiply and marginalize
920     // => combine the messages
921     _PotentialSet_           new_pot_list = _marginalizeOut_(pot_list.first, del_vars, kept_vars);
922     Potential< GUM_SCALAR >* joint        = nullptr;
923 
924     if (new_pot_list.size() == 1) {
925       joint = const_cast< Potential< GUM_SCALAR >* >(*(new_pot_list.begin()));
926       // if joint already existed, create a copy, so that we can put it into
927       // the  _target_posterior_ property
928       if (pot_list.first.exists(joint)) {
929         joint = new Potential< GUM_SCALAR >(*joint);
930       } else {
931         // remove the joint from new_pot_list so that it will not be
932         // removed just after the else block
933         new_pot_list.clear();
934       }
935     } else {
936       MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(_combination_op_);
937       joint = fast_combination.combine(new_pot_list);
938     }
939 
940     // remove the potentials that were created in new_pot_list
941     for (auto pot: new_pot_list)
942       if (!pot_list.first.exists(pot)) delete pot;
943 
944     // remove all the temporary potentials created in pot_list
945     for (auto pot: pot_list.second)
946       delete pot;
947 
948     // check that the joint posterior is different from a 0 vector: this would
949     // indicate that some hard evidence are not compatible (their joint
950     // probability is equal to 0)
951     bool nonzero_found = false;
952     for (Instantiation inst(*joint); !inst.end(); ++inst) {
953       if ((*joint)[inst]) {
954         nonzero_found = true;
955         break;
956       }
957     }
958     if (!nonzero_found) {
959       // remove joint from memory to avoid memory leaks
960       delete joint;
961       GUM_ERROR(IncompatibleEvidence,
962                 "some evidence entered into the Bayes "
963                 "net are incompatible (their joint proba = 0)");
964     }
965 
966     return joint;
967   }
968 
969 
970   /// returns the posterior of a given variable
971   template < typename GUM_SCALAR >
posterior_(NodeId id)972   const Potential< GUM_SCALAR >& VariableElimination< GUM_SCALAR >::posterior_(NodeId id) {
973     // compute the joint posterior and normalize
974     auto joint = unnormalizedJointPosterior_(id);
975     if (joint->sum() != 1)   // hard test for ReadOnly CPT (as aggregator)
976       joint->normalize();
977 
978     if (_target_posterior_ != nullptr) delete _target_posterior_;
979     _target_posterior_ = joint;
980 
981     return *joint;
982   }
983 
984 
985   // returns the marginal a posteriori proba of a given node
986   template < typename GUM_SCALAR >
987   Potential< GUM_SCALAR >*
unnormalizedJointPosterior_(const NodeSet & set)988      VariableElimination< GUM_SCALAR >::unnormalizedJointPosterior_(const NodeSet& set) {
989     // hard evidence do not belong to the join tree, so extract the nodes
990     // from targets that are not hard evidence
991     NodeSet targets = set, hard_ev_nodes;
992     for (const auto node: this->hardEvidenceNodes()) {
993       if (targets.contains(node)) {
994         targets.erase(node);
995         hard_ev_nodes.insert(node);
996       }
997     }
998 
999     // if all the nodes have received hard evidence, then compute the
1000     // joint posterior directly by multiplying the hard evidence potentials
1001     const auto& evidence = this->evidence();
1002     if (targets.empty()) {
1003       _PotentialSet_ pot_list;
1004       for (const auto node: set) {
1005         pot_list.insert(evidence[node]);
1006       }
1007       if (pot_list.size() == 1) {
1008         return new Potential< GUM_SCALAR >(**(pot_list.begin()));
1009       } else {
1010         MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(_combination_op_);
1011         return fast_combination.combine(pot_list);
1012       }
1013     }
1014 
1015     // if we still need to perform some inference task, do it
1016     _createNewJT_(set);
1017     auto pot_list = _collectMessage_(_targets2clique_, _targets2clique_);
1018 
1019     // get the set of variables that need be removed from the potentials
1020     const NodeSet&                 nodes = _JT_->clique(_targets2clique_);
1021     Set< const DiscreteVariable* > del_vars(nodes.size());
1022     Set< const DiscreteVariable* > kept_vars(targets.size());
1023     const auto&                    bn = this->BN();
1024     for (const auto node: nodes) {
1025       if (!targets.contains(node)) {
1026         del_vars.insert(&(bn.variable(node)));
1027       } else {
1028         kept_vars.insert(&(bn.variable(node)));
1029       }
1030     }
1031 
1032     // pot_list now contains all the potentials to multiply and marginalize
1033     // => combine the messages
1034     _PotentialSet_           new_pot_list = _marginalizeOut_(pot_list.first, del_vars, kept_vars);
1035     Potential< GUM_SCALAR >* joint        = nullptr;
1036 
1037     if ((new_pot_list.size() == 1) && hard_ev_nodes.empty()) {
1038       joint = const_cast< Potential< GUM_SCALAR >* >(*(new_pot_list.begin()));
1039       // if pot already existed, create a copy, so that we can put it into
1040       // the  _target_posteriors_ property
1041       if (pot_list.first.exists(joint)) {
1042         joint = new Potential< GUM_SCALAR >(*joint);
1043       } else {
1044         // remove the joint from new_pot_list so that it will not be
1045         // removed just after the next else block
1046         new_pot_list.clear();
1047       }
1048     } else {
1049       // combine all the potentials in new_pot_list with all the hard evidence
1050       // of the nodes in set
1051       _PotentialSet_ new_new_pot_list = new_pot_list;
1052       for (const auto node: hard_ev_nodes) {
1053         new_new_pot_list.insert(evidence[node]);
1054       }
1055       MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(_combination_op_);
1056       joint = fast_combination.combine(new_new_pot_list);
1057     }
1058 
1059     // remove the potentials that were created in new_pot_list
1060     for (auto pot: new_pot_list)
1061       if (!pot_list.first.exists(pot)) delete pot;
1062 
1063     // remove all the temporary potentials created in pot_list
1064     for (auto pot: pot_list.second)
1065       delete pot;
1066 
1067     // check that the joint posterior is different from a 0 vector: this would
1068     // indicate that some hard evidence are not compatible
1069     bool nonzero_found = false;
1070     for (Instantiation inst(*joint); !inst.end(); ++inst) {
1071       if ((*joint)[inst]) {
1072         nonzero_found = true;
1073         break;
1074       }
1075     }
1076     if (!nonzero_found) {
1077       // remove joint from memory to avoid memory leaks
1078       delete joint;
1079       GUM_ERROR(IncompatibleEvidence,
1080                 "some evidence entered into the Bayes "
1081                 "net are incompatible (their joint proba = 0)");
1082     }
1083 
1084     return joint;
1085   }
1086 
1087 
1088   /// returns the posterior of a given set of variables
1089   template < typename GUM_SCALAR >
1090   const Potential< GUM_SCALAR >&
jointPosterior_(const NodeSet & set)1091      VariableElimination< GUM_SCALAR >::jointPosterior_(const NodeSet& set) {
1092     // compute the joint posterior and normalize
1093     auto joint = unnormalizedJointPosterior_(set);
1094     joint->normalize();
1095 
1096     if (_target_posterior_ != nullptr) delete _target_posterior_;
1097     _target_posterior_ = joint;
1098 
1099     return *joint;
1100   }
1101 
1102 
1103   /// returns the posterior of a given set of variables
1104   template < typename GUM_SCALAR >
1105   const Potential< GUM_SCALAR >&
jointPosterior_(const NodeSet & wanted_target,const NodeSet & declared_target)1106      VariableElimination< GUM_SCALAR >::jointPosterior_(const NodeSet& wanted_target,
1107                                                         const NodeSet& declared_target) {
1108     return jointPosterior_(wanted_target);
1109   }
1110 
1111 
1112 } /* namespace gum */
1113 
1114 #endif   // DOXYGEN_SHOULD_SKIP_THIS
1115