1 /** 2 * 3 * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU) 4 * info_at_agrum_dot_org 5 * 6 * This library is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU Lesser General Public License as published by 8 * the Free Software Foundation, either version 3 of the License, or 9 * (at your option) any later version. 10 * 11 * This library is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 * GNU Lesser General Public License for more details. 15 * 16 * You should have received a copy of the GNU Lesser General Public License 17 * along with this library. If not, see <http://www.gnu.org/licenses/>. 18 * 19 */ 20 21 22 /** 23 * @file 24 * @brief Implementation of the BayesNetFactory class. 25 * 26 * @author Lionel TORTI and Pierre-Henri WUILLEMIN(_at_LIP6) 27 28 */ 29 30 #include <agrum/BN/BayesNetFactory.h> 31 32 namespace gum { 33 34 // Default constructor. 35 // @param bn A pointer over the BayesNet filled by this factory. 36 // @throw DuplicateElement Raised if two variables in bn share the same 37 // name. 38 template < typename GUM_SCALAR > BayesNetFactory(BayesNet<GUM_SCALAR> * bn)39 INLINE BayesNetFactory< GUM_SCALAR >::BayesNetFactory(BayesNet< GUM_SCALAR >* bn) : 40 _parents_(0), _impl_(0), _bn_(bn) { 41 GUM_CONSTRUCTOR(BayesNetFactory); 42 _states_.push_back(factory_state::NONE); 43 44 for (auto node: bn->nodes()) { 45 if (_varNameMap_.exists(bn->variable(node).name())) 46 GUM_ERROR(DuplicateElement, "Name already used: " << bn->variable(node).name()) 47 48 _varNameMap_.insert(bn->variable(node).name(), node); 49 } 50 51 resetVerbose(); 52 } 53 54 // Copy constructor. 55 // The copy will have an exact copy of the constructed BayesNet in source. 56 template < typename GUM_SCALAR > 57 INLINE BayesNetFactory(const BayesNetFactory<GUM_SCALAR> & source)58 BayesNetFactory< GUM_SCALAR >::BayesNetFactory(const BayesNetFactory< GUM_SCALAR >& source) : 59 _parents_(nullptr), 60 _impl_(nullptr), _bn_(nullptr) { 61 GUM_CONS_CPY(BayesNetFactory); 62 63 if (source.state() != factory_state::NONE) { 64 GUM_ERROR(OperationNotAllowed, "Illegal state to proceed make a copy.") 65 } else { 66 _states_ = source._states_; 67 _bn_ = new BayesNet< GUM_SCALAR >(*(source._bn_)); 68 } 69 } 70 71 // Destructor 72 template < typename GUM_SCALAR > ~BayesNetFactory()73 INLINE BayesNetFactory< GUM_SCALAR >::~BayesNetFactory() { 74 GUM_DESTRUCTOR(BayesNetFactory); 75 76 if (_parents_ != nullptr) delete _parents_; 77 78 if (_impl_ != nullptr) { 79 //@todo better than throwing an exception from inside a destructor but 80 // still ... 81 std::cerr << "[BN factory] Implementation defined for a variable but not used. " 82 "You should call endVariableDeclaration() before " 83 "deleting me." 84 << std::endl; 85 exit(1); 86 } 87 } 88 89 // Returns the BayesNet created by this factory. 90 template < typename GUM_SCALAR > bayesNet()91 INLINE BayesNet< GUM_SCALAR >* BayesNetFactory< GUM_SCALAR >::bayesNet() { 92 return _bn_; 93 } 94 95 template < typename GUM_SCALAR > varInBN(NodeId id)96 INLINE const DiscreteVariable& BayesNetFactory< GUM_SCALAR >::varInBN(NodeId id) { 97 return _bn_->variable(id); 98 } 99 100 // Returns the current state of the factory. 101 template < typename GUM_SCALAR > state()102 INLINE IBayesNetFactory::factory_state BayesNetFactory< GUM_SCALAR >::state() const { 103 // This is ok because there is always at least the state NONE in the stack. 104 return _states_.back(); 105 } 106 107 // Returns the NodeId of a variable given it's name. 108 // @throw NotFound Raised if no variable matches the name. 109 template < typename GUM_SCALAR > variableId(const std::string & name)110 INLINE NodeId BayesNetFactory< GUM_SCALAR >::variableId(const std::string& name) const { 111 try { 112 return _varNameMap_[name]; 113 } catch (NotFound&) { GUM_ERROR(NotFound, name) } 114 } 115 116 // Returns a constant reference on a variable given it's name. 117 // @throw NotFound Raised if no variable matches the name. 118 template < typename GUM_SCALAR > 119 INLINE const DiscreteVariable& variable(const std::string & name)120 BayesNetFactory< GUM_SCALAR >::variable(const std::string& name) const { 121 try { 122 return _bn_->variable(variableId(name)); 123 } catch (NotFound&) { GUM_ERROR(NotFound, name) } 124 } 125 126 // Returns the domainSize of the cpt for the node n. 127 // @throw NotFound raised if no such NodeId exists. 128 // @throw OperationNotAllowed if there is no Bayesian networks. 129 template < typename GUM_SCALAR > cptDomainSize(const NodeId n)130 INLINE Size BayesNetFactory< GUM_SCALAR >::cptDomainSize(const NodeId n) const { 131 return _bn_->cpt(n).domainSize(); 132 } 133 134 // Tells the factory that we're in a network declaration. 135 template < typename GUM_SCALAR > startNetworkDeclaration()136 INLINE void BayesNetFactory< GUM_SCALAR >::startNetworkDeclaration() { 137 if (state() != factory_state::NONE) { 138 _illegalStateError_("startNetworkDeclaration"); 139 } else { 140 _states_.push_back(factory_state::NETWORK); 141 } 142 } 143 144 // Tells the factory to add a property to the current network. 145 template < typename GUM_SCALAR > addNetworkProperty(const std::string & propName,const std::string & propValue)146 INLINE void BayesNetFactory< GUM_SCALAR >::addNetworkProperty(const std::string& propName, 147 const std::string& propValue) { 148 _bn_->setProperty(propName, propValue); 149 } 150 151 // Tells the factory that we're out of a network declaration. 152 template < typename GUM_SCALAR > endNetworkDeclaration()153 INLINE void BayesNetFactory< GUM_SCALAR >::endNetworkDeclaration() { 154 if (state() != factory_state::NETWORK) { 155 _illegalStateError_("endNetworkDeclaration"); 156 } else { 157 _states_.pop_back(); 158 } 159 } 160 161 // Tells the factory that we're in a variable declaration. 162 // A variable is considered as a LabelizedVariable while its type is not defined. 163 template < typename GUM_SCALAR > startVariableDeclaration()164 INLINE void BayesNetFactory< GUM_SCALAR >::startVariableDeclaration() { 165 if (state() != factory_state::NONE) { 166 _illegalStateError_("startVariableDeclaration"); 167 } else { 168 _states_.push_back(factory_state::VARIABLE); 169 _stringBag_.push_back("name"); 170 _stringBag_.push_back("desc"); 171 _stringBag_.push_back("L"); 172 } 173 } 174 175 // Tells the factory the current variable's name. 176 template < typename GUM_SCALAR > variableName(const std::string & name)177 INLINE void BayesNetFactory< GUM_SCALAR >::variableName(const std::string& name) { 178 if (state() != factory_state::VARIABLE) { 179 _illegalStateError_("variableName"); 180 } else { 181 if (_varNameMap_.exists(name)) { GUM_ERROR(DuplicateElement, "Name already used: " << name) } 182 183 _foo_flag_ = true; 184 _stringBag_[0] = name; 185 } 186 } 187 188 // Tells the factory the current variable's description. 189 template < typename GUM_SCALAR > variableDescription(const std::string & desc)190 INLINE void BayesNetFactory< GUM_SCALAR >::variableDescription(const std::string& desc) { 191 if (state() != factory_state::VARIABLE) { 192 _illegalStateError_("variableDescription"); 193 } else { 194 _bar_flag_ = true; 195 _stringBag_[1] = desc; 196 } 197 } 198 199 // Tells the factory the current variable's type. 200 // L : Labelized 201 // R : Range 202 // C : Continuous 203 // D : Discretized 204 template < typename GUM_SCALAR > variableType(const gum::VarType & type)205 INLINE void BayesNetFactory< GUM_SCALAR >::variableType(const gum::VarType& type) { 206 if (state() != factory_state::VARIABLE) { 207 _illegalStateError_("variableType"); 208 } else { 209 switch (type) { 210 case VarType::Discretized: 211 _stringBag_[2] = "D"; 212 break; 213 case VarType::Range: 214 _stringBag_[2] = "R"; 215 break; 216 case VarType::Integer: 217 _stringBag_[2] = "I"; 218 break; 219 case VarType::Continuous: 220 GUM_ERROR(OperationNotAllowed, 221 "Continuous variable (" + _stringBag_[0] 222 + ") are not supported in Bayesian networks.") 223 case VarType::Labelized: 224 _stringBag_[2] = "L"; 225 break; 226 } 227 } 228 } 229 230 // Adds a modality to the current variable. 231 // @throw DuplicateElement If the current variable already has a modality 232 // with the same name. 233 template < typename GUM_SCALAR > addModality(const std::string & name)234 INLINE void BayesNetFactory< GUM_SCALAR >::addModality(const std::string& name) { 235 if (state() != factory_state::VARIABLE) { 236 _illegalStateError_("addModality"); 237 } else { 238 _checkModalityInBag_(name); 239 _stringBag_.push_back(name); 240 } 241 } 242 243 // Adds a modality to the current variable. 244 // @throw DuplicateElement If the current variable already has a modality 245 // with the same name. 246 template < typename GUM_SCALAR > addMin(const long & min)247 INLINE void BayesNetFactory< GUM_SCALAR >::addMin(const long& min) { 248 if (state() != factory_state::VARIABLE) { 249 _illegalStateError_("addMin"); 250 } else { 251 _stringBag_.push_back(std::to_string(min)); 252 } 253 } 254 255 // Adds a modality to the current variable. 256 // @throw DuplicateElement If the current variable already has a modality 257 // with the same name. 258 template < typename GUM_SCALAR > addMax(const long & max)259 INLINE void BayesNetFactory< GUM_SCALAR >::addMax(const long& max) { 260 if (state() != factory_state::VARIABLE) { 261 _illegalStateError_("addMin"); 262 } else { 263 _stringBag_.push_back(std::to_string(max)); 264 } 265 } 266 267 // Adds a modality to the current variable. 268 // @throw DuplicateElement If the current variable already has a modality 269 // with the same name. 270 template < typename GUM_SCALAR > addTick(const GUM_SCALAR & tick)271 INLINE void BayesNetFactory< GUM_SCALAR >::addTick(const GUM_SCALAR& tick) { 272 if (state() != factory_state::VARIABLE) { 273 _illegalStateError_("addTick"); 274 } else { 275 _stringBag_.push_back(std::to_string(tick)); 276 } 277 } 278 279 // @brief Defines the implementation to use for Potential. 280 // @warning The implementation must be empty. 281 // @warning The pointer is always delegated to Potential! No copy of it 282 // is made. 283 // @todo When copy of a MultiDimImplementation is available use a copy 284 // behaviour for this method. 285 // @throw NotFound Raised if no variable matches var. 286 // @throw OperationNotAllowed Raised if impl is not empty. 287 // @throw OperationNotAllowed If an implementation is already defined for the 288 // current variable. 289 template < typename GUM_SCALAR > 290 INLINE void setVariableCPTImplementation(MultiDimAdressable * adressable)291 BayesNetFactory< GUM_SCALAR >::setVariableCPTImplementation(MultiDimAdressable* adressable) { 292 MultiDimImplementation< GUM_SCALAR >* impl 293 = dynamic_cast< MultiDimImplementation< GUM_SCALAR >* >(adressable); 294 295 if (state() != factory_state::VARIABLE) { 296 _illegalStateError_("setVariableCPTImplementation"); 297 } else { 298 if (impl == 0) { 299 GUM_ERROR(OperationNotAllowed, 300 "An implementation for this variable is already " 301 "defined.") 302 } else if (impl->nbrDim() > 0) { 303 GUM_ERROR(OperationNotAllowed, "This implementation is not empty.") 304 } 305 306 _impl_ = impl; 307 } 308 } 309 310 // Tells the factory that we're out of a variable declaration. 311 template < typename GUM_SCALAR > endVariableDeclaration()312 INLINE NodeId BayesNetFactory< GUM_SCALAR >::endVariableDeclaration() { 313 if (state() != factory_state::VARIABLE) { 314 _illegalStateError_("endVariableDeclaration"); 315 } else if (_foo_flag_ && (_stringBag_.size() > 4)) { 316 DiscreteVariable* var = nullptr; 317 318 // if the current variable is a LabelizedVariable 319 if (_stringBag_[2] == "L") { 320 LabelizedVariable* l 321 = new LabelizedVariable(_stringBag_[0], (_bar_flag_) ? _stringBag_[1] : "", 0); 322 323 for (size_t i = 3; i < _stringBag_.size(); ++i) { 324 l->addLabel(_stringBag_[i]); 325 } 326 327 var = l; 328 // if the current variable is a RangeVariable 329 } else if (_stringBag_[2] == "I") { 330 // try to create the domain of the variable 331 std::vector< int > domain; 332 for (size_t i = 3; i < _stringBag_.size(); ++i) { 333 domain.push_back(std::stoi(_stringBag_[i])); 334 } 335 336 IntegerVariable* v 337 = new IntegerVariable(_stringBag_[0], (_bar_flag_) ? _stringBag_[1] : "", domain); 338 var = v; 339 } else if (_stringBag_[2] == "R") { 340 RangeVariable* r = new RangeVariable(_stringBag_[0], 341 (_bar_flag_) ? _stringBag_[1] : "", 342 std::stol(_stringBag_[3]), 343 std::stol(_stringBag_[4])); 344 345 var = r; 346 // if the current variable is a DiscretizedVariable 347 } else if (_stringBag_[2] == "D") { 348 DiscretizedVariable< GUM_SCALAR >* d 349 = new DiscretizedVariable< GUM_SCALAR >(_stringBag_[0], 350 (_bar_flag_) ? _stringBag_[1] : ""); 351 352 for (size_t i = 3; i < _stringBag_.size(); ++i) { 353 d->addTick(std::stof(_stringBag_[i])); 354 } 355 356 var = d; 357 } 358 359 if (_impl_ != 0) { 360 _varNameMap_.insert(var->name(), _bn_->add(*var, _impl_)); 361 _impl_ = 0; 362 } else { 363 _varNameMap_.insert(var->name(), _bn_->add(*var)); 364 } 365 366 NodeId retVal = _varNameMap_[var->name()]; 367 368 delete var; 369 370 _resetParts_(); 371 _states_.pop_back(); 372 373 return retVal; 374 } else { 375 std::stringstream msg; 376 msg << "Not enough modalities ("; 377 378 if (_stringBag_.size() > 3) { 379 msg << _stringBag_.size() - 3; 380 } else { 381 msg << 0; 382 } 383 384 msg << ") declared for variable "; 385 386 if (_foo_flag_) { 387 msg << _stringBag_[0]; 388 } else { 389 msg << "unknown"; 390 } 391 392 _resetParts_(); 393 394 _states_.pop_back(); 395 GUM_ERROR(OperationNotAllowed, msg.str()) 396 } 397 398 // For noisy compilers 399 return 0; 400 } 401 402 // Tells the factory that we're declaring parents for some variable. 403 // @var The concerned variable's name. 404 template < typename GUM_SCALAR > startParentsDeclaration(const std::string & var)405 INLINE void BayesNetFactory< GUM_SCALAR >::startParentsDeclaration(const std::string& var) { 406 if (state() != factory_state::NONE) { 407 _illegalStateError_("startParentsDeclaration"); 408 } else { 409 _checkVariableName_(var); 410 std::vector< std::string >::iterator iter = _stringBag_.begin(); 411 _stringBag_.insert(iter, var); 412 _states_.push_back(factory_state::PARENTS); 413 } 414 } 415 416 // Tells the factory for which variable we're declaring parents. 417 // @var The parent's name. 418 // @throw NotFound Raised if var does not exists. 419 template < typename GUM_SCALAR > addParent(const std::string & var)420 INLINE void BayesNetFactory< GUM_SCALAR >::addParent(const std::string& var) { 421 if (state() != factory_state::PARENTS) { 422 _illegalStateError_("addParent"); 423 } else { 424 _checkVariableName_(var); 425 _stringBag_.push_back(var); 426 } 427 } 428 429 // Tells the factory that we've finished declaring parents for some 430 // variable. When parents exist, endParentsDeclaration creates some arcs. 431 // These arcs are created in the inverse order of the order of the parent 432 // specification. 433 template < typename GUM_SCALAR > endParentsDeclaration()434 INLINE void BayesNetFactory< GUM_SCALAR >::endParentsDeclaration() { 435 if (state() != factory_state::PARENTS) { 436 _illegalStateError_("endParentsDeclaration"); 437 } else { 438 NodeId id = _varNameMap_[_stringBag_[0]]; 439 440 // PLEASE NOTE THAT THE ORDER IS INVERSE 441 442 for (size_t i = _stringBag_.size() - 1; i > 0; --i) { 443 _bn_->addArc(_varNameMap_[_stringBag_[i]], id); 444 } 445 446 _resetParts_(); 447 448 _states_.pop_back(); 449 } 450 } 451 452 // Tells the factory that we're declaring a conditional probability table 453 // for some variable. 454 // @param var The concerned variable's name. 455 template < typename GUM_SCALAR > 456 INLINE void startRawProbabilityDeclaration(const std::string & var)457 BayesNetFactory< GUM_SCALAR >::startRawProbabilityDeclaration(const std::string& var) { 458 if (state() != factory_state::NONE) { 459 _illegalStateError_("startRawProbabilityDeclaration"); 460 } else { 461 _checkVariableName_(var); 462 _stringBag_.push_back(var); 463 _states_.push_back(factory_state::RAW_CPT); 464 } 465 } 466 467 // @brief Fills the variable's table with the values in rawTable. 468 // Parse the parents in the same order in which they were added to the 469 // variable. 470 // Given a sequence [var, p_1, p_2, ...,p_n-1, p_n] of parents, modalities are 471 // parsed 472 // in the given order (if all p_i are binary): 473 // [0, 0, ..., 0, 0], [0, 0, ..., 0, 1], 474 // [0, 0, ..., 1, 0], [0, 0, ..., 1, 1], 475 // ..., 476 // [1, 1, ..., 1, 0], [1, 1, ..., 1, 1]. 477 // @param rawTable The raw table. 478 template < typename GUM_SCALAR > 479 INLINE void rawConditionalTable(const std::vector<std::string> & variables,const std::vector<float> & rawTable)480 BayesNetFactory< GUM_SCALAR >::rawConditionalTable(const std::vector< std::string >& variables, 481 const std::vector< float >& rawTable) { 482 if (state() != factory_state::RAW_CPT) { 483 _illegalStateError_("rawConditionalTable"); 484 } else { 485 _fillProbaWithValuesTable_(variables, rawTable); 486 } 487 } 488 489 template < typename GUM_SCALAR > _fillProbaWithValuesTable_(const std::vector<std::string> & variables,const std::vector<float> & rawTable)490 INLINE void BayesNetFactory< GUM_SCALAR >::_fillProbaWithValuesTable_( 491 const std::vector< std::string >& variables, 492 const std::vector< float >& rawTable) { 493 const Potential< GUM_SCALAR >& table = _bn_->cpt(_varNameMap_[_stringBag_[0]]); 494 Instantiation cptInst(table); 495 496 List< const DiscreteVariable* > varList; 497 498 for (size_t i = 0; i < variables.size(); ++i) { 499 varList.pushBack(&(_bn_->variable(_varNameMap_[variables[i]]))); 500 } 501 502 // varList.pushFront(&( _bn_->variable( _varNameMap_[ _stringBag_[0]]))); 503 504 Idx nbrVar = varList.size(); 505 506 std::vector< Idx > modCounter; 507 508 // initializing the array 509 for (NodeId i = 0; i < nbrVar; i++) { 510 modCounter.push_back(Idx(0)); 511 } 512 513 Idx j = 0; 514 515 do { 516 for (NodeId i = 0; i < nbrVar; i++) { 517 cptInst.chgVal(*(varList[i]), modCounter[i]); 518 } 519 520 if (j < rawTable.size()) { 521 table.set(cptInst, (GUM_SCALAR)rawTable[j]); 522 } else { 523 table.set(cptInst, (GUM_SCALAR)0); 524 } 525 526 j++; 527 } while (_increment_(modCounter, varList)); 528 } 529 530 template < typename GUM_SCALAR > 531 INLINE void rawConditionalTable(const std::vector<float> & rawTable)532 BayesNetFactory< GUM_SCALAR >::rawConditionalTable(const std::vector< float >& rawTable) { 533 if (state() != factory_state::RAW_CPT) { 534 _illegalStateError_("rawConditionalTable"); 535 } else { 536 _fillProbaWithValuesTable_(rawTable); 537 } 538 } 539 540 template < typename GUM_SCALAR > _fillProbaWithValuesTable_(const std::vector<float> & rawTable)541 INLINE void BayesNetFactory< GUM_SCALAR >::_fillProbaWithValuesTable_( 542 const std::vector< float >& rawTable) { 543 const Potential< GUM_SCALAR >& table = _bn_->cpt(_varNameMap_[_stringBag_[0]]); 544 545 Instantiation cptInst(table); 546 547 // the main loop is on the first variables. The others are in the right 548 // order. 549 const DiscreteVariable& first = table.variable(0); 550 Idx j = 0; 551 552 for (cptInst.setFirstVar(first); !cptInst.end(); cptInst.incVar(first)) { 553 for (cptInst.setFirstNotVar(first); !cptInst.end(); cptInst.incNotVar(first)) 554 table.set(cptInst, (j < rawTable.size()) ? (GUM_SCALAR)rawTable[j++] : (GUM_SCALAR)0); 555 556 cptInst.unsetEnd(); 557 } 558 } 559 560 template < typename GUM_SCALAR > _increment_(std::vector<gum::Idx> & modCounter,List<const DiscreteVariable * > & varList)561 INLINE bool BayesNetFactory< GUM_SCALAR >::_increment_(std::vector< gum::Idx >& modCounter, 562 List< const DiscreteVariable* >& varList) { 563 bool last = true; 564 565 for (NodeId j = 0; j < modCounter.size(); j++) { 566 last = (modCounter[j] == (varList[j]->domainSize() - 1)) && last; 567 568 if (!last) break; 569 } 570 571 if (last) { return false; } 572 573 bool add = false; 574 575 NodeId i = NodeId(varList.size() - 1); 576 577 do { 578 if (modCounter[i] == (varList[i]->domainSize() - 1)) { 579 modCounter[i] = 0; 580 add = true; 581 } else { 582 modCounter[i] += 1; 583 add = false; 584 } 585 586 i--; 587 } while (add); 588 589 return true; 590 } 591 592 // Tells the factory that we finished declaring a conditional probability 593 // table. 594 template < typename GUM_SCALAR > endRawProbabilityDeclaration()595 INLINE void BayesNetFactory< GUM_SCALAR >::endRawProbabilityDeclaration() { 596 if (state() != factory_state::RAW_CPT) { 597 _illegalStateError_("endRawProbabilityDeclaration"); 598 } else { 599 _resetParts_(); 600 _states_.pop_back(); 601 } 602 } 603 604 // Tells the factory that we're starting a factorized declaration. 605 template < typename GUM_SCALAR > 606 INLINE void startFactorizedProbabilityDeclaration(const std::string & var)607 BayesNetFactory< GUM_SCALAR >::startFactorizedProbabilityDeclaration(const std::string& var) { 608 if (state() != factory_state::NONE) { 609 _illegalStateError_("startFactorizedProbabilityDeclaration"); 610 } else { 611 _checkVariableName_(var); 612 std::vector< std::string >::iterator iter = _stringBag_.begin(); 613 _stringBag_.insert(iter, var); 614 _states_.push_back(factory_state::FACT_CPT); 615 } 616 } 617 618 // Tells the factory that we start an entry of a factorized conditional 619 // probability table. 620 template < typename GUM_SCALAR > startFactorizedEntry()621 INLINE void BayesNetFactory< GUM_SCALAR >::startFactorizedEntry() { 622 if (state() != factory_state::FACT_CPT) { 623 _illegalStateError_("startFactorizedEntry"); 624 } else { 625 _parents_ = new Instantiation(); 626 _states_.push_back(factory_state::FACT_ENTRY); 627 } 628 } 629 630 // Tells the factory that we finished declaring a conditional probability 631 // table. 632 template < typename GUM_SCALAR > endFactorizedEntry()633 INLINE void BayesNetFactory< GUM_SCALAR >::endFactorizedEntry() { 634 if (state() != factory_state::FACT_ENTRY) { 635 _illegalStateError_("endFactorizedEntry"); 636 } else { 637 delete _parents_; 638 _parents_ = 0; 639 _states_.pop_back(); 640 } 641 } 642 643 // Tells the factory on which modality we want to instantiate one of 644 // variable's parent. 645 template < typename GUM_SCALAR > setParentModality(const std::string & parent,const std::string & modality)646 INLINE void BayesNetFactory< GUM_SCALAR >::setParentModality(const std::string& parent, 647 const std::string& modality) { 648 if (state() != factory_state::FACT_ENTRY) { 649 _illegalStateError_("string"); 650 } else { 651 _checkVariableName_(parent); 652 Idx id = _checkVariableModality_(parent, modality); 653 (*_parents_) << _bn_->variable(_varNameMap_[parent]); 654 _parents_->chgVal(_bn_->variable(_varNameMap_[parent]), id); 655 } 656 } 657 658 // @brief Gives the values of the variable with respect to precedent 659 // parents modality. 660 // If some parents have no modality set, then we apply values for all 661 // instantiations of that parent. 662 // 663 // This means you can declare a default value for the table by doing 664 // @code 665 // BayesNetFactory factory; 666 // // Do stuff 667 // factory.startVariableDeclaration(); 668 // factory.variableName("foo"); 669 // factory.endVariableDeclaration(); 670 // factory.startParentsDeclaration("foo"); 671 // // add parents 672 // factory.endParentsDeclaration(); 673 // factory.startFactorizedProbabilityDeclaration("foo"); 674 // std::vector<float> seq; 675 // seq.insert(0.4); // if foo true 676 // seq.insert(O.6); // if foo false 677 // factory.setVariableValues(seq); // fills the table with a default value 678 // // finish your stuff 679 // factory.endFactorizedProbabilityDeclaration(); 680 // @code 681 // as for raw Probability, if value's size is different than the number of 682 // modalities of the current variable, we don't use the supplementary values and 683 // we fill by 0 the missing values. 684 template < typename GUM_SCALAR > 685 INLINE void setVariableValuesUnchecked(const std::vector<float> & values)686 BayesNetFactory< GUM_SCALAR >::setVariableValuesUnchecked(const std::vector< float >& values) { 687 if (state() != factory_state::FACT_ENTRY) { 688 _illegalStateError_("setVariableValues"); 689 } else { 690 const DiscreteVariable& var = _bn_->variable(_varNameMap_[_stringBag_[0]]); 691 NodeId varId = _varNameMap_[_stringBag_[0]]; 692 693 if (_parents_->domainSize() > 0) { 694 Instantiation inst(_bn_->cpt(_varNameMap_[var.name()])); 695 inst.setVals(*_parents_); 696 // Creating an instantiation containing all the variables not ins 697 // _parents_. 698 Instantiation inst_default; 699 inst_default << var; 700 701 for (auto node: _bn_->parents(varId)) { 702 if (!_parents_->contains(_bn_->variable(node))) { inst_default << _bn_->variable(node); } 703 } 704 705 // Filling the variable's table. 706 for (inst.setFirstIn(inst_default); !inst.end(); inst.incIn(inst_default)) { 707 (_bn_->cpt(varId)) 708 .set(inst, 709 inst.val(var) < values.size() ? (GUM_SCALAR)values[inst.val(var)] 710 : (GUM_SCALAR)0); 711 } 712 } else { 713 Instantiation inst(_bn_->cpt(_varNameMap_[var.name()])); 714 Instantiation var_inst; 715 var_inst << var; 716 717 for (var_inst.setFirst(); !var_inst.end(); ++var_inst) { 718 inst.setVals(var_inst); 719 720 for (inst.setFirstOut(var_inst); !inst.end(); inst.incOut(var_inst)) { 721 (_bn_->cpt(varId)) 722 .set(inst, 723 inst.val(var) < values.size() ? (GUM_SCALAR)values[inst.val(var)] 724 : (GUM_SCALAR)0); 725 } 726 } 727 } 728 } 729 } 730 731 template < typename GUM_SCALAR > setVariableValues(const std::vector<float> & values)732 INLINE void BayesNetFactory< GUM_SCALAR >::setVariableValues(const std::vector< float >& values) { 733 if (state() != factory_state::FACT_ENTRY) { 734 _illegalStateError_("setVariableValues"); 735 } else { 736 const DiscreteVariable& var = _bn_->variable(_varNameMap_[_stringBag_[0]]); 737 // Checking consistency between values and var. 738 739 if (values.size() != var.domainSize()) { 740 GUM_ERROR(OperationNotAllowed, 741 var.name() << " : invalid number of modalities: found " << values.size() 742 << " while needed " << var.domainSize()) 743 } 744 745 setVariableValuesUnchecked(values); 746 } 747 } 748 749 // Tells the factory that we finished declaring a conditional probability 750 // table. 751 template < typename GUM_SCALAR > endFactorizedProbabilityDeclaration()752 INLINE void BayesNetFactory< GUM_SCALAR >::endFactorizedProbabilityDeclaration() { 753 if (state() != factory_state::FACT_CPT) { 754 _illegalStateError_("endFactorizedProbabilityDeclaration"); 755 } else { 756 _resetParts_(); 757 _states_.pop_back(); 758 } 759 } 760 761 // @brief Define a variable. 762 // You can only call this method is the factory is in the NONE or NETWORK 763 // state. 764 // The variable is added by copy. 765 // @param var The pointer over a DiscreteVariable used to define a new 766 // variable in the built BayesNet. 767 // @throw DuplicateElement Raised if a variable with the same name already 768 // exists. 769 // @throw OperationNotAllowed Raised if redefineParents == false and if table 770 // is not a valid CPT for var in the current state 771 // of the BayesNet. 772 template < typename GUM_SCALAR > setVariable(const DiscreteVariable & var)773 INLINE void BayesNetFactory< GUM_SCALAR >::setVariable(const DiscreteVariable& var) { 774 if ((state() != factory_state::NONE)) { 775 _illegalStateError_("setVariable"); 776 } else { 777 try { 778 _checkVariableName_(var.name()); 779 GUM_ERROR(DuplicateElement, "Name already used: " << var.name()) 780 } catch (NotFound&) { 781 // The var name is unused 782 _varNameMap_.insert(var.name(), _bn_->add(var)); 783 } 784 } 785 } 786 787 // @brief Define a variable's CPT. 788 // You can only call this method if the factory is in the NONE or NETWORK 789 // state. 790 // Be careful that table is given to the built BayesNet, so it will be 791 // deleted with it, and you should not directly access it after you call 792 // this method. 793 // When the redefineParents flag is set to true the constructed BayesNet's 794 // DAG is changed to fit with table's definition. 795 // @param var The name of the concerned variable. 796 // @param table A pointer over the CPT used for var. 797 // @param redefineParents If true redefine parents of the variable to match 798 // table's 799 // variables set. 800 // 801 // @throw NotFound Raised if no variable matches var. 802 // @throw OperationNotAllowed Raised if redefineParents == false and if table 803 // is not a valid CPT for var in the current state 804 // of the BayesNet. 805 template < typename GUM_SCALAR > setVariableCPT(const std::string & varName,MultiDimAdressable * table,bool redefineParents)806 INLINE void BayesNetFactory< GUM_SCALAR >::setVariableCPT(const std::string& varName, 807 MultiDimAdressable* table, 808 bool redefineParents) { 809 auto pot = dynamic_cast< Potential< GUM_SCALAR >* >(table); 810 811 if (state() != factory_state::NONE) { 812 _illegalStateError_("setVariableCPT"); 813 } else { 814 _checkVariableName_(varName); 815 const DiscreteVariable& var = _bn_->variable(_varNameMap_[varName]); 816 NodeId varId = _varNameMap_[varName]; 817 // If we have to change the structure of the BayesNet, then we call a sub 818 // method. 819 820 if (redefineParents) { 821 _setCPTAndParents_(var, pot); 822 } else if (pot->contains(var)) { 823 for (auto node: _bn_->parents(varId)) { 824 if (!pot->contains(_bn_->variable(node))) { 825 GUM_ERROR(OperationNotAllowed, "The CPT is not valid in the current BayesNet.") 826 } 827 } 828 829 // CPT are created when a variable is added. 830 _bn_->_unsafeChangePotential_(varId, pot); 831 } 832 } 833 } 834 835 // Raise an OperationNotAllowed with the message "Illegal state." 836 template < typename GUM_SCALAR > _illegalStateError_(const std::string & s)837 INLINE void BayesNetFactory< GUM_SCALAR >::_illegalStateError_(const std::string& s) { 838 std::string msg = "Illegal state call ("; 839 msg += s; 840 msg += ") in state "; 841 842 switch (state()) { 843 case factory_state::NONE: { 844 msg += "NONE"; 845 break; 846 } 847 848 case factory_state::NETWORK: { 849 msg += "NETWORK"; 850 break; 851 } 852 853 case factory_state::VARIABLE: { 854 msg += "VARIABLE"; 855 break; 856 } 857 858 case factory_state::PARENTS: { 859 msg += "PARENTS"; 860 break; 861 } 862 863 case factory_state::RAW_CPT: { 864 msg += "RAW_CPT"; 865 break; 866 } 867 868 case factory_state::FACT_CPT: { 869 msg += "FACT_CPT"; 870 break; 871 } 872 873 case factory_state::FACT_ENTRY: { 874 msg += "FACT_ENTRY"; 875 break; 876 } 877 878 default: { 879 msg += "Unknown state"; 880 } 881 } 882 883 GUM_ERROR(OperationNotAllowed, msg) 884 } 885 886 // Check if a variable with the given name exists, if not raise an NotFound 887 // exception. 888 template < typename GUM_SCALAR > _checkVariableName_(const std::string & name)889 INLINE void BayesNetFactory< GUM_SCALAR >::_checkVariableName_(const std::string& name) { 890 if (!_varNameMap_.exists(name)) { GUM_ERROR(NotFound, name) } 891 } 892 893 // Check if var exists and if mod is one of it's modality, if not raise an 894 // NotFound exception. 895 template < typename GUM_SCALAR > _checkVariableModality_(const std::string & name,const std::string & mod)896 INLINE Idx BayesNetFactory< GUM_SCALAR >::_checkVariableModality_(const std::string& name, 897 const std::string& mod) { 898 _checkVariableName_(name); 899 const DiscreteVariable& var = _bn_->variable(_varNameMap_[name]); 900 901 for (Idx i = 0; i < var.domainSize(); ++i) { 902 if (mod == var.label(i)) { return i; } 903 } 904 905 GUM_ERROR(NotFound, mod) 906 } 907 908 // Check if in _stringBag_ there is no other modality with the same name. 909 template < typename GUM_SCALAR > _checkModalityInBag_(const std::string & mod)910 INLINE void BayesNetFactory< GUM_SCALAR >::_checkModalityInBag_(const std::string& mod) { 911 for (size_t i = 3; i < _stringBag_.size(); ++i) { 912 if (mod == _stringBag_[i]) { GUM_ERROR(DuplicateElement, "Label already used: " << mod) } 913 } 914 } 915 916 // Sub method of setVariableCPT() which redefine the BayesNet's DAG with 917 // respect to table. 918 template < typename GUM_SCALAR > _setCPTAndParents_(const DiscreteVariable & var,Potential<GUM_SCALAR> * table)919 INLINE void BayesNetFactory< GUM_SCALAR >::_setCPTAndParents_(const DiscreteVariable& var, 920 Potential< GUM_SCALAR >* table) { 921 NodeId varId = _varNameMap_[var.name()]; 922 _bn_->dag_.eraseParents(varId); 923 924 for (auto v: table->variablesSequence()) { 925 if (v != (&var)) { 926 _checkVariableName_(v->name()); 927 _bn_->dag_.addArc(_varNameMap_[v->name()], varId); 928 } 929 } 930 931 // CPT are created when a variable is added. 932 _bn_->_unsafeChangePotential_(varId, table); 933 } 934 935 // Reset the different parts used to constructed the BayesNet. 936 template < typename GUM_SCALAR > _resetParts_()937 INLINE void BayesNetFactory< GUM_SCALAR >::_resetParts_() { 938 _foo_flag_ = false; 939 _bar_flag_ = false; 940 _stringBag_.clear(); 941 } 942 } /* namespace gum */