1 /**
2  *
3  *   Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU)
4  *   info_at_agrum_dot_org
5  *
6  *  This library is free software: you can redistribute it and/or modify
7  *  it under the terms of the GNU Lesser General Public License as published by
8  *  the Free Software Foundation, either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  This library is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU Lesser General Public License for more details.
15  *
16  *  You should have received a copy of the GNU Lesser General Public License
17  *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Implementation of the BayesNetFactory class.
25  *
26  * @author Lionel TORTI and Pierre-Henri WUILLEMIN(_at_LIP6)
27 
28  */
29 
30 #include <agrum/BN/BayesNetFactory.h>
31 
32 namespace gum {
33 
34   // Default constructor.
35   // @param bn A pointer over the BayesNet filled by this factory.
36   // @throw DuplicateElement Raised if two variables in bn share the same
37   //                         name.
38   template < typename GUM_SCALAR >
BayesNetFactory(BayesNet<GUM_SCALAR> * bn)39   INLINE BayesNetFactory< GUM_SCALAR >::BayesNetFactory(BayesNet< GUM_SCALAR >* bn) :
40       _parents_(0), _impl_(0), _bn_(bn) {
41     GUM_CONSTRUCTOR(BayesNetFactory);
42     _states_.push_back(factory_state::NONE);
43 
44     for (auto node: bn->nodes()) {
45       if (_varNameMap_.exists(bn->variable(node).name()))
46         GUM_ERROR(DuplicateElement, "Name already used: " << bn->variable(node).name())
47 
48       _varNameMap_.insert(bn->variable(node).name(), node);
49     }
50 
51     resetVerbose();
52   }
53 
54   // Copy constructor.
55   // The copy will have an exact copy of the constructed BayesNet in source.
56   template < typename GUM_SCALAR >
57   INLINE
BayesNetFactory(const BayesNetFactory<GUM_SCALAR> & source)58      BayesNetFactory< GUM_SCALAR >::BayesNetFactory(const BayesNetFactory< GUM_SCALAR >& source) :
59       _parents_(nullptr),
60       _impl_(nullptr), _bn_(nullptr) {
61     GUM_CONS_CPY(BayesNetFactory);
62 
63     if (source.state() != factory_state::NONE) {
64       GUM_ERROR(OperationNotAllowed, "Illegal state to proceed make a copy.")
65     } else {
66       _states_ = source._states_;
67       _bn_     = new BayesNet< GUM_SCALAR >(*(source._bn_));
68     }
69   }
70 
71   // Destructor
72   template < typename GUM_SCALAR >
~BayesNetFactory()73   INLINE BayesNetFactory< GUM_SCALAR >::~BayesNetFactory() {
74     GUM_DESTRUCTOR(BayesNetFactory);
75 
76     if (_parents_ != nullptr) delete _parents_;
77 
78     if (_impl_ != nullptr) {
79       //@todo better than throwing an exception from inside a destructor but
80       // still ...
81       std::cerr << "[BN factory] Implementation defined for a variable but not used. "
82                    "You should call endVariableDeclaration() before "
83                    "deleting me."
84                 << std::endl;
85       exit(1);
86     }
87   }
88 
89   // Returns the BayesNet created by this factory.
90   template < typename GUM_SCALAR >
bayesNet()91   INLINE BayesNet< GUM_SCALAR >* BayesNetFactory< GUM_SCALAR >::bayesNet() {
92     return _bn_;
93   }
94 
95   template < typename GUM_SCALAR >
varInBN(NodeId id)96   INLINE const DiscreteVariable& BayesNetFactory< GUM_SCALAR >::varInBN(NodeId id) {
97     return _bn_->variable(id);
98   }
99 
100   // Returns the current state of the factory.
101   template < typename GUM_SCALAR >
state()102   INLINE IBayesNetFactory::factory_state BayesNetFactory< GUM_SCALAR >::state() const {
103     // This is ok because there is always at least the state NONE in the stack.
104     return _states_.back();
105   }
106 
107   // Returns the NodeId of a variable given it's name.
108   // @throw NotFound Raised if no variable matches the name.
109   template < typename GUM_SCALAR >
variableId(const std::string & name)110   INLINE NodeId BayesNetFactory< GUM_SCALAR >::variableId(const std::string& name) const {
111     try {
112       return _varNameMap_[name];
113     } catch (NotFound&) { GUM_ERROR(NotFound, name) }
114   }
115 
116   // Returns a constant reference on a variable given it's name.
117   // @throw NotFound Raised if no variable matches the name.
118   template < typename GUM_SCALAR >
119   INLINE const DiscreteVariable&
variable(const std::string & name)120                BayesNetFactory< GUM_SCALAR >::variable(const std::string& name) const {
121     try {
122       return _bn_->variable(variableId(name));
123     } catch (NotFound&) { GUM_ERROR(NotFound, name) }
124   }
125 
126   // Returns the domainSize of the cpt for the node n.
127   // @throw NotFound raised if no such NodeId exists.
128   // @throw OperationNotAllowed if there is no Bayesian networks.
129   template < typename GUM_SCALAR >
cptDomainSize(const NodeId n)130   INLINE Size BayesNetFactory< GUM_SCALAR >::cptDomainSize(const NodeId n) const {
131     return _bn_->cpt(n).domainSize();
132   }
133 
134   // Tells the factory that we're in a network declaration.
135   template < typename GUM_SCALAR >
startNetworkDeclaration()136   INLINE void BayesNetFactory< GUM_SCALAR >::startNetworkDeclaration() {
137     if (state() != factory_state::NONE) {
138       _illegalStateError_("startNetworkDeclaration");
139     } else {
140       _states_.push_back(factory_state::NETWORK);
141     }
142   }
143 
144   // Tells the factory to add a property to the current network.
145   template < typename GUM_SCALAR >
addNetworkProperty(const std::string & propName,const std::string & propValue)146   INLINE void BayesNetFactory< GUM_SCALAR >::addNetworkProperty(const std::string& propName,
147                                                                 const std::string& propValue) {
148     _bn_->setProperty(propName, propValue);
149   }
150 
151   // Tells the factory that we're out of a network declaration.
152   template < typename GUM_SCALAR >
endNetworkDeclaration()153   INLINE void BayesNetFactory< GUM_SCALAR >::endNetworkDeclaration() {
154     if (state() != factory_state::NETWORK) {
155       _illegalStateError_("endNetworkDeclaration");
156     } else {
157       _states_.pop_back();
158     }
159   }
160 
161   // Tells the factory that we're in a variable declaration.
162   // A variable is considered as a LabelizedVariable while its type is not defined.
163   template < typename GUM_SCALAR >
startVariableDeclaration()164   INLINE void BayesNetFactory< GUM_SCALAR >::startVariableDeclaration() {
165     if (state() != factory_state::NONE) {
166       _illegalStateError_("startVariableDeclaration");
167     } else {
168       _states_.push_back(factory_state::VARIABLE);
169       _stringBag_.push_back("name");
170       _stringBag_.push_back("desc");
171       _stringBag_.push_back("L");
172     }
173   }
174 
175   // Tells the factory the current variable's name.
176   template < typename GUM_SCALAR >
variableName(const std::string & name)177   INLINE void BayesNetFactory< GUM_SCALAR >::variableName(const std::string& name) {
178     if (state() != factory_state::VARIABLE) {
179       _illegalStateError_("variableName");
180     } else {
181       if (_varNameMap_.exists(name)) { GUM_ERROR(DuplicateElement, "Name already used: " << name) }
182 
183       _foo_flag_     = true;
184       _stringBag_[0] = name;
185     }
186   }
187 
188   // Tells the factory the current variable's description.
189   template < typename GUM_SCALAR >
variableDescription(const std::string & desc)190   INLINE void BayesNetFactory< GUM_SCALAR >::variableDescription(const std::string& desc) {
191     if (state() != factory_state::VARIABLE) {
192       _illegalStateError_("variableDescription");
193     } else {
194       _bar_flag_     = true;
195       _stringBag_[1] = desc;
196     }
197   }
198 
199   // Tells the factory the current variable's type.
200   // L : Labelized
201   // R : Range
202   // C : Continuous
203   // D : Discretized
204   template < typename GUM_SCALAR >
variableType(const gum::VarType & type)205   INLINE void BayesNetFactory< GUM_SCALAR >::variableType(const gum::VarType& type) {
206     if (state() != factory_state::VARIABLE) {
207       _illegalStateError_("variableType");
208     } else {
209       switch (type) {
210         case VarType::Discretized:
211           _stringBag_[2] = "D";
212           break;
213         case VarType::Range:
214           _stringBag_[2] = "R";
215           break;
216         case VarType::Integer:
217           _stringBag_[2] = "I";
218           break;
219         case VarType::Continuous:
220           GUM_ERROR(OperationNotAllowed,
221                     "Continuous variable (" + _stringBag_[0]
222                        + ") are not supported in Bayesian networks.")
223         case VarType::Labelized:
224           _stringBag_[2] = "L";
225           break;
226       }
227     }
228   }
229 
230   // Adds a modality to the current variable.
231   // @throw DuplicateElement If the current variable already has a modality
232   //                         with the same name.
233   template < typename GUM_SCALAR >
addModality(const std::string & name)234   INLINE void BayesNetFactory< GUM_SCALAR >::addModality(const std::string& name) {
235     if (state() != factory_state::VARIABLE) {
236       _illegalStateError_("addModality");
237     } else {
238       _checkModalityInBag_(name);
239       _stringBag_.push_back(name);
240     }
241   }
242 
243   // Adds a modality to the current variable.
244   // @throw DuplicateElement If the current variable already has a modality
245   //                         with the same name.
246   template < typename GUM_SCALAR >
addMin(const long & min)247   INLINE void BayesNetFactory< GUM_SCALAR >::addMin(const long& min) {
248     if (state() != factory_state::VARIABLE) {
249       _illegalStateError_("addMin");
250     } else {
251       _stringBag_.push_back(std::to_string(min));
252     }
253   }
254 
255   // Adds a modality to the current variable.
256   // @throw DuplicateElement If the current variable already has a modality
257   //                         with the same name.
258   template < typename GUM_SCALAR >
addMax(const long & max)259   INLINE void BayesNetFactory< GUM_SCALAR >::addMax(const long& max) {
260     if (state() != factory_state::VARIABLE) {
261       _illegalStateError_("addMin");
262     } else {
263       _stringBag_.push_back(std::to_string(max));
264     }
265   }
266 
267   // Adds a modality to the current variable.
268   // @throw DuplicateElement If the current variable already has a modality
269   //                         with the same name.
270   template < typename GUM_SCALAR >
addTick(const GUM_SCALAR & tick)271   INLINE void BayesNetFactory< GUM_SCALAR >::addTick(const GUM_SCALAR& tick) {
272     if (state() != factory_state::VARIABLE) {
273       _illegalStateError_("addTick");
274     } else {
275       _stringBag_.push_back(std::to_string(tick));
276     }
277   }
278 
279   // @brief Defines the implementation to use for Potential.
280   // @warning The implementation must be empty.
281   // @warning The pointer is always delegated to Potential! No copy of it
282   //          is made.
283   // @todo When copy of a MultiDimImplementation is available use a copy
284   //       behaviour for this method.
285   // @throw NotFound Raised if no variable matches var.
286   // @throw OperationNotAllowed Raised if impl is not empty.
287   // @throw OperationNotAllowed If an implementation is already defined for the
288   //                            current variable.
289   template < typename GUM_SCALAR >
290   INLINE void
setVariableCPTImplementation(MultiDimAdressable * adressable)291      BayesNetFactory< GUM_SCALAR >::setVariableCPTImplementation(MultiDimAdressable* adressable) {
292     MultiDimImplementation< GUM_SCALAR >* impl
293        = dynamic_cast< MultiDimImplementation< GUM_SCALAR >* >(adressable);
294 
295     if (state() != factory_state::VARIABLE) {
296       _illegalStateError_("setVariableCPTImplementation");
297     } else {
298       if (impl == 0) {
299         GUM_ERROR(OperationNotAllowed,
300                   "An implementation for this variable is already "
301                   "defined.")
302       } else if (impl->nbrDim() > 0) {
303         GUM_ERROR(OperationNotAllowed, "This implementation is not empty.")
304       }
305 
306       _impl_ = impl;
307     }
308   }
309 
310   // Tells the factory that we're out of a variable declaration.
311   template < typename GUM_SCALAR >
endVariableDeclaration()312   INLINE NodeId BayesNetFactory< GUM_SCALAR >::endVariableDeclaration() {
313     if (state() != factory_state::VARIABLE) {
314       _illegalStateError_("endVariableDeclaration");
315     } else if (_foo_flag_ && (_stringBag_.size() > 4)) {
316       DiscreteVariable* var = nullptr;
317 
318       // if the current variable is a LabelizedVariable
319       if (_stringBag_[2] == "L") {
320         LabelizedVariable* l
321            = new LabelizedVariable(_stringBag_[0], (_bar_flag_) ? _stringBag_[1] : "", 0);
322 
323         for (size_t i = 3; i < _stringBag_.size(); ++i) {
324           l->addLabel(_stringBag_[i]);
325         }
326 
327         var = l;
328         // if the current variable is a RangeVariable
329       } else if (_stringBag_[2] == "I") {
330         // try to create the domain of the variable
331         std::vector< int > domain;
332         for (size_t i = 3; i < _stringBag_.size(); ++i) {
333           domain.push_back(std::stoi(_stringBag_[i]));
334         }
335 
336         IntegerVariable* v
337            = new IntegerVariable(_stringBag_[0], (_bar_flag_) ? _stringBag_[1] : "", domain);
338         var = v;
339       } else if (_stringBag_[2] == "R") {
340         RangeVariable* r = new RangeVariable(_stringBag_[0],
341                                              (_bar_flag_) ? _stringBag_[1] : "",
342                                              std::stol(_stringBag_[3]),
343                                              std::stol(_stringBag_[4]));
344 
345         var = r;
346         // if the current variable is a DiscretizedVariable
347       } else if (_stringBag_[2] == "D") {
348         DiscretizedVariable< GUM_SCALAR >* d
349            = new DiscretizedVariable< GUM_SCALAR >(_stringBag_[0],
350                                                    (_bar_flag_) ? _stringBag_[1] : "");
351 
352         for (size_t i = 3; i < _stringBag_.size(); ++i) {
353           d->addTick(std::stof(_stringBag_[i]));
354         }
355 
356         var = d;
357       }
358 
359       if (_impl_ != 0) {
360         _varNameMap_.insert(var->name(), _bn_->add(*var, _impl_));
361         _impl_ = 0;
362       } else {
363         _varNameMap_.insert(var->name(), _bn_->add(*var));
364       }
365 
366       NodeId retVal = _varNameMap_[var->name()];
367 
368       delete var;
369 
370       _resetParts_();
371       _states_.pop_back();
372 
373       return retVal;
374     } else {
375       std::stringstream msg;
376       msg << "Not enough modalities (";
377 
378       if (_stringBag_.size() > 3) {
379         msg << _stringBag_.size() - 3;
380       } else {
381         msg << 0;
382       }
383 
384       msg << ") declared for variable ";
385 
386       if (_foo_flag_) {
387         msg << _stringBag_[0];
388       } else {
389         msg << "unknown";
390       }
391 
392       _resetParts_();
393 
394       _states_.pop_back();
395       GUM_ERROR(OperationNotAllowed, msg.str())
396     }
397 
398     // For noisy compilers
399     return 0;
400   }
401 
402   // Tells the factory that we're declaring parents for some variable.
403   // @var The concerned variable's name.
404   template < typename GUM_SCALAR >
startParentsDeclaration(const std::string & var)405   INLINE void BayesNetFactory< GUM_SCALAR >::startParentsDeclaration(const std::string& var) {
406     if (state() != factory_state::NONE) {
407       _illegalStateError_("startParentsDeclaration");
408     } else {
409       _checkVariableName_(var);
410       std::vector< std::string >::iterator iter = _stringBag_.begin();
411       _stringBag_.insert(iter, var);
412       _states_.push_back(factory_state::PARENTS);
413     }
414   }
415 
416   // Tells the factory for which variable we're declaring parents.
417   // @var The parent's name.
418   // @throw NotFound Raised if var does not exists.
419   template < typename GUM_SCALAR >
addParent(const std::string & var)420   INLINE void BayesNetFactory< GUM_SCALAR >::addParent(const std::string& var) {
421     if (state() != factory_state::PARENTS) {
422       _illegalStateError_("addParent");
423     } else {
424       _checkVariableName_(var);
425       _stringBag_.push_back(var);
426     }
427   }
428 
429   // Tells the factory that we've finished declaring parents for some
430   // variable. When parents exist, endParentsDeclaration creates some arcs.
431   // These arcs are created in the inverse order of the order of the parent
432   // specification.
433   template < typename GUM_SCALAR >
endParentsDeclaration()434   INLINE void BayesNetFactory< GUM_SCALAR >::endParentsDeclaration() {
435     if (state() != factory_state::PARENTS) {
436       _illegalStateError_("endParentsDeclaration");
437     } else {
438       NodeId id = _varNameMap_[_stringBag_[0]];
439 
440       // PLEASE NOTE THAT THE ORDER IS INVERSE
441 
442       for (size_t i = _stringBag_.size() - 1; i > 0; --i) {
443         _bn_->addArc(_varNameMap_[_stringBag_[i]], id);
444       }
445 
446       _resetParts_();
447 
448       _states_.pop_back();
449     }
450   }
451 
452   // Tells the factory that we're declaring a conditional probability table
453   // for some variable.
454   // @param var The concerned variable's name.
455   template < typename GUM_SCALAR >
456   INLINE void
startRawProbabilityDeclaration(const std::string & var)457      BayesNetFactory< GUM_SCALAR >::startRawProbabilityDeclaration(const std::string& var) {
458     if (state() != factory_state::NONE) {
459       _illegalStateError_("startRawProbabilityDeclaration");
460     } else {
461       _checkVariableName_(var);
462       _stringBag_.push_back(var);
463       _states_.push_back(factory_state::RAW_CPT);
464     }
465   }
466 
467   // @brief Fills the variable's table with the values in rawTable.
468   // Parse the parents in the same order in which they were added to the
469   // variable.
470   // Given a sequence [var, p_1, p_2, ...,p_n-1, p_n] of parents, modalities are
471   // parsed
472   // in the given order (if all p_i are binary):
473   // [0, 0, ..., 0, 0], [0, 0, ..., 0, 1],
474   // [0, 0, ..., 1, 0], [0, 0, ..., 1, 1],
475   // ...,
476   // [1, 1, ..., 1, 0], [1, 1, ..., 1, 1].
477   // @param rawTable The raw table.
478   template < typename GUM_SCALAR >
479   INLINE void
rawConditionalTable(const std::vector<std::string> & variables,const std::vector<float> & rawTable)480      BayesNetFactory< GUM_SCALAR >::rawConditionalTable(const std::vector< std::string >& variables,
481                                                         const std::vector< float >& rawTable) {
482     if (state() != factory_state::RAW_CPT) {
483       _illegalStateError_("rawConditionalTable");
484     } else {
485       _fillProbaWithValuesTable_(variables, rawTable);
486     }
487   }
488 
489   template < typename GUM_SCALAR >
_fillProbaWithValuesTable_(const std::vector<std::string> & variables,const std::vector<float> & rawTable)490   INLINE void BayesNetFactory< GUM_SCALAR >::_fillProbaWithValuesTable_(
491      const std::vector< std::string >& variables,
492      const std::vector< float >&       rawTable) {
493     const Potential< GUM_SCALAR >& table = _bn_->cpt(_varNameMap_[_stringBag_[0]]);
494     Instantiation                  cptInst(table);
495 
496     List< const DiscreteVariable* > varList;
497 
498     for (size_t i = 0; i < variables.size(); ++i) {
499       varList.pushBack(&(_bn_->variable(_varNameMap_[variables[i]])));
500     }
501 
502     // varList.pushFront(&( _bn_->variable( _varNameMap_[ _stringBag_[0]])));
503 
504     Idx nbrVar = varList.size();
505 
506     std::vector< Idx > modCounter;
507 
508     // initializing the array
509     for (NodeId i = 0; i < nbrVar; i++) {
510       modCounter.push_back(Idx(0));
511     }
512 
513     Idx j = 0;
514 
515     do {
516       for (NodeId i = 0; i < nbrVar; i++) {
517         cptInst.chgVal(*(varList[i]), modCounter[i]);
518       }
519 
520       if (j < rawTable.size()) {
521         table.set(cptInst, (GUM_SCALAR)rawTable[j]);
522       } else {
523         table.set(cptInst, (GUM_SCALAR)0);
524       }
525 
526       j++;
527     } while (_increment_(modCounter, varList));
528   }
529 
530   template < typename GUM_SCALAR >
531   INLINE void
rawConditionalTable(const std::vector<float> & rawTable)532      BayesNetFactory< GUM_SCALAR >::rawConditionalTable(const std::vector< float >& rawTable) {
533     if (state() != factory_state::RAW_CPT) {
534       _illegalStateError_("rawConditionalTable");
535     } else {
536       _fillProbaWithValuesTable_(rawTable);
537     }
538   }
539 
540   template < typename GUM_SCALAR >
_fillProbaWithValuesTable_(const std::vector<float> & rawTable)541   INLINE void BayesNetFactory< GUM_SCALAR >::_fillProbaWithValuesTable_(
542      const std::vector< float >& rawTable) {
543     const Potential< GUM_SCALAR >& table = _bn_->cpt(_varNameMap_[_stringBag_[0]]);
544 
545     Instantiation cptInst(table);
546 
547     // the main loop is on the first variables. The others are in the right
548     // order.
549     const DiscreteVariable& first = table.variable(0);
550     Idx                     j     = 0;
551 
552     for (cptInst.setFirstVar(first); !cptInst.end(); cptInst.incVar(first)) {
553       for (cptInst.setFirstNotVar(first); !cptInst.end(); cptInst.incNotVar(first))
554         table.set(cptInst, (j < rawTable.size()) ? (GUM_SCALAR)rawTable[j++] : (GUM_SCALAR)0);
555 
556       cptInst.unsetEnd();
557     }
558   }
559 
560   template < typename GUM_SCALAR >
_increment_(std::vector<gum::Idx> & modCounter,List<const DiscreteVariable * > & varList)561   INLINE bool BayesNetFactory< GUM_SCALAR >::_increment_(std::vector< gum::Idx >& modCounter,
562                                                          List< const DiscreteVariable* >& varList) {
563     bool last = true;
564 
565     for (NodeId j = 0; j < modCounter.size(); j++) {
566       last = (modCounter[j] == (varList[j]->domainSize() - 1)) && last;
567 
568       if (!last) break;
569     }
570 
571     if (last) { return false; }
572 
573     bool add = false;
574 
575     NodeId i = NodeId(varList.size() - 1);
576 
577     do {
578       if (modCounter[i] == (varList[i]->domainSize() - 1)) {
579         modCounter[i] = 0;
580         add           = true;
581       } else {
582         modCounter[i] += 1;
583         add = false;
584       }
585 
586       i--;
587     } while (add);
588 
589     return true;
590   }
591 
592   // Tells the factory that we finished declaring a conditional probability
593   // table.
594   template < typename GUM_SCALAR >
endRawProbabilityDeclaration()595   INLINE void BayesNetFactory< GUM_SCALAR >::endRawProbabilityDeclaration() {
596     if (state() != factory_state::RAW_CPT) {
597       _illegalStateError_("endRawProbabilityDeclaration");
598     } else {
599       _resetParts_();
600       _states_.pop_back();
601     }
602   }
603 
604   // Tells the factory that we're starting a factorized declaration.
605   template < typename GUM_SCALAR >
606   INLINE void
startFactorizedProbabilityDeclaration(const std::string & var)607      BayesNetFactory< GUM_SCALAR >::startFactorizedProbabilityDeclaration(const std::string& var) {
608     if (state() != factory_state::NONE) {
609       _illegalStateError_("startFactorizedProbabilityDeclaration");
610     } else {
611       _checkVariableName_(var);
612       std::vector< std::string >::iterator iter = _stringBag_.begin();
613       _stringBag_.insert(iter, var);
614       _states_.push_back(factory_state::FACT_CPT);
615     }
616   }
617 
618   // Tells the factory that we start an entry of a factorized conditional
619   // probability table.
620   template < typename GUM_SCALAR >
startFactorizedEntry()621   INLINE void BayesNetFactory< GUM_SCALAR >::startFactorizedEntry() {
622     if (state() != factory_state::FACT_CPT) {
623       _illegalStateError_("startFactorizedEntry");
624     } else {
625       _parents_ = new Instantiation();
626       _states_.push_back(factory_state::FACT_ENTRY);
627     }
628   }
629 
630   // Tells the factory that we finished declaring a conditional probability
631   // table.
632   template < typename GUM_SCALAR >
endFactorizedEntry()633   INLINE void BayesNetFactory< GUM_SCALAR >::endFactorizedEntry() {
634     if (state() != factory_state::FACT_ENTRY) {
635       _illegalStateError_("endFactorizedEntry");
636     } else {
637       delete _parents_;
638       _parents_ = 0;
639       _states_.pop_back();
640     }
641   }
642 
643   // Tells the factory on which modality we want to instantiate one of
644   // variable's parent.
645   template < typename GUM_SCALAR >
setParentModality(const std::string & parent,const std::string & modality)646   INLINE void BayesNetFactory< GUM_SCALAR >::setParentModality(const std::string& parent,
647                                                                const std::string& modality) {
648     if (state() != factory_state::FACT_ENTRY) {
649       _illegalStateError_("string");
650     } else {
651       _checkVariableName_(parent);
652       Idx id = _checkVariableModality_(parent, modality);
653       (*_parents_) << _bn_->variable(_varNameMap_[parent]);
654       _parents_->chgVal(_bn_->variable(_varNameMap_[parent]), id);
655     }
656   }
657 
658   // @brief Gives the values of the variable with respect to precedent
659   //        parents modality.
660   // If some parents have no modality set, then we apply values for all
661   // instantiations of that parent.
662   //
663   // This means you can declare a default value for the table by doing
664   // @code
665   // BayesNetFactory factory;
666   // // Do stuff
667   // factory.startVariableDeclaration();
668   // factory.variableName("foo");
669   // factory.endVariableDeclaration();
670   // factory.startParentsDeclaration("foo");
671   // // add parents
672   // factory.endParentsDeclaration();
673   // factory.startFactorizedProbabilityDeclaration("foo");
674   // std::vector<float> seq;
675   // seq.insert(0.4); // if foo true
676   // seq.insert(O.6); // if foo false
677   // factory.setVariableValues(seq); // fills the table with a default value
678   // // finish your stuff
679   // factory.endFactorizedProbabilityDeclaration();
680   // @code
681   // as for raw Probability, if value's size is different than the number of
682   // modalities of the current variable, we don't use the supplementary values and
683   // we fill by 0 the missing values.
684   template < typename GUM_SCALAR >
685   INLINE void
setVariableValuesUnchecked(const std::vector<float> & values)686      BayesNetFactory< GUM_SCALAR >::setVariableValuesUnchecked(const std::vector< float >& values) {
687     if (state() != factory_state::FACT_ENTRY) {
688       _illegalStateError_("setVariableValues");
689     } else {
690       const DiscreteVariable& var   = _bn_->variable(_varNameMap_[_stringBag_[0]]);
691       NodeId                  varId = _varNameMap_[_stringBag_[0]];
692 
693       if (_parents_->domainSize() > 0) {
694         Instantiation inst(_bn_->cpt(_varNameMap_[var.name()]));
695         inst.setVals(*_parents_);
696         // Creating an instantiation containing all the variables not ins
697         //  _parents_.
698         Instantiation inst_default;
699         inst_default << var;
700 
701         for (auto node: _bn_->parents(varId)) {
702           if (!_parents_->contains(_bn_->variable(node))) { inst_default << _bn_->variable(node); }
703         }
704 
705         // Filling the variable's table.
706         for (inst.setFirstIn(inst_default); !inst.end(); inst.incIn(inst_default)) {
707           (_bn_->cpt(varId))
708              .set(inst,
709                   inst.val(var) < values.size() ? (GUM_SCALAR)values[inst.val(var)]
710                                                 : (GUM_SCALAR)0);
711         }
712       } else {
713         Instantiation inst(_bn_->cpt(_varNameMap_[var.name()]));
714         Instantiation var_inst;
715         var_inst << var;
716 
717         for (var_inst.setFirst(); !var_inst.end(); ++var_inst) {
718           inst.setVals(var_inst);
719 
720           for (inst.setFirstOut(var_inst); !inst.end(); inst.incOut(var_inst)) {
721             (_bn_->cpt(varId))
722                .set(inst,
723                     inst.val(var) < values.size() ? (GUM_SCALAR)values[inst.val(var)]
724                                                   : (GUM_SCALAR)0);
725           }
726         }
727       }
728     }
729   }
730 
731   template < typename GUM_SCALAR >
setVariableValues(const std::vector<float> & values)732   INLINE void BayesNetFactory< GUM_SCALAR >::setVariableValues(const std::vector< float >& values) {
733     if (state() != factory_state::FACT_ENTRY) {
734       _illegalStateError_("setVariableValues");
735     } else {
736       const DiscreteVariable& var = _bn_->variable(_varNameMap_[_stringBag_[0]]);
737       //     Checking consistency between values and var.
738 
739       if (values.size() != var.domainSize()) {
740         GUM_ERROR(OperationNotAllowed,
741                   var.name() << " : invalid number of modalities: found " << values.size()
742                              << " while needed " << var.domainSize())
743       }
744 
745       setVariableValuesUnchecked(values);
746     }
747   }
748 
749   // Tells the factory that we finished declaring a conditional probability
750   // table.
751   template < typename GUM_SCALAR >
endFactorizedProbabilityDeclaration()752   INLINE void BayesNetFactory< GUM_SCALAR >::endFactorizedProbabilityDeclaration() {
753     if (state() != factory_state::FACT_CPT) {
754       _illegalStateError_("endFactorizedProbabilityDeclaration");
755     } else {
756       _resetParts_();
757       _states_.pop_back();
758     }
759   }
760 
761   // @brief Define a variable.
762   // You can only call this method is the factory is in the NONE or NETWORK
763   // state.
764   // The variable is added by copy.
765   // @param var The pointer over a DiscreteVariable used to define a new
766   //            variable in the built BayesNet.
767   // @throw DuplicateElement Raised if a variable with the same name already
768   //                         exists.
769   // @throw OperationNotAllowed Raised if redefineParents == false and if table
770   //                            is not a valid CPT for var in the current state
771   //                            of the BayesNet.
772   template < typename GUM_SCALAR >
setVariable(const DiscreteVariable & var)773   INLINE void BayesNetFactory< GUM_SCALAR >::setVariable(const DiscreteVariable& var) {
774     if ((state() != factory_state::NONE)) {
775       _illegalStateError_("setVariable");
776     } else {
777       try {
778         _checkVariableName_(var.name());
779         GUM_ERROR(DuplicateElement, "Name already used: " << var.name())
780       } catch (NotFound&) {
781         // The var name is unused
782         _varNameMap_.insert(var.name(), _bn_->add(var));
783       }
784     }
785   }
786 
787   // @brief Define a variable's CPT.
788   // You can only call this method if the factory is in the NONE or NETWORK
789   // state.
790   // Be careful that table is given to the built BayesNet, so it will be
791   // deleted with it, and you should not directly access it after you call
792   // this method.
793   // When the redefineParents flag is set to true the constructed BayesNet's
794   // DAG is changed to fit with table's definition.
795   // @param var The name of the concerned variable.
796   // @param table A pointer over the CPT used for var.
797   // @param redefineParents If true redefine parents of the variable to match
798   // table's
799   //                        variables set.
800   //
801   // @throw NotFound Raised if no variable matches var.
802   // @throw OperationNotAllowed Raised if redefineParents == false and if table
803   //                            is not a valid CPT for var in the current state
804   //                            of the BayesNet.
805   template < typename GUM_SCALAR >
setVariableCPT(const std::string & varName,MultiDimAdressable * table,bool redefineParents)806   INLINE void BayesNetFactory< GUM_SCALAR >::setVariableCPT(const std::string&  varName,
807                                                             MultiDimAdressable* table,
808                                                             bool                redefineParents) {
809     auto pot = dynamic_cast< Potential< GUM_SCALAR >* >(table);
810 
811     if (state() != factory_state::NONE) {
812       _illegalStateError_("setVariableCPT");
813     } else {
814       _checkVariableName_(varName);
815       const DiscreteVariable& var   = _bn_->variable(_varNameMap_[varName]);
816       NodeId                  varId = _varNameMap_[varName];
817       // If we have to change the structure of the BayesNet, then we call a sub
818       // method.
819 
820       if (redefineParents) {
821         _setCPTAndParents_(var, pot);
822       } else if (pot->contains(var)) {
823         for (auto node: _bn_->parents(varId)) {
824           if (!pot->contains(_bn_->variable(node))) {
825             GUM_ERROR(OperationNotAllowed, "The CPT is not valid in the current BayesNet.")
826           }
827         }
828 
829         // CPT are created when a variable is added.
830         _bn_->_unsafeChangePotential_(varId, pot);
831       }
832     }
833   }
834 
835   // Raise an OperationNotAllowed with the message "Illegal state."
836   template < typename GUM_SCALAR >
_illegalStateError_(const std::string & s)837   INLINE void BayesNetFactory< GUM_SCALAR >::_illegalStateError_(const std::string& s) {
838     std::string msg = "Illegal state call (";
839     msg += s;
840     msg += ") in state ";
841 
842     switch (state()) {
843       case factory_state::NONE: {
844         msg += "NONE";
845         break;
846       }
847 
848       case factory_state::NETWORK: {
849         msg += "NETWORK";
850         break;
851       }
852 
853       case factory_state::VARIABLE: {
854         msg += "VARIABLE";
855         break;
856       }
857 
858       case factory_state::PARENTS: {
859         msg += "PARENTS";
860         break;
861       }
862 
863       case factory_state::RAW_CPT: {
864         msg += "RAW_CPT";
865         break;
866       }
867 
868       case factory_state::FACT_CPT: {
869         msg += "FACT_CPT";
870         break;
871       }
872 
873       case factory_state::FACT_ENTRY: {
874         msg += "FACT_ENTRY";
875         break;
876       }
877 
878       default: {
879         msg += "Unknown state";
880       }
881     }
882 
883     GUM_ERROR(OperationNotAllowed, msg)
884   }
885 
886   // Check if a variable with the given name exists, if not raise an NotFound
887   // exception.
888   template < typename GUM_SCALAR >
_checkVariableName_(const std::string & name)889   INLINE void BayesNetFactory< GUM_SCALAR >::_checkVariableName_(const std::string& name) {
890     if (!_varNameMap_.exists(name)) { GUM_ERROR(NotFound, name) }
891   }
892 
893   // Check if var exists and if mod is one of it's modality, if not raise an
894   // NotFound exception.
895   template < typename GUM_SCALAR >
_checkVariableModality_(const std::string & name,const std::string & mod)896   INLINE Idx BayesNetFactory< GUM_SCALAR >::_checkVariableModality_(const std::string& name,
897                                                                     const std::string& mod) {
898     _checkVariableName_(name);
899     const DiscreteVariable& var = _bn_->variable(_varNameMap_[name]);
900 
901     for (Idx i = 0; i < var.domainSize(); ++i) {
902       if (mod == var.label(i)) { return i; }
903     }
904 
905     GUM_ERROR(NotFound, mod)
906   }
907 
908   // Check if in  _stringBag_ there is no other modality with the same name.
909   template < typename GUM_SCALAR >
_checkModalityInBag_(const std::string & mod)910   INLINE void BayesNetFactory< GUM_SCALAR >::_checkModalityInBag_(const std::string& mod) {
911     for (size_t i = 3; i < _stringBag_.size(); ++i) {
912       if (mod == _stringBag_[i]) { GUM_ERROR(DuplicateElement, "Label already used: " << mod) }
913     }
914   }
915 
916   // Sub method of setVariableCPT() which redefine the BayesNet's DAG with
917   // respect to table.
918   template < typename GUM_SCALAR >
_setCPTAndParents_(const DiscreteVariable & var,Potential<GUM_SCALAR> * table)919   INLINE void BayesNetFactory< GUM_SCALAR >::_setCPTAndParents_(const DiscreteVariable&  var,
920                                                                 Potential< GUM_SCALAR >* table) {
921     NodeId varId = _varNameMap_[var.name()];
922     _bn_->dag_.eraseParents(varId);
923 
924     for (auto v: table->variablesSequence()) {
925       if (v != (&var)) {
926         _checkVariableName_(v->name());
927         _bn_->dag_.addArc(_varNameMap_[v->name()], varId);
928       }
929     }
930 
931     // CPT are created when a variable is added.
932     _bn_->_unsafeChangePotential_(varId, table);
933   }
934 
935   // Reset the different parts used to constructed the BayesNet.
936   template < typename GUM_SCALAR >
_resetParts_()937   INLINE void BayesNetFactory< GUM_SCALAR >::_resetParts_() {
938     _foo_flag_ = false;
939     _bar_flag_ = false;
940     _stringBag_.clear();
941   }
942 } /* namespace gum */