1 /** 2 * 3 * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU) 4 * info_at_agrum_dot_org 5 * 6 * This library is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU Lesser General Public License as published by 8 * the Free Software Foundation, either version 3 of the License, or 9 * (at your option) any later version. 10 * 11 * This library is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 * GNU Lesser General Public License for more details. 15 * 16 * You should have received a copy of the GNU Lesser General Public License 17 * along with this library. If not, see <http://www.gnu.org/licenses/>. 18 * 19 */ 20 21 22 /** 23 * @file 24 * @brief A class for generic framework of learning algorithms that can easily 25 * be used. 26 * 27 * The pack currently contains K2, GreedyHillClimbing, miic, 3off2 and 28 * LocalSearchWithTabuList 29 * 30 * @author Christophe GONZALES(_at_AMU) and Pierre-Henri WUILLEMIN(_at_LIP6) 31 */ 32 #ifndef GUM_LEARNING_GENERIC_BN_LEARNER_H 33 #define GUM_LEARNING_GENERIC_BN_LEARNER_H 34 35 #include <sstream> 36 #include <memory> 37 38 #include <agrum/BN/BayesNet.h> 39 #include <agrum/agrum.h> 40 #include <agrum/tools/core/bijection.h> 41 #include <agrum/tools/core/sequence.h> 42 #include <agrum/tools/graphs/DAG.h> 43 44 #include <agrum/tools/database/DBTranslator4LabelizedVariable.h> 45 #include <agrum/tools/database/DBRowGeneratorParser.h> 46 #include <agrum/tools/database/DBInitializerFromCSV.h> 47 #include <agrum/tools/database/databaseTable.h> 48 #include <agrum/tools/database/DBRowGeneratorParser.h> 49 #include <agrum/tools/database/DBRowGenerator4CompleteRows.h> 50 #include <agrum/tools/database/DBRowGeneratorEM.h> 51 #include <agrum/tools/database/DBRowGeneratorSet.h> 52 53 #include <agrum/BN/learning/scores_and_tests/scoreAIC.h> 54 #include <agrum/BN/learning/scores_and_tests/scoreBD.h> 55 #include <agrum/BN/learning/scores_and_tests/scoreBDeu.h> 56 #include <agrum/BN/learning/scores_and_tests/scoreBIC.h> 57 #include <agrum/BN/learning/scores_and_tests/scoreK2.h> 58 #include <agrum/BN/learning/scores_and_tests/scoreLog2Likelihood.h> 59 60 #include <agrum/BN/learning/aprioris/aprioriDirichletFromDatabase.h> 61 #include <agrum/BN/learning/aprioris/aprioriNoApriori.h> 62 #include <agrum/BN/learning/aprioris/aprioriSmoothing.h> 63 #include <agrum/BN/learning/aprioris/aprioriBDeu.h> 64 65 #include <agrum/BN/learning/constraints/structuralConstraintDAG.h> 66 #include <agrum/BN/learning/constraints/structuralConstraintDiGraph.h> 67 #include <agrum/BN/learning/constraints/structuralConstraintForbiddenArcs.h> 68 #include <agrum/BN/learning/constraints/structuralConstraintPossibleEdges.h> 69 #include <agrum/BN/learning/constraints/structuralConstraintIndegree.h> 70 #include <agrum/BN/learning/constraints/structuralConstraintMandatoryArcs.h> 71 #include <agrum/BN/learning/constraints/structuralConstraintSetStatic.h> 72 #include <agrum/BN/learning/constraints/structuralConstraintSliceOrder.h> 73 #include <agrum/BN/learning/constraints/structuralConstraintTabuList.h> 74 75 #include <agrum/BN/learning/structureUtils/graphChange.h> 76 #include <agrum/BN/learning/structureUtils/graphChangesGenerator4DiGraph.h> 77 #include <agrum/BN/learning/structureUtils/graphChangesGenerator4K2.h> 78 #include <agrum/BN/learning/structureUtils/graphChangesSelector4DiGraph.h> 79 80 #include <agrum/BN/learning/paramUtils/DAG2BNLearner.h> 81 #include <agrum/BN/learning/paramUtils/paramEstimatorML.h> 82 83 #include <agrum/tools/core/approximations/IApproximationSchemeConfiguration.h> 84 #include <agrum/tools/core/approximations/approximationSchemeListener.h> 85 86 #include <agrum/BN/learning/K2.h> 87 #include <agrum/BN/learning/Miic.h> 88 #include <agrum/BN/learning/greedyHillClimbing.h> 89 #include <agrum/BN/learning/localSearchWithTabuList.h> 90 91 #include <agrum/tools/core/signal/signaler.h> 92 93 namespace gum { 94 95 namespace learning { 96 97 class BNLearnerListener; 98 99 /** @class genericBNLearner 100 * @brief A pack of learning algorithms that can easily be used 101 * 102 * The pack currently contains K2, GreedyHillClimbing and 103 * LocalSearchWithTabuList also 3off2/miic 104 * @ingroup learning_group 105 */ 106 class genericBNLearner: public gum::IApproximationSchemeConfiguration { 107 // private: 108 public: 109 /// an enumeration enabling to select easily the score we wish to use 110 enum class ScoreType 111 { 112 AIC, 113 BD, 114 BDeu, 115 BIC, 116 K2, 117 LOG2LIKELIHOOD 118 }; 119 120 /// an enumeration to select the type of parameter estimation we shall 121 /// apply 122 enum class ParamEstimatorType 123 { 124 ML 125 }; 126 127 /// an enumeration to select the apriori 128 enum class AprioriType 129 { 130 NO_APRIORI, 131 SMOOTHING, 132 DIRICHLET_FROM_DATABASE, 133 BDEU 134 }; 135 136 /// an enumeration to select easily the learning algorithm to use 137 enum class AlgoType 138 { 139 K2, 140 GREEDY_HILL_CLIMBING, 141 LOCAL_SEARCH_WITH_TABU_LIST, 142 MIIC, 143 THREE_OFF_TWO 144 }; 145 146 147 /// a helper to easily read databases 148 class Database { 149 public: 150 // ######################################################################## 151 /// @name Constructors / Destructors 152 // ######################################################################## 153 /// @{ 154 155 /// default constructor 156 /** @param file the name of the CSV file containing the data 157 * @param missing_symbols the set of symbols in the CSV file that 158 * correspond to missing data 159 * @param induceTypes By default, all the values in the dataset are 160 * interpreted as "labels", i.e., as categorical values. But if some 161 * columns of the dataset have only numerical values, it would certainly 162 * be better totag them as corresponding to integer, range or continuous 163 * variables. By setting induceTypes to true, this is precisely what the 164 * BNLearner will do. 165 */ 166 explicit Database(const std::string& file, 167 const std::vector< std::string >& missing_symbols, 168 const bool induceTypes = false); 169 170 /// default constructor 171 /** @param db an already initialized database table that is used to 172 * fill the Database 173 */ 174 explicit Database(const DatabaseTable<>& db); 175 176 /// constructor for the aprioris 177 /** We must ensure that the variables of the Database are identical to 178 * those of the score database (else the countings used by the 179 * scores might be erroneous). However, we allow the variables to be 180 * ordered differently in the two databases: variables with the same 181 * name in both databases are supposed to be the same. 182 * @param file the name of the CSV file containing the data 183 * @param score_database the main database used for the learning 184 * @param missing_symbols the set of symbols in the CSV file that 185 * correspond to missing data 186 */ 187 Database(const std::string& filename, 188 Database& score_database, 189 const std::vector< std::string >& missing_symbols); 190 191 /// constructor with a BN providing the variables of interest 192 /** @param file the name of the CSV file containing the data 193 * @param bn a Bayesian network indicating which variables of the CSV 194 * file are used for learning 195 * @param missing_symbols the set of symbols in the CSV file that 196 * correspond to missing data 197 */ 198 template < typename GUM_SCALAR > 199 Database(const std::string& filename, 200 const gum::BayesNet< GUM_SCALAR >& bn, 201 const std::vector< std::string >& missing_symbols); 202 203 /// copy constructor 204 Database(const Database& from); 205 206 /// move constructor 207 Database(Database&& from); 208 209 /// destructor 210 ~Database(); 211 212 /// @} 213 214 // ######################################################################## 215 /// @name Operators 216 // ######################################################################## 217 /// @{ 218 219 /// copy operator 220 Database& operator=(const Database& from); 221 222 /// move operator 223 Database& operator=(Database&& from); 224 225 /// @} 226 227 // ######################################################################## 228 /// @name Accessors / Modifiers 229 // ######################################################################## 230 /// @{ 231 232 /// returns the parser for the database 233 DBRowGeneratorParser<>& parser(); 234 235 /// returns the domain sizes of the variables 236 const std::vector< std::size_t >& domainSizes() const; 237 238 /// returns the names of the variables in the database 239 const std::vector< std::string >& names() const; 240 241 /// returns the node id corresponding to a variable name 242 NodeId idFromName(const std::string& var_name) const; 243 244 /// returns the variable name corresponding to a given node id 245 const std::string& nameFromId(NodeId id) const; 246 247 /// returns the internal database table 248 const DatabaseTable<>& databaseTable() const; 249 250 /** @brief assign a weight to all the rows of the database so 251 * that the sum of their weights is equal to new_weight */ 252 void setDatabaseWeight(const double new_weight); 253 254 /// returns the mapping between node ids and their columns in the database 255 const Bijection< NodeId, std::size_t >& nodeId2Columns() const; 256 257 /// returns the set of missing symbols taken into account 258 const std::vector< std::string >& missingSymbols() const; 259 260 /// returns the number of records in the database 261 std::size_t nbRows() const; 262 263 /// returns the number of records in the database 264 std::size_t size() const; 265 266 /// sets the weight of the ith record 267 /** @throws OutOfBounds if i is outside the set of indices of the 268 * records or if the weight is negative 269 */ 270 void setWeight(const std::size_t i, const double weight); 271 272 /// returns the weight of the ith record 273 /** @throws OutOfBounds if i is outside the set of indices of the 274 * records */ 275 double weight(const std::size_t i) const; 276 277 /// returns the weight of the whole database 278 double weight() const; 279 280 281 /// @} 282 283 protected: 284 /// the database itself 285 DatabaseTable<> _database_; 286 287 /// the parser used for reading the database 288 DBRowGeneratorParser<>* _parser_{nullptr}; 289 290 /// the domain sizes of the variables (useful to speed-up computations) 291 std::vector< std::size_t > _domain_sizes_; 292 293 /// a bijection assigning to each variable name its NodeId 294 Bijection< NodeId, std::size_t > _nodeId2cols_; 295 296 /// the max number of threads authorized 297 #if defined(_OPENMP) && !defined(GUM_DEBUG_MODE) 298 Size _max_threads_number_{getMaxNumberOfThreads()}; 299 #else 300 Size _max_threads_number_{1}; 301 #endif /* GUM_DEBUG_MODE */ 302 303 /// the minimal number of rows to parse (on average) by thread 304 Size _min_nb_rows_per_thread_{100}; 305 306 private: 307 // returns the set of variables as a BN. This is convenient for 308 // the constructors of apriori Databases 309 template < typename GUM_SCALAR > 310 BayesNet< GUM_SCALAR > _BNVars_() const; 311 }; 312 313 /// sets the apriori weight 314 void _setAprioriWeight_(double weight); 315 316 public: 317 // ########################################################################## 318 /// @name Constructors / Destructors 319 // ########################################################################## 320 /// @{ 321 322 /** 323 * read the database file for the score / parameter estimation and var 324 * names 325 * @param filename the name of a CSV file containing the dataset 326 * @param missing_symbols the set of symbols in the CSV that should 327 * be interpreted as missing values 328 * @param induceTypes when some columns of the dataset have only numerical 329 * values, it is certainly be better to tag them as corresponding to integer, 330 * range or continuous variables. By setting induceTypes to true (default), this is 331 * precisely what the BNLearner will do. If inducedTypes is false, all the values in 332 * the dataset are interpreted as "labels", i.e., as categorical values. 333 */ 334 genericBNLearner(const std::string& filename, 335 const std::vector< std::string >& missingSymbols, 336 bool induceTypes = true); 337 338 genericBNLearner(const DatabaseTable<>& db); 339 340 /** 341 * read the database file for the score / parameter estimation and var 342 * names 343 * @param filename The file to learn from. 344 * @param src indicate for some nodes (not necessarily all the 345 * nodes of the BN) which modalities they should have and in which order 346 * these modalities should be stored into the nodes. For instance, if 347 * modalities = { 1 -> {True, False, Big} }, then the node of id 1 in the 348 * BN will have 3 modalities, the first one being True, the second one 349 * being False, and the third bein Big. 350 * The modalities specified by the user will be considered 351 * as being exactly those of the variables of the BN (as a consequence, 352 * if we find other values in the database, an exception will be raised 353 * during learning). 354 * @param missing_symbols the set of symbols in the CSV that should 355 * be interpreted as missing values 356 */ 357 template < typename GUM_SCALAR > 358 genericBNLearner(const std::string& filename, 359 const gum::BayesNet< GUM_SCALAR >& src, 360 const std::vector< std::string >& missing_symbols); 361 362 /// copy constructor 363 genericBNLearner(const genericBNLearner&); 364 365 /// move constructor 366 genericBNLearner(genericBNLearner&&); 367 368 /// destructor 369 virtual ~genericBNLearner(); 370 371 /// @} 372 373 // ########################################################################## 374 /// @name Operators 375 // ########################################################################## 376 /// @{ 377 378 /// copy operator 379 genericBNLearner& operator=(const genericBNLearner&); 380 381 /// move operator 382 genericBNLearner& operator=(genericBNLearner&&); 383 384 /// @} 385 386 // ########################################################################## 387 /// @name Accessors / Modifiers 388 // ########################################################################## 389 /// @{ 390 391 /// learn a structure from a file (must have read the db before) 392 DAG learnDAG(); 393 394 /// learn a partial structure from a file (must have read the db before and 395 /// must have selected miic or 3off2) 396 MixedGraph learnMixedStructure(); 397 398 /// sets an initial DAG structure 399 void setInitialDAG(const DAG&); 400 401 /// returns the initial DAG structure 402 DAG initialDAG(); 403 404 /// returns the names of the variables in the database 405 const std::vector< std::string >& names() const; 406 407 /// returns the domain sizes of the variables in the database 408 const std::vector< std::size_t >& domainSizes() const; 409 Size domainSize(NodeId var) const; 410 Size domainSize(const std::string& var) const; 411 412 /// returns the node id corresponding to a variable name 413 /** 414 * @throw MissingVariableInDatabase if a variable of the BN is not found 415 * in the database. 416 */ 417 NodeId idFromName(const std::string& var_name) const; 418 419 /// returns the database used by the BNLearner 420 const DatabaseTable<>& database() const; 421 422 /** @brief assign a weight to all the rows of the learning database so 423 * that the sum of their weights is equal to new_weight */ 424 void setDatabaseWeight(const double new_weight); 425 426 /// sets the weight of the ith record of the database 427 /** @throws OutOfBounds if i is outside the set of indices of the 428 * records or if the weight is negative 429 */ 430 void setRecordWeight(const std::size_t i, const double weight); 431 432 /// returns the weight of the ith record 433 /** @throws OutOfBounds if i is outside the set of indices of the 434 * records */ 435 double recordWeight(const std::size_t i) const; 436 437 /// returns the weight of the whole database 438 double databaseWeight() const; 439 440 /// returns the variable name corresponding to a given node id 441 const std::string& nameFromId(NodeId id) const; 442 443 /// use a new set of database rows' ranges to perform learning 444 /** @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows 445 * indices. The subsequent learnings are then performed only on the union 446 * of the rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when 447 * performing cross validation tasks, in which part of the database should 448 * be ignored. An empty set of ranges is equivalent to an interval [X,Y) 449 * ranging over the whole database. */ 450 template < template < typename > class XALLOC > 451 void useDatabaseRanges( 452 const std::vector< std::pair< std::size_t, std::size_t >, 453 XALLOC< std::pair< std::size_t, std::size_t > > >& new_ranges); 454 455 /// reset the ranges to the one range corresponding to the whole database 456 void clearDatabaseRanges(); 457 458 /// returns the current database rows' ranges used for learning 459 /** @return The method returns a vector of pairs [Xi,Yi) of indices of 460 * rows in the database. The learning is performed on these set of rows. 461 * @warning an empty set of ranges means the whole database. */ 462 const std::vector< std::pair< std::size_t, std::size_t > >& databaseRanges() const; 463 464 /// sets the ranges of rows to be used for cross-validation learning 465 /** When applied on (x,k), the method indicates to the subsequent learnings 466 * that they should be performed on the xth fold in a k-fold 467 * cross-validation context. For instance, if a database has 1000 rows, 468 * and if we perform a 10-fold cross-validation, then, the first learning 469 * fold (learning_fold=0) corresponds to rows interval [100,1000) and the 470 * test dataset corresponds to [0,100). The second learning fold 471 * (learning_fold=1) is [0,100) U [200,1000) and the corresponding test 472 * dataset is [100,200). 473 * @param learning_fold a number indicating the set of rows used for 474 * learning. If N denotes the size of the database, and k_fold represents 475 * the number of folds in the cross validation, then the set of rows 476 * used for testing is [learning_fold * N / k_fold, 477 * (learning_fold+1) * N / k_fold) and the learning database is the 478 * complement in the database 479 * @param k_fold the value of "k" in k-fold cross validation 480 * @return a pair [x,y) of rows' indices that corresponds to the indices 481 * of rows in the original database that constitute the test dataset 482 * @throws OutOfBounds is raised if k_fold is equal to 0 or learning_fold 483 * is greater than or eqal to k_fold, or if k_fold is greater than 484 * or equal to the size of the database. */ 485 std::pair< std::size_t, std::size_t > useCrossValidationFold(const std::size_t learning_fold, 486 const std::size_t k_fold); 487 488 489 /** 490 * Return the <statistic,pvalue> pair for chi2 test in the database 491 * @param id1 first variable 492 * @param id2 second variable 493 * @param knowing list of observed variables 494 * @return a std::pair<double,double> 495 */ 496 std::pair< double, double > 497 chi2(const NodeId id1, const NodeId id2, const std::vector< NodeId >& knowing = {}); 498 /** 499 * Return the <statistic,pvalue> pair for the BNLearner 500 * @param id1 first variable 501 * @param id2 second variable 502 * @param knowing list of observed variables 503 * @return a std::pair<double,double> 504 */ 505 std::pair< double, double > chi2(const std::string& name1, 506 const std::string& name2, 507 const std::vector< std::string >& knowing = {}); 508 509 /** 510 * Return the <statistic,pvalue> pair for for G2 test in the database 511 * @param id1 first variable 512 * @param id2 second variable 513 * @param knowing list of observed variables 514 * @return a std::pair<double,double> 515 */ 516 std::pair< double, double > 517 G2(const NodeId id1, const NodeId id2, const std::vector< NodeId >& knowing = {}); 518 /** 519 * Return the <statistic,pvalue> pair for for G2 test in the database 520 * @param id1 first variable 521 * @param id2 second variable 522 * @param knowing list of observed variables 523 * @return a std::pair<double,double> 524 */ 525 std::pair< double, double > G2(const std::string& name1, 526 const std::string& name2, 527 const std::vector< std::string >& knowing = {}); 528 529 /** 530 * Return the loglikelihood of vars in the base, conditioned by knowing for 531 * the BNLearner 532 * @param vars a vector of NodeIds 533 * @param knowing an optional vector of conditioning NodeIds 534 * @return a std::pair<double,double> 535 */ 536 double logLikelihood(const std::vector< NodeId >& vars, 537 const std::vector< NodeId >& knowing = {}); 538 539 /** 540 * Return the loglikelihood of vars in the base, conditioned by knowing for 541 * the BNLearner 542 * @param vars a vector of name of rows 543 * @param knowing an optional vector of conditioning rows 544 * @return a std::pair<double,double> 545 */ 546 double logLikelihood(const std::vector< std::string >& vars, 547 const std::vector< std::string >& knowing = {}); 548 549 /** 550 * Return the pseudoconts ofNodeIds vars in the base in a raw array 551 * @param vars a vector of 552 * @return a a std::vector<double> containing the contingency table 553 */ 554 std::vector< double > rawPseudoCount(const std::vector< NodeId >& vars); 555 556 /** 557 * Return the pseudoconts of vars in the base in a raw array 558 * @param vars a vector of name 559 * @return a std::vector<double> containing the contingency table 560 */ 561 std::vector< double > rawPseudoCount(const std::vector< std::string >& vars); 562 /** 563 * 564 * @return the number of cols in the database 565 */ 566 Size nbCols() const; 567 568 /** 569 * 570 * @return the number of rows in the database 571 */ 572 Size nbRows() const; 573 574 /** use The EM algorithm to learn paramters 575 * 576 * if epsilon=0, EM is not used 577 */ 578 void useEM(const double epsilon); 579 580 /// returns true if the learner's database has missing values 581 bool hasMissingValues() const; 582 583 /// @} 584 585 // ########################################################################## 586 /// @name Score selection 587 // ########################################################################## 588 /// @{ 589 590 /// indicate that we wish to use an AIC score 591 void useScoreAIC(); 592 593 /// indicate that we wish to use a BD score 594 void useScoreBD(); 595 596 /// indicate that we wish to use a BDeu score 597 void useScoreBDeu(); 598 599 /// indicate that we wish to use a BIC score 600 void useScoreBIC(); 601 602 /// indicate that we wish to use a K2 score 603 void useScoreK2(); 604 605 /// indicate that we wish to use a Log2Likelihood score 606 void useScoreLog2Likelihood(); 607 608 /// @} 609 610 // ########################################################################## 611 /// @name A priori selection / parameterization 612 // ########################################################################## 613 /// @{ 614 615 /// use no apriori 616 void useNoApriori(); 617 618 /// use the BDeu apriori 619 /** The BDeu apriori adds weight to all the cells of the countings 620 * tables. In other words, it adds weight rows in the database with 621 * equally probable values. */ 622 void useAprioriBDeu(double weight = 1); 623 624 /// use the apriori smoothing 625 /** @param weight pass in argument a weight if you wish to assign a weight 626 * to the smoothing, else the current weight of the genericBNLearner will 627 * be used. */ 628 void useAprioriSmoothing(double weight = 1); 629 630 /// use the Dirichlet apriori 631 void useAprioriDirichlet(const std::string& filename, double weight = 1); 632 633 634 /// checks whether the current score and apriori are compatible 635 /** @returns a non empty string if the apriori is somehow compatible with the 636 * score.*/ 637 std::string checkScoreAprioriCompatibility() const; 638 /// @} 639 640 // ########################################################################## 641 /// @name Learning algorithm selection 642 // ########################################################################## 643 /// @{ 644 645 /// indicate that we wish to use a greedy hill climbing algorithm 646 void useGreedyHillClimbing(); 647 648 /// indicate that we wish to use a local search with tabu list 649 /** @param tabu_size indicate the size of the tabu list 650 * @param nb_decrease indicate the max number of changes decreasing the 651 * score consecutively that we allow to apply */ 652 void useLocalSearchWithTabuList(Size tabu_size = 100, Size nb_decrease = 2); 653 654 /// indicate that we wish to use K2 655 void useK2(const Sequence< NodeId >& order); 656 657 /// indicate that we wish to use K2 658 void useK2(const std::vector< NodeId >& order); 659 660 /// indicate that we wish to use 3off2 661 void use3off2(); 662 663 /// indicate that we wish to use MIIC 664 void useMIIC(); 665 666 /// @} 667 668 // ########################################################################## 669 /// @name 3off2/MIIC parameterization and specific results 670 // ########################################################################## 671 /// @{ 672 /// indicate that we wish to use the NML correction for 3off2 and MIIC 673 /// @throws OperationNotAllowed when 3off2 is not the selected algorithm 674 void useNMLCorrection(); 675 /// indicate that we wish to use the MDL correction for 3off2 and MIIC 676 /// @throws OperationNotAllowed when 3off2 is not the selected algorithm 677 void useMDLCorrection(); 678 /// indicate that we wish to use the NoCorr correction for 3off2 and MIIC 679 /// @throws OperationNotAllowed when 3off2 is not the selected algorithm 680 void useNoCorrection(); 681 682 /// get the list of arcs hiding latent variables 683 /// @throws OperationNotAllowed when 3off2 or MIIC is not the selected algorithm 684 const std::vector< Arc > latentVariables() const; 685 686 /// @} 687 // ########################################################################## 688 /// @name Accessors / Modifiers for adding constraints on learning 689 // ########################################################################## 690 /// @{ 691 692 /// sets the max indegree 693 void setMaxIndegree(Size max_indegree); 694 695 /** 696 * sets a partial order on the nodes 697 * @param slice_order a NodeProperty given the rank (priority) of nodes in 698 * the partial order 699 */ 700 void setSliceOrder(const NodeProperty< NodeId >& slice_order); 701 702 /** 703 * sets a partial order on the nodes 704 * @param slices the list of list of variable names 705 */ 706 void setSliceOrder(const std::vector< std::vector< std::string > >& slices); 707 708 /// assign a set of forbidden arcs 709 void setForbiddenArcs(const ArcSet& set); 710 711 /// @name assign a new forbidden arc 712 /// @{ 713 void addForbiddenArc(const Arc& arc); 714 void addForbiddenArc(const NodeId tail, const NodeId head); 715 void addForbiddenArc(const std::string& tail, const std::string& head); 716 /// @} 717 718 /// @name remove a forbidden arc 719 /// @{ 720 void eraseForbiddenArc(const Arc& arc); 721 void eraseForbiddenArc(const NodeId tail, const NodeId head); 722 void eraseForbiddenArc(const std::string& tail, const std::string& head); 723 ///@} 724 725 /// assign a set of forbidden arcs 726 void setMandatoryArcs(const ArcSet& set); 727 728 /// @name assign a new forbidden arc 729 ///@{ 730 void addMandatoryArc(const Arc& arc); 731 void addMandatoryArc(const NodeId tail, const NodeId head); 732 void addMandatoryArc(const std::string& tail, const std::string& head); 733 ///@} 734 735 /// @name remove a forbidden arc 736 ///@{ 737 void eraseMandatoryArc(const Arc& arc); 738 void eraseMandatoryArc(const NodeId tail, const NodeId head); 739 void eraseMandatoryArc(const std::string& tail, const std::string& head); 740 /// @} 741 742 /// assign a set of forbidden edges 743 /// @warning Once at least one possible edge is defined, all other edges are 744 /// not possible anymore 745 /// @{ 746 void setPossibleEdges(const EdgeSet& set); 747 void setPossibleSkeleton(const UndiGraph& skeleton); 748 /// @} 749 750 /// @name assign a new possible edge 751 /// @warning Once at least one possible edge is defined, all other edges are 752 /// not possible anymore 753 /// @{ 754 void addPossibleEdge(const Edge& edge); 755 void addPossibleEdge(const NodeId tail, const NodeId head); 756 void addPossibleEdge(const std::string& tail, const std::string& head); 757 /// @} 758 759 /// @name remove a possible edge 760 /// @{ 761 void erasePossibleEdge(const Edge& edge); 762 void erasePossibleEdge(const NodeId tail, const NodeId head); 763 void erasePossibleEdge(const std::string& tail, const std::string& head); 764 ///@} 765 766 ///@} 767 768 protected: 769 /// the policy for typing variables 770 bool inducedTypes_; 771 772 /// the score selected for learning 773 ScoreType scoreType_{ScoreType::BDeu}; 774 775 /// the score used 776 Score<>* score_{nullptr}; 777 778 /// the type of the parameter estimator 779 ParamEstimatorType paramEstimatorType_{ParamEstimatorType::ML}; 780 781 /// epsilon for EM. if espilon=0.0 : no EM 782 double epsilonEM_{0.0}; 783 784 /// the selected correction for 3off2 and miic 785 CorrectedMutualInformation<>* mutualInfo_{nullptr}; 786 787 /// the a priori selected for the score and parameters 788 AprioriType aprioriType_{AprioriType::NO_APRIORI}; 789 790 /// the apriori used 791 Apriori<>* apriori_{nullptr}; 792 793 AprioriNoApriori<>* noApriori_{nullptr}; 794 795 /// the weight of the apriori 796 double aprioriWeight_{1.0f}; 797 798 /// the constraint for 2TBNs 799 StructuralConstraintSliceOrder constraintSliceOrder_; 800 801 /// the constraint for indegrees 802 StructuralConstraintIndegree constraintIndegree_; 803 804 /// the constraint for tabu lists 805 StructuralConstraintTabuList constraintTabuList_; 806 807 /// the constraint on forbidden arcs 808 StructuralConstraintForbiddenArcs constraintForbiddenArcs_; 809 810 /// the constraint on possible Edges 811 StructuralConstraintPossibleEdges constraintPossibleEdges_; 812 813 /// the constraint on mandatory arcs 814 StructuralConstraintMandatoryArcs constraintMandatoryArcs_; 815 816 /// the selected learning algorithm 817 AlgoType selectedAlgo_{AlgoType::GREEDY_HILL_CLIMBING}; 818 819 /// the K2 algorithm 820 K2 algoK2_; 821 822 /// the MIIC or 3off2 algorithm 823 Miic algoMiic3off2_; 824 825 /// the penalty used in 3off2 826 typename CorrectedMutualInformation<>::KModeTypes kmode3Off2_{ 827 CorrectedMutualInformation<>::KModeTypes::MDL}; 828 829 /// the parametric EM 830 DAG2BNLearner<> Dag2BN_; 831 832 /// the greedy hill climbing algorithm 833 GreedyHillClimbing greedyHillClimbing_; 834 835 /// the local search with tabu list algorithm 836 LocalSearchWithTabuList localSearchWithTabuList_; 837 838 /// the database to be used by the scores and parameter estimators 839 Database scoreDatabase_; 840 841 /// the set of rows' ranges within the database in which learning is done 842 std::vector< std::pair< std::size_t, std::size_t > > ranges_; 843 844 /// the database used by the Dirichlet a priori 845 Database* aprioriDatabase_{nullptr}; 846 847 /// the filename for the Dirichlet a priori, if any 848 std::string aprioriDbname_; 849 850 851 /// an initial DAG given to learners 852 DAG initialDag_; 853 854 /// the filename database 855 std::string filename_; 856 857 // size of the tabu list 858 Size nbDecreasingChanges_{2}; 859 860 // the current algorithm as an approximationScheme 861 const ApproximationScheme* currentAlgorithm_{nullptr}; 862 863 /// reads a file and returns a databaseVectInRam 864 static DatabaseTable<> readFile_(const std::string& filename, 865 const std::vector< std::string >& missing_symbols); 866 867 /// checks whether the extension of a CSV filename is correct 868 static void isCSVFileName_(const std::string& filename); 869 870 /// create the apriori used for learning 871 void createApriori_(); 872 873 /// create the score used for learning 874 void createScore_(); 875 876 /// create the parameter estimator used for learning 877 ParamEstimator<>* createParamEstimator_(DBRowGeneratorParser<>& parser, 878 bool take_into_account_score = true); 879 880 /// returns the DAG learnt 881 DAG learnDag_(); 882 883 /// prepares the initial graph for 3off2 or miic 884 MixedGraph prepareMiic3Off2_(); 885 886 /// returns the type (as a string) of a given apriori 887 const std::string& getAprioriType_() const; 888 889 /// create the Corrected Mutual Information instance for Miic/3off2 890 void createCorrectedMutualInformation_(); 891 892 893 public: 894 // ########################################################################## 895 /// @name redistribute signals AND implementation of interface 896 /// IApproximationSchemeConfiguration 897 // ########################################################################## 898 // in order to not pollute the proper code of genericBNLearner, we 899 // directly 900 // implement those 901 // very simples methods here. 902 /// {@ /// distribute signals setCurrentApproximationScheme(const ApproximationScheme * approximationScheme)903 INLINE void setCurrentApproximationScheme(const ApproximationScheme* approximationScheme) { 904 currentAlgorithm_ = approximationScheme; 905 } 906 distributeProgress(const ApproximationScheme * approximationScheme,Size pourcent,double error,double time)907 INLINE void distributeProgress(const ApproximationScheme* approximationScheme, 908 Size pourcent, 909 double error, 910 double time) { 911 setCurrentApproximationScheme(approximationScheme); 912 913 if (onProgress.hasListener()) GUM_EMIT3(onProgress, pourcent, error, time); 914 }; 915 916 /// distribute signals distributeStop(const ApproximationScheme * approximationScheme,std::string message)917 INLINE void distributeStop(const ApproximationScheme* approximationScheme, 918 std::string message) { 919 setCurrentApproximationScheme(approximationScheme); 920 921 if (onStop.hasListener()) GUM_EMIT1(onStop, message); 922 }; 923 /// @} 924 925 /// Given that we approximate f(t), stopping criterion on |f(t+1)-f(t)| 926 /// If the criterion was disabled it will be enabled 927 /// @{ 928 /// @throw OutOfBounds if eps<0 setEpsilon(double eps)929 void setEpsilon(double eps) { 930 algoK2_.approximationScheme().setEpsilon(eps); 931 greedyHillClimbing_.setEpsilon(eps); 932 localSearchWithTabuList_.setEpsilon(eps); 933 Dag2BN_.setEpsilon(eps); 934 }; 935 936 /// Get the value of epsilon epsilon()937 double epsilon() const { 938 if (currentAlgorithm_ != nullptr) 939 return currentAlgorithm_->epsilon(); 940 else 941 GUM_ERROR(FatalError, "No chosen algorithm for learning") 942 } 943 944 /// Disable stopping criterion on epsilon disableEpsilon()945 void disableEpsilon() { 946 algoK2_.approximationScheme().disableEpsilon(); 947 greedyHillClimbing_.disableEpsilon(); 948 localSearchWithTabuList_.disableEpsilon(); 949 Dag2BN_.disableEpsilon(); 950 }; 951 952 /// Enable stopping criterion on epsilon enableEpsilon()953 void enableEpsilon() { 954 algoK2_.approximationScheme().enableEpsilon(); 955 greedyHillClimbing_.enableEpsilon(); 956 localSearchWithTabuList_.enableEpsilon(); 957 Dag2BN_.enableEpsilon(); 958 }; 959 960 /// @return true if stopping criterion on epsilon is enabled, false 961 /// otherwise isEnabledEpsilon()962 bool isEnabledEpsilon() const { 963 if (currentAlgorithm_ != nullptr) 964 return currentAlgorithm_->isEnabledEpsilon(); 965 else 966 GUM_ERROR(FatalError, "No chosen algorithm for learning") 967 } 968 /// @} 969 970 /// Given that we approximate f(t), stopping criterion on 971 /// d/dt(|f(t+1)-f(t)|) 972 /// If the criterion was disabled it will be enabled 973 /// @{ 974 /// @throw OutOfBounds if rate<0 setMinEpsilonRate(double rate)975 void setMinEpsilonRate(double rate) { 976 algoK2_.approximationScheme().setMinEpsilonRate(rate); 977 greedyHillClimbing_.setMinEpsilonRate(rate); 978 localSearchWithTabuList_.setMinEpsilonRate(rate); 979 Dag2BN_.setMinEpsilonRate(rate); 980 }; 981 982 /// Get the value of the minimal epsilon rate minEpsilonRate()983 double minEpsilonRate() const { 984 if (currentAlgorithm_ != nullptr) 985 return currentAlgorithm_->minEpsilonRate(); 986 else 987 GUM_ERROR(FatalError, "No chosen algorithm for learning") 988 } 989 990 /// Disable stopping criterion on epsilon rate disableMinEpsilonRate()991 void disableMinEpsilonRate() { 992 algoK2_.approximationScheme().disableMinEpsilonRate(); 993 greedyHillClimbing_.disableMinEpsilonRate(); 994 localSearchWithTabuList_.disableMinEpsilonRate(); 995 Dag2BN_.disableMinEpsilonRate(); 996 }; 997 /// Enable stopping criterion on epsilon rate enableMinEpsilonRate()998 void enableMinEpsilonRate() { 999 algoK2_.approximationScheme().enableMinEpsilonRate(); 1000 greedyHillClimbing_.enableMinEpsilonRate(); 1001 localSearchWithTabuList_.enableMinEpsilonRate(); 1002 Dag2BN_.enableMinEpsilonRate(); 1003 }; 1004 /// @return true if stopping criterion on epsilon rate is enabled, false 1005 /// otherwise isEnabledMinEpsilonRate()1006 bool isEnabledMinEpsilonRate() const { 1007 if (currentAlgorithm_ != nullptr) 1008 return currentAlgorithm_->isEnabledMinEpsilonRate(); 1009 else 1010 GUM_ERROR(FatalError, "No chosen algorithm for learning") 1011 } 1012 /// @} 1013 1014 /// stopping criterion on number of iterations 1015 /// @{ 1016 /// If the criterion was disabled it will be enabled 1017 /// @param max The maximum number of iterations 1018 /// @throw OutOfBounds if max<=1 setMaxIter(Size max)1019 void setMaxIter(Size max) { 1020 algoK2_.approximationScheme().setMaxIter(max); 1021 greedyHillClimbing_.setMaxIter(max); 1022 localSearchWithTabuList_.setMaxIter(max); 1023 Dag2BN_.setMaxIter(max); 1024 }; 1025 1026 /// @return the criterion on number of iterations maxIter()1027 Size maxIter() const { 1028 if (currentAlgorithm_ != nullptr) 1029 return currentAlgorithm_->maxIter(); 1030 else 1031 GUM_ERROR(FatalError, "No chosen algorithm for learning") 1032 } 1033 1034 /// Disable stopping criterion on max iterations disableMaxIter()1035 void disableMaxIter() { 1036 algoK2_.approximationScheme().disableMaxIter(); 1037 greedyHillClimbing_.disableMaxIter(); 1038 localSearchWithTabuList_.disableMaxIter(); 1039 Dag2BN_.disableMaxIter(); 1040 }; 1041 /// Enable stopping criterion on max iterations enableMaxIter()1042 void enableMaxIter() { 1043 algoK2_.approximationScheme().enableMaxIter(); 1044 greedyHillClimbing_.enableMaxIter(); 1045 localSearchWithTabuList_.enableMaxIter(); 1046 Dag2BN_.enableMaxIter(); 1047 }; 1048 /// @return true if stopping criterion on max iterations is enabled, false 1049 /// otherwise isEnabledMaxIter()1050 bool isEnabledMaxIter() const { 1051 if (currentAlgorithm_ != nullptr) 1052 return currentAlgorithm_->isEnabledMaxIter(); 1053 else 1054 GUM_ERROR(FatalError, "No chosen algorithm for learning") 1055 } 1056 /// @} 1057 1058 /// stopping criterion on timeout 1059 /// If the criterion was disabled it will be enabled 1060 /// @{ 1061 /// @throw OutOfBounds if timeout<=0.0 1062 /** timeout is time in second (double). 1063 */ setMaxTime(double timeout)1064 void setMaxTime(double timeout) { 1065 algoK2_.approximationScheme().setMaxTime(timeout); 1066 greedyHillClimbing_.setMaxTime(timeout); 1067 localSearchWithTabuList_.setMaxTime(timeout); 1068 Dag2BN_.setMaxTime(timeout); 1069 } 1070 1071 /// returns the timeout (in seconds) maxTime()1072 double maxTime() const { 1073 if (currentAlgorithm_ != nullptr) 1074 return currentAlgorithm_->maxTime(); 1075 else 1076 GUM_ERROR(FatalError, "No chosen algorithm for learning") 1077 } 1078 1079 /// get the current running time in second (double) currentTime()1080 double currentTime() const { 1081 if (currentAlgorithm_ != nullptr) 1082 return currentAlgorithm_->currentTime(); 1083 else 1084 GUM_ERROR(FatalError, "No chosen algorithm for learning") 1085 } 1086 1087 /// Disable stopping criterion on timeout disableMaxTime()1088 void disableMaxTime() { 1089 algoK2_.approximationScheme().disableMaxTime(); 1090 greedyHillClimbing_.disableMaxTime(); 1091 localSearchWithTabuList_.disableMaxTime(); 1092 Dag2BN_.disableMaxTime(); 1093 }; enableMaxTime()1094 void enableMaxTime() { 1095 algoK2_.approximationScheme().enableMaxTime(); 1096 greedyHillClimbing_.enableMaxTime(); 1097 localSearchWithTabuList_.enableMaxTime(); 1098 Dag2BN_.enableMaxTime(); 1099 }; 1100 /// @return true if stopping criterion on timeout is enabled, false 1101 /// otherwise isEnabledMaxTime()1102 bool isEnabledMaxTime() const { 1103 if (currentAlgorithm_ != nullptr) 1104 return currentAlgorithm_->isEnabledMaxTime(); 1105 else 1106 GUM_ERROR(FatalError, "No chosen algorithm for learning") 1107 } 1108 /// @} 1109 1110 /// how many samples between 2 stopping isEnableds 1111 /// @{ 1112 /// @throw OutOfBounds if p<1 setPeriodSize(Size p)1113 void setPeriodSize(Size p) { 1114 algoK2_.approximationScheme().setPeriodSize(p); 1115 greedyHillClimbing_.setPeriodSize(p); 1116 localSearchWithTabuList_.setPeriodSize(p); 1117 Dag2BN_.setPeriodSize(p); 1118 }; 1119 periodSize()1120 Size periodSize() const { 1121 if (currentAlgorithm_ != nullptr) 1122 return currentAlgorithm_->periodSize(); 1123 else 1124 GUM_ERROR(FatalError, "No chosen algorithm for learning") 1125 } 1126 /// @} 1127 1128 /// verbosity 1129 /// @{ setVerbosity(bool v)1130 void setVerbosity(bool v) { 1131 algoK2_.approximationScheme().setVerbosity(v); 1132 greedyHillClimbing_.setVerbosity(v); 1133 localSearchWithTabuList_.setVerbosity(v); 1134 Dag2BN_.setVerbosity(v); 1135 }; 1136 verbosity()1137 bool verbosity() const { 1138 if (currentAlgorithm_ != nullptr) 1139 return currentAlgorithm_->verbosity(); 1140 else 1141 GUM_ERROR(FatalError, "No chosen algorithm for learning") 1142 } 1143 /// @} 1144 1145 /// history 1146 /// @{ 1147 stateApproximationScheme()1148 ApproximationSchemeSTATE stateApproximationScheme() const { 1149 if (currentAlgorithm_ != nullptr) 1150 return currentAlgorithm_->stateApproximationScheme(); 1151 else 1152 GUM_ERROR(FatalError, "No chosen algorithm for learning") 1153 } 1154 1155 /// @throw OperationNotAllowed if scheme not performed nbrIterations()1156 Size nbrIterations() const { 1157 if (currentAlgorithm_ != nullptr) 1158 return currentAlgorithm_->nbrIterations(); 1159 else 1160 GUM_ERROR(FatalError, "No chosen algorithm for learning") 1161 } 1162 1163 /// @throw OperationNotAllowed if scheme not performed or verbosity=false history()1164 const std::vector< double >& history() const { 1165 if (currentAlgorithm_ != nullptr) 1166 return currentAlgorithm_->history(); 1167 else 1168 GUM_ERROR(FatalError, "No chosen algorithm for learning") 1169 } 1170 /// @} 1171 }; 1172 1173 } /* namespace learning */ 1174 1175 } /* namespace gum */ 1176 1177 /// include the inlined functions if necessary 1178 #ifndef GUM_NO_INLINE 1179 # include <agrum/BN/learning/BNLearnUtils/genericBNLearner_inl.h> 1180 #endif /* GUM_NO_INLINE */ 1181 1182 #include <agrum/BN/learning/BNLearnUtils/genericBNLearner_tpl.h> 1183 1184 #endif /* GUM_LEARNING_GENERIC_BN_LEARNER_H */ 1185