1 /**
2  *
3  *   Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU)
4  *   info_at_agrum_dot_org
5  *
6  *  This library is free software: you can redistribute it and/or modify
7  *  it under the terms of the GNU Lesser General Public License as published by
8  *  the Free Software Foundation, either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  This library is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU Lesser General Public License for more details.
15  *
16  *  You should have received a copy of the GNU Lesser General Public License
17  *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief A class for generic framework of learning algorithms that can easily
25  * be used.
26  *
27  * The pack currently contains K2, GreedyHillClimbing, miic, 3off2 and
28  * LocalSearchWithTabuList
29  *
30  * @author Christophe GONZALES(_at_AMU) and Pierre-Henri WUILLEMIN(_at_LIP6)
31  */
32 #ifndef GUM_LEARNING_GENERIC_BN_LEARNER_H
33 #define GUM_LEARNING_GENERIC_BN_LEARNER_H
34 
35 #include <sstream>
36 #include <memory>
37 
38 #include <agrum/BN/BayesNet.h>
39 #include <agrum/agrum.h>
40 #include <agrum/tools/core/bijection.h>
41 #include <agrum/tools/core/sequence.h>
42 #include <agrum/tools/graphs/DAG.h>
43 
44 #include <agrum/tools/database/DBTranslator4LabelizedVariable.h>
45 #include <agrum/tools/database/DBRowGeneratorParser.h>
46 #include <agrum/tools/database/DBInitializerFromCSV.h>
47 #include <agrum/tools/database/databaseTable.h>
48 #include <agrum/tools/database/DBRowGeneratorParser.h>
49 #include <agrum/tools/database/DBRowGenerator4CompleteRows.h>
50 #include <agrum/tools/database/DBRowGeneratorEM.h>
51 #include <agrum/tools/database/DBRowGeneratorSet.h>
52 
53 #include <agrum/BN/learning/scores_and_tests/scoreAIC.h>
54 #include <agrum/BN/learning/scores_and_tests/scoreBD.h>
55 #include <agrum/BN/learning/scores_and_tests/scoreBDeu.h>
56 #include <agrum/BN/learning/scores_and_tests/scoreBIC.h>
57 #include <agrum/BN/learning/scores_and_tests/scoreK2.h>
58 #include <agrum/BN/learning/scores_and_tests/scoreLog2Likelihood.h>
59 
60 #include <agrum/BN/learning/aprioris/aprioriDirichletFromDatabase.h>
61 #include <agrum/BN/learning/aprioris/aprioriNoApriori.h>
62 #include <agrum/BN/learning/aprioris/aprioriSmoothing.h>
63 #include <agrum/BN/learning/aprioris/aprioriBDeu.h>
64 
65 #include <agrum/BN/learning/constraints/structuralConstraintDAG.h>
66 #include <agrum/BN/learning/constraints/structuralConstraintDiGraph.h>
67 #include <agrum/BN/learning/constraints/structuralConstraintForbiddenArcs.h>
68 #include <agrum/BN/learning/constraints/structuralConstraintPossibleEdges.h>
69 #include <agrum/BN/learning/constraints/structuralConstraintIndegree.h>
70 #include <agrum/BN/learning/constraints/structuralConstraintMandatoryArcs.h>
71 #include <agrum/BN/learning/constraints/structuralConstraintSetStatic.h>
72 #include <agrum/BN/learning/constraints/structuralConstraintSliceOrder.h>
73 #include <agrum/BN/learning/constraints/structuralConstraintTabuList.h>
74 
75 #include <agrum/BN/learning/structureUtils/graphChange.h>
76 #include <agrum/BN/learning/structureUtils/graphChangesGenerator4DiGraph.h>
77 #include <agrum/BN/learning/structureUtils/graphChangesGenerator4K2.h>
78 #include <agrum/BN/learning/structureUtils/graphChangesSelector4DiGraph.h>
79 
80 #include <agrum/BN/learning/paramUtils/DAG2BNLearner.h>
81 #include <agrum/BN/learning/paramUtils/paramEstimatorML.h>
82 
83 #include <agrum/tools/core/approximations/IApproximationSchemeConfiguration.h>
84 #include <agrum/tools/core/approximations/approximationSchemeListener.h>
85 
86 #include <agrum/BN/learning/K2.h>
87 #include <agrum/BN/learning/Miic.h>
88 #include <agrum/BN/learning/greedyHillClimbing.h>
89 #include <agrum/BN/learning/localSearchWithTabuList.h>
90 
91 #include <agrum/tools/core/signal/signaler.h>
92 
93 namespace gum {
94 
95   namespace learning {
96 
97     class BNLearnerListener;
98 
99     /** @class genericBNLearner
100      * @brief A pack of learning algorithms that can easily be used
101      *
102      * The pack currently contains K2, GreedyHillClimbing and
103      * LocalSearchWithTabuList also 3off2/miic
104      * @ingroup learning_group
105      */
106     class genericBNLearner: public gum::IApproximationSchemeConfiguration {
107       // private:
108       public:
109       /// an enumeration enabling to select easily the score we wish to use
110       enum class ScoreType
111       {
112         AIC,
113         BD,
114         BDeu,
115         BIC,
116         K2,
117         LOG2LIKELIHOOD
118       };
119 
120       /// an enumeration to select the type of parameter estimation we shall
121       /// apply
122       enum class ParamEstimatorType
123       {
124         ML
125       };
126 
127       /// an enumeration to select the apriori
128       enum class AprioriType
129       {
130         NO_APRIORI,
131         SMOOTHING,
132         DIRICHLET_FROM_DATABASE,
133         BDEU
134       };
135 
136       /// an enumeration to select easily the learning algorithm to use
137       enum class AlgoType
138       {
139         K2,
140         GREEDY_HILL_CLIMBING,
141         LOCAL_SEARCH_WITH_TABU_LIST,
142         MIIC,
143         THREE_OFF_TWO
144       };
145 
146 
147       /// a helper to easily read databases
148       class Database {
149         public:
150         // ########################################################################
151         /// @name Constructors / Destructors
152         // ########################################################################
153         /// @{
154 
155         /// default constructor
156         /** @param file the name of the CSV file containing the data
157          * @param missing_symbols the set of symbols in the CSV file that
158          * correspond to missing data
159          * @param induceTypes By default, all the values in the dataset are
160          * interpreted as "labels", i.e., as categorical values. But if some
161          * columns of the dataset have only numerical values, it would certainly
162          * be better totag them as corresponding to integer, range or continuous
163          * variables. By setting induceTypes to true, this is precisely what the
164          * BNLearner will do.
165          */
166         explicit Database(const std::string&                file,
167                           const std::vector< std::string >& missing_symbols,
168                           const bool                        induceTypes = false);
169 
170         /// default constructor
171         /** @param db an already initialized database table that is used to
172          * fill the Database
173          */
174         explicit Database(const DatabaseTable<>& db);
175 
176         /// constructor for the aprioris
177         /** We must ensure that the variables of the Database are identical to
178          * those of the score database (else the countings used by the
179          * scores might be erroneous). However, we allow the variables to be
180          * ordered differently in the two databases: variables with the same
181          * name in both databases are supposed to be the same.
182          * @param file the name of the CSV file containing the data
183          * @param score_database the main database used for the learning
184          * @param missing_symbols the set of symbols in the CSV file that
185          * correspond to missing data
186          */
187         Database(const std::string&                filename,
188                  Database&                         score_database,
189                  const std::vector< std::string >& missing_symbols);
190 
191         /// constructor with a BN providing the variables of interest
192         /** @param file the name of the CSV file containing the data
193          * @param bn a Bayesian network indicating which variables of the CSV
194          * file are used for learning
195          * @param missing_symbols the set of symbols in the CSV file that
196          * correspond to missing data
197          */
198         template < typename GUM_SCALAR >
199         Database(const std::string&                 filename,
200                  const gum::BayesNet< GUM_SCALAR >& bn,
201                  const std::vector< std::string >&  missing_symbols);
202 
203         /// copy constructor
204         Database(const Database& from);
205 
206         /// move constructor
207         Database(Database&& from);
208 
209         /// destructor
210         ~Database();
211 
212         /// @}
213 
214         // ########################################################################
215         /// @name Operators
216         // ########################################################################
217         /// @{
218 
219         /// copy operator
220         Database& operator=(const Database& from);
221 
222         /// move operator
223         Database& operator=(Database&& from);
224 
225         /// @}
226 
227         // ########################################################################
228         /// @name Accessors / Modifiers
229         // ########################################################################
230         /// @{
231 
232         /// returns the parser for the database
233         DBRowGeneratorParser<>& parser();
234 
235         /// returns the domain sizes of the variables
236         const std::vector< std::size_t >& domainSizes() const;
237 
238         /// returns the names of the variables in the database
239         const std::vector< std::string >& names() const;
240 
241         /// returns the node id corresponding to a variable name
242         NodeId idFromName(const std::string& var_name) const;
243 
244         /// returns the variable name corresponding to a given node id
245         const std::string& nameFromId(NodeId id) const;
246 
247         /// returns the internal database table
248         const DatabaseTable<>& databaseTable() const;
249 
250         /** @brief assign a weight to all the rows of the database so
251          * that the sum of their weights is equal to new_weight */
252         void setDatabaseWeight(const double new_weight);
253 
254         /// returns the mapping between node ids and their columns in the database
255         const Bijection< NodeId, std::size_t >& nodeId2Columns() const;
256 
257         /// returns the set of missing symbols taken into account
258         const std::vector< std::string >& missingSymbols() const;
259 
260         /// returns the number of records in the database
261         std::size_t nbRows() const;
262 
263         /// returns the number of records in the database
264         std::size_t size() const;
265 
266         /// sets the weight of the ith record
267         /** @throws OutOfBounds if i is outside the set of indices of the
268          * records or if the weight is negative
269          */
270         void setWeight(const std::size_t i, const double weight);
271 
272         /// returns the weight of the ith record
273         /** @throws OutOfBounds if i is outside the set of indices of the
274          * records */
275         double weight(const std::size_t i) const;
276 
277         /// returns the weight of the whole database
278         double weight() const;
279 
280 
281         /// @}
282 
283         protected:
284         /// the database itself
285         DatabaseTable<> _database_;
286 
287         /// the parser used for reading the database
288         DBRowGeneratorParser<>* _parser_{nullptr};
289 
290         /// the domain sizes of the variables (useful to speed-up computations)
291         std::vector< std::size_t > _domain_sizes_;
292 
293         /// a bijection assigning to each variable name its NodeId
294         Bijection< NodeId, std::size_t > _nodeId2cols_;
295 
296 /// the max number of threads authorized
297 #if defined(_OPENMP) && !defined(GUM_DEBUG_MODE)
298         Size _max_threads_number_{getMaxNumberOfThreads()};
299 #else
300         Size _max_threads_number_{1};
301 #endif /* GUM_DEBUG_MODE */
302 
303         /// the minimal number of rows to parse (on average) by thread
304         Size _min_nb_rows_per_thread_{100};
305 
306         private:
307         // returns the set of variables as a BN. This is convenient for
308         // the constructors of apriori Databases
309         template < typename GUM_SCALAR >
310         BayesNet< GUM_SCALAR > _BNVars_() const;
311       };
312 
313       /// sets the apriori weight
314       void _setAprioriWeight_(double weight);
315 
316       public:
317       // ##########################################################################
318       /// @name Constructors / Destructors
319       // ##########################################################################
320       /// @{
321 
322       /**
323        * read the database file for the score / parameter estimation and var
324        * names
325        * @param filename the name of a CSV file containing the dataset
326        * @param missing_symbols the set of symbols in the CSV that should
327        * be interpreted as missing values
328        * @param induceTypes when some  columns of the dataset have only numerical
329        * values, it is certainly be better to tag them as corresponding to integer,
330        * range or continuous variables. By setting induceTypes to true (default), this is
331        * precisely what the BNLearner will do. If inducedTypes is false, all the values in
332        * the dataset are interpreted as "labels", i.e., as categorical values.
333        */
334       genericBNLearner(const std::string&                filename,
335                        const std::vector< std::string >& missingSymbols,
336                        bool                              induceTypes = true);
337 
338       genericBNLearner(const DatabaseTable<>& db);
339 
340       /**
341        * read the database file for the score / parameter estimation and var
342        * names
343        * @param filename The file to learn from.
344        * @param src indicate for some nodes (not necessarily all the
345        * nodes of the BN) which modalities they should have and in which order
346        * these modalities should be stored into the nodes. For instance, if
347        * modalities = { 1 -> {True, False, Big} }, then the node of id 1 in the
348        * BN will have 3 modalities, the first one being True, the second one
349        * being False, and the third bein Big.
350        * The modalities specified by the user will be considered
351        * as being exactly those of the variables of the BN (as a consequence,
352        * if we find other values in the database, an exception will be raised
353        * during learning).
354        * @param missing_symbols the set of symbols in the CSV that should
355        * be interpreted as missing values
356        */
357       template < typename GUM_SCALAR >
358       genericBNLearner(const std::string&                 filename,
359                        const gum::BayesNet< GUM_SCALAR >& src,
360                        const std::vector< std::string >&  missing_symbols);
361 
362       /// copy constructor
363       genericBNLearner(const genericBNLearner&);
364 
365       /// move constructor
366       genericBNLearner(genericBNLearner&&);
367 
368       /// destructor
369       virtual ~genericBNLearner();
370 
371       /// @}
372 
373       // ##########################################################################
374       /// @name Operators
375       // ##########################################################################
376       /// @{
377 
378       /// copy operator
379       genericBNLearner& operator=(const genericBNLearner&);
380 
381       /// move operator
382       genericBNLearner& operator=(genericBNLearner&&);
383 
384       /// @}
385 
386       // ##########################################################################
387       /// @name Accessors / Modifiers
388       // ##########################################################################
389       /// @{
390 
391       /// learn a structure from a file (must have read the db before)
392       DAG learnDAG();
393 
394       /// learn a partial structure from a file (must have read the db before and
395       /// must have selected miic or 3off2)
396       MixedGraph learnMixedStructure();
397 
398       /// sets an initial DAG structure
399       void setInitialDAG(const DAG&);
400 
401       /// returns the initial DAG structure
402       DAG initialDAG();
403 
404       /// returns the names of the variables in the database
405       const std::vector< std::string >& names() const;
406 
407       /// returns the domain sizes of the variables in the database
408       const std::vector< std::size_t >& domainSizes() const;
409       Size                              domainSize(NodeId var) const;
410       Size                              domainSize(const std::string& var) const;
411 
412       /// returns the node id corresponding to a variable name
413       /**
414        * @throw MissingVariableInDatabase if a variable of the BN is not found
415        * in the database.
416        */
417       NodeId idFromName(const std::string& var_name) const;
418 
419       /// returns the database used by the BNLearner
420       const DatabaseTable<>& database() const;
421 
422       /** @brief assign a weight to all the rows of the learning database so
423        * that the sum of their weights is equal to new_weight */
424       void setDatabaseWeight(const double new_weight);
425 
426       /// sets the weight of the ith record of the database
427       /** @throws OutOfBounds if i is outside the set of indices of the
428        * records or if the weight is negative
429        */
430       void setRecordWeight(const std::size_t i, const double weight);
431 
432       /// returns the weight of the ith record
433       /** @throws OutOfBounds if i is outside the set of indices of the
434        * records */
435       double recordWeight(const std::size_t i) const;
436 
437       /// returns the weight of the whole database
438       double databaseWeight() const;
439 
440       /// returns the variable name corresponding to a given node id
441       const std::string& nameFromId(NodeId id) const;
442 
443       /// use a new set of database rows' ranges to perform learning
444       /** @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
445        * indices. The subsequent learnings are then performed only on the union
446        * of the rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when
447        * performing cross validation tasks, in which part of the database should
448        * be ignored. An empty set of ranges is equivalent to an interval [X,Y)
449        * ranging over the whole database. */
450       template < template < typename > class XALLOC >
451       void useDatabaseRanges(
452          const std::vector< std::pair< std::size_t, std::size_t >,
453                             XALLOC< std::pair< std::size_t, std::size_t > > >& new_ranges);
454 
455       /// reset the ranges to the one range corresponding to the whole database
456       void clearDatabaseRanges();
457 
458       /// returns the current database rows' ranges used for learning
459       /** @return The method returns a vector of pairs [Xi,Yi) of indices of
460        * rows in the database. The learning is performed on these set of rows.
461        * @warning an empty set of ranges means the whole database. */
462       const std::vector< std::pair< std::size_t, std::size_t > >& databaseRanges() const;
463 
464       /// sets the ranges of rows to be used for cross-validation learning
465       /** When applied on (x,k), the method indicates to the subsequent learnings
466        * that they should be performed on the xth fold in a k-fold
467        * cross-validation context. For instance, if a database has 1000 rows,
468        * and if we perform a 10-fold cross-validation, then, the first learning
469        * fold (learning_fold=0) corresponds to rows interval [100,1000) and the
470        * test dataset corresponds to [0,100). The second learning fold
471        * (learning_fold=1) is [0,100) U [200,1000) and the corresponding test
472        * dataset is [100,200).
473        * @param learning_fold a number indicating the set of rows used for
474        * learning. If N denotes the size of the database, and k_fold represents
475        * the number of folds in the cross validation, then the set of rows
476        * used for testing is [learning_fold * N / k_fold,
477        * (learning_fold+1) * N / k_fold) and the learning database is the
478        * complement in the database
479        * @param k_fold the value of "k" in k-fold cross validation
480        * @return a pair [x,y) of rows' indices that corresponds to the indices
481        * of rows in the original database that constitute the test dataset
482        * @throws OutOfBounds is raised if k_fold is equal to 0 or learning_fold
483        * is greater than or eqal to k_fold, or if k_fold is greater than
484        * or equal to the size of the database. */
485       std::pair< std::size_t, std::size_t > useCrossValidationFold(const std::size_t learning_fold,
486                                                                    const std::size_t k_fold);
487 
488 
489       /**
490        * Return the <statistic,pvalue> pair for chi2 test in the database
491        * @param id1 first variable
492        * @param id2 second variable
493        * @param knowing list of observed variables
494        * @return a std::pair<double,double>
495        */
496       std::pair< double, double >
497          chi2(const NodeId id1, const NodeId id2, const std::vector< NodeId >& knowing = {});
498       /**
499        * Return the <statistic,pvalue> pair for the BNLearner
500        * @param id1 first variable
501        * @param id2 second variable
502        * @param knowing list of observed variables
503        * @return a std::pair<double,double>
504        */
505       std::pair< double, double > chi2(const std::string&                name1,
506                                        const std::string&                name2,
507                                        const std::vector< std::string >& knowing = {});
508 
509       /**
510        * Return the <statistic,pvalue> pair for for G2 test in the database
511        * @param id1 first variable
512        * @param id2 second variable
513        * @param knowing list of observed variables
514        * @return a std::pair<double,double>
515        */
516       std::pair< double, double >
517          G2(const NodeId id1, const NodeId id2, const std::vector< NodeId >& knowing = {});
518       /**
519        * Return the <statistic,pvalue> pair for for G2 test in the database
520        * @param id1 first variable
521        * @param id2 second variable
522        * @param knowing list of observed variables
523        * @return a std::pair<double,double>
524        */
525       std::pair< double, double > G2(const std::string&                name1,
526                                      const std::string&                name2,
527                                      const std::vector< std::string >& knowing = {});
528 
529       /**
530        * Return the loglikelihood of vars in the base, conditioned by knowing for
531        * the BNLearner
532        * @param vars a vector of NodeIds
533        * @param knowing an optional vector of conditioning NodeIds
534        * @return a std::pair<double,double>
535        */
536       double logLikelihood(const std::vector< NodeId >& vars,
537                            const std::vector< NodeId >& knowing = {});
538 
539       /**
540        * Return the loglikelihood of vars in the base, conditioned by knowing for
541        * the BNLearner
542        * @param vars a vector of name of rows
543        * @param knowing an optional vector of conditioning rows
544        * @return a std::pair<double,double>
545        */
546       double logLikelihood(const std::vector< std::string >& vars,
547                            const std::vector< std::string >& knowing = {});
548 
549       /**
550        * Return the pseudoconts ofNodeIds vars in the base in a raw array
551        * @param vars a vector of
552        * @return a a std::vector<double> containing the contingency table
553        */
554       std::vector< double > rawPseudoCount(const std::vector< NodeId >& vars);
555 
556       /**
557        * Return the pseudoconts of vars in the base in a raw array
558        * @param vars a vector of name
559        * @return a std::vector<double> containing the contingency table
560        */
561       std::vector< double > rawPseudoCount(const std::vector< std::string >& vars);
562       /**
563        *
564        * @return the number of cols in the database
565        */
566       Size nbCols() const;
567 
568       /**
569        *
570        * @return the number of rows in the database
571        */
572       Size nbRows() const;
573 
574       /** use The EM algorithm to learn paramters
575        *
576        * if epsilon=0, EM is not used
577        */
578       void useEM(const double epsilon);
579 
580       /// returns true if the learner's database has missing values
581       bool hasMissingValues() const;
582 
583       /// @}
584 
585       // ##########################################################################
586       /// @name Score selection
587       // ##########################################################################
588       /// @{
589 
590       /// indicate that we wish to use an AIC score
591       void useScoreAIC();
592 
593       /// indicate that we wish to use a BD score
594       void useScoreBD();
595 
596       /// indicate that we wish to use a BDeu score
597       void useScoreBDeu();
598 
599       /// indicate that we wish to use a BIC score
600       void useScoreBIC();
601 
602       /// indicate that we wish to use a K2 score
603       void useScoreK2();
604 
605       /// indicate that we wish to use a Log2Likelihood score
606       void useScoreLog2Likelihood();
607 
608       /// @}
609 
610       // ##########################################################################
611       /// @name A priori selection / parameterization
612       // ##########################################################################
613       /// @{
614 
615       /// use no apriori
616       void useNoApriori();
617 
618       /// use the BDeu apriori
619       /** The BDeu apriori adds weight to all the cells of the countings
620        * tables. In other words, it adds weight rows in the database with
621        * equally probable values. */
622       void useAprioriBDeu(double weight = 1);
623 
624       /// use the apriori smoothing
625       /** @param weight pass in argument a weight if you wish to assign a weight
626        * to the smoothing, else the current weight of the genericBNLearner will
627        * be used. */
628       void useAprioriSmoothing(double weight = 1);
629 
630       /// use the Dirichlet apriori
631       void useAprioriDirichlet(const std::string& filename, double weight = 1);
632 
633 
634       /// checks whether the current score and apriori are compatible
635       /** @returns a non empty string if the apriori is somehow compatible with the
636        * score.*/
637       std::string checkScoreAprioriCompatibility() const;
638       /// @}
639 
640       // ##########################################################################
641       /// @name Learning algorithm selection
642       // ##########################################################################
643       /// @{
644 
645       /// indicate that we wish to use a greedy hill climbing algorithm
646       void useGreedyHillClimbing();
647 
648       /// indicate that we wish to use a local search with tabu list
649       /** @param tabu_size indicate the size of the tabu list
650        * @param nb_decrease indicate the max number of changes decreasing the
651        * score consecutively that we allow to apply */
652       void useLocalSearchWithTabuList(Size tabu_size = 100, Size nb_decrease = 2);
653 
654       /// indicate that we wish to use K2
655       void useK2(const Sequence< NodeId >& order);
656 
657       /// indicate that we wish to use K2
658       void useK2(const std::vector< NodeId >& order);
659 
660       /// indicate that we wish to use 3off2
661       void use3off2();
662 
663       /// indicate that we wish to use MIIC
664       void useMIIC();
665 
666       /// @}
667 
668       // ##########################################################################
669       /// @name 3off2/MIIC parameterization and specific results
670       // ##########################################################################
671       /// @{
672       /// indicate that we wish to use the NML correction for 3off2 and MIIC
673       /// @throws OperationNotAllowed when 3off2 is not the selected algorithm
674       void useNMLCorrection();
675       /// indicate that we wish to use the MDL correction for 3off2 and MIIC
676       /// @throws OperationNotAllowed when 3off2 is not the selected algorithm
677       void useMDLCorrection();
678       /// indicate that we wish to use the NoCorr correction for 3off2 and MIIC
679       /// @throws OperationNotAllowed when 3off2 is not the selected algorithm
680       void useNoCorrection();
681 
682       /// get the list of arcs hiding latent variables
683       /// @throws OperationNotAllowed when 3off2 or MIIC is not the selected algorithm
684       const std::vector< Arc > latentVariables() const;
685 
686       /// @}
687       // ##########################################################################
688       /// @name Accessors / Modifiers for adding constraints on learning
689       // ##########################################################################
690       /// @{
691 
692       /// sets the max indegree
693       void setMaxIndegree(Size max_indegree);
694 
695       /**
696        * sets a partial order on the nodes
697        * @param slice_order a NodeProperty given the rank (priority) of nodes in
698        * the partial order
699        */
700       void setSliceOrder(const NodeProperty< NodeId >& slice_order);
701 
702       /**
703        * sets a partial order on the nodes
704        * @param slices the list of list of variable names
705        */
706       void setSliceOrder(const std::vector< std::vector< std::string > >& slices);
707 
708       /// assign a set of forbidden arcs
709       void setForbiddenArcs(const ArcSet& set);
710 
711       /// @name assign a new forbidden arc
712       /// @{
713       void addForbiddenArc(const Arc& arc);
714       void addForbiddenArc(const NodeId tail, const NodeId head);
715       void addForbiddenArc(const std::string& tail, const std::string& head);
716       /// @}
717 
718       /// @name remove a forbidden arc
719       /// @{
720       void eraseForbiddenArc(const Arc& arc);
721       void eraseForbiddenArc(const NodeId tail, const NodeId head);
722       void eraseForbiddenArc(const std::string& tail, const std::string& head);
723       ///@}
724 
725       /// assign a set of forbidden arcs
726       void setMandatoryArcs(const ArcSet& set);
727 
728       /// @name assign a new forbidden arc
729       ///@{
730       void addMandatoryArc(const Arc& arc);
731       void addMandatoryArc(const NodeId tail, const NodeId head);
732       void addMandatoryArc(const std::string& tail, const std::string& head);
733       ///@}
734 
735       /// @name remove a forbidden arc
736       ///@{
737       void eraseMandatoryArc(const Arc& arc);
738       void eraseMandatoryArc(const NodeId tail, const NodeId head);
739       void eraseMandatoryArc(const std::string& tail, const std::string& head);
740       /// @}
741 
742       /// assign a set of forbidden edges
743       /// @warning Once at least one possible edge is defined, all other edges are
744       /// not possible anymore
745       /// @{
746       void setPossibleEdges(const EdgeSet& set);
747       void setPossibleSkeleton(const UndiGraph& skeleton);
748       /// @}
749 
750       /// @name assign a new possible edge
751       /// @warning Once at least one possible edge is defined, all other edges are
752       /// not possible anymore
753       /// @{
754       void addPossibleEdge(const Edge& edge);
755       void addPossibleEdge(const NodeId tail, const NodeId head);
756       void addPossibleEdge(const std::string& tail, const std::string& head);
757       /// @}
758 
759       /// @name remove a possible edge
760       /// @{
761       void erasePossibleEdge(const Edge& edge);
762       void erasePossibleEdge(const NodeId tail, const NodeId head);
763       void erasePossibleEdge(const std::string& tail, const std::string& head);
764       ///@}
765 
766       ///@}
767 
768       protected:
769       /// the policy for typing variables
770       bool inducedTypes_;
771 
772       /// the score selected for learning
773       ScoreType scoreType_{ScoreType::BDeu};
774 
775       /// the score used
776       Score<>* score_{nullptr};
777 
778       /// the type of the parameter estimator
779       ParamEstimatorType paramEstimatorType_{ParamEstimatorType::ML};
780 
781       /// epsilon for EM. if espilon=0.0 : no EM
782       double epsilonEM_{0.0};
783 
784       /// the selected correction for 3off2 and miic
785       CorrectedMutualInformation<>* mutualInfo_{nullptr};
786 
787       /// the a priori selected for the score and parameters
788       AprioriType aprioriType_{AprioriType::NO_APRIORI};
789 
790       /// the apriori used
791       Apriori<>* apriori_{nullptr};
792 
793       AprioriNoApriori<>* noApriori_{nullptr};
794 
795       /// the weight of the apriori
796       double aprioriWeight_{1.0f};
797 
798       /// the constraint for 2TBNs
799       StructuralConstraintSliceOrder constraintSliceOrder_;
800 
801       /// the constraint for indegrees
802       StructuralConstraintIndegree constraintIndegree_;
803 
804       /// the constraint for tabu lists
805       StructuralConstraintTabuList constraintTabuList_;
806 
807       /// the constraint on forbidden arcs
808       StructuralConstraintForbiddenArcs constraintForbiddenArcs_;
809 
810       /// the constraint on possible Edges
811       StructuralConstraintPossibleEdges constraintPossibleEdges_;
812 
813       /// the constraint on mandatory arcs
814       StructuralConstraintMandatoryArcs constraintMandatoryArcs_;
815 
816       /// the selected learning algorithm
817       AlgoType selectedAlgo_{AlgoType::GREEDY_HILL_CLIMBING};
818 
819       /// the K2 algorithm
820       K2 algoK2_;
821 
822       /// the MIIC or 3off2 algorithm
823       Miic algoMiic3off2_;
824 
825       /// the penalty used in 3off2
826       typename CorrectedMutualInformation<>::KModeTypes kmode3Off2_{
827          CorrectedMutualInformation<>::KModeTypes::MDL};
828 
829       /// the parametric EM
830       DAG2BNLearner<> Dag2BN_;
831 
832       /// the greedy hill climbing algorithm
833       GreedyHillClimbing greedyHillClimbing_;
834 
835       /// the local search with tabu list algorithm
836       LocalSearchWithTabuList localSearchWithTabuList_;
837 
838       /// the database to be used by the scores and parameter estimators
839       Database scoreDatabase_;
840 
841       /// the set of rows' ranges within the database in which learning is done
842       std::vector< std::pair< std::size_t, std::size_t > > ranges_;
843 
844       /// the database used by the Dirichlet a priori
845       Database* aprioriDatabase_{nullptr};
846 
847       /// the filename for the Dirichlet a priori, if any
848       std::string aprioriDbname_;
849 
850 
851       /// an initial DAG given to learners
852       DAG initialDag_;
853 
854       /// the filename database
855       std::string filename_;
856 
857       // size of the tabu list
858       Size nbDecreasingChanges_{2};
859 
860       // the current algorithm as an approximationScheme
861       const ApproximationScheme* currentAlgorithm_{nullptr};
862 
863       /// reads a file and returns a databaseVectInRam
864       static DatabaseTable<> readFile_(const std::string&                filename,
865                                        const std::vector< std::string >& missing_symbols);
866 
867       /// checks whether the extension of a CSV filename is correct
868       static void isCSVFileName_(const std::string& filename);
869 
870       /// create the apriori used for learning
871       void createApriori_();
872 
873       /// create the score used for learning
874       void createScore_();
875 
876       /// create the parameter estimator used for learning
877       ParamEstimator<>* createParamEstimator_(DBRowGeneratorParser<>& parser,
878                                               bool take_into_account_score = true);
879 
880       /// returns the DAG learnt
881       DAG learnDag_();
882 
883       /// prepares the initial graph for 3off2 or miic
884       MixedGraph prepareMiic3Off2_();
885 
886       /// returns the type (as a string) of a given apriori
887       const std::string& getAprioriType_() const;
888 
889       /// create the Corrected Mutual Information instance for Miic/3off2
890       void createCorrectedMutualInformation_();
891 
892 
893       public:
894       // ##########################################################################
895       /// @name redistribute signals AND implementation of interface
896       /// IApproximationSchemeConfiguration
897       // ##########################################################################
898       // in order to not pollute the proper code of genericBNLearner, we
899       // directly
900       // implement those
901       // very simples methods here.
902       /// {@    /// distribute signals
setCurrentApproximationScheme(const ApproximationScheme * approximationScheme)903       INLINE void setCurrentApproximationScheme(const ApproximationScheme* approximationScheme) {
904         currentAlgorithm_ = approximationScheme;
905       }
906 
distributeProgress(const ApproximationScheme * approximationScheme,Size pourcent,double error,double time)907       INLINE void distributeProgress(const ApproximationScheme* approximationScheme,
908                                      Size                       pourcent,
909                                      double                     error,
910                                      double                     time) {
911         setCurrentApproximationScheme(approximationScheme);
912 
913         if (onProgress.hasListener()) GUM_EMIT3(onProgress, pourcent, error, time);
914       };
915 
916       /// distribute signals
distributeStop(const ApproximationScheme * approximationScheme,std::string message)917       INLINE void distributeStop(const ApproximationScheme* approximationScheme,
918                                  std::string                message) {
919         setCurrentApproximationScheme(approximationScheme);
920 
921         if (onStop.hasListener()) GUM_EMIT1(onStop, message);
922       };
923       /// @}
924 
925       /// Given that we approximate f(t), stopping criterion on |f(t+1)-f(t)|
926       /// If the criterion was disabled it will be enabled
927       /// @{
928       /// @throw OutOfBounds if eps<0
setEpsilon(double eps)929       void setEpsilon(double eps) {
930         algoK2_.approximationScheme().setEpsilon(eps);
931         greedyHillClimbing_.setEpsilon(eps);
932         localSearchWithTabuList_.setEpsilon(eps);
933         Dag2BN_.setEpsilon(eps);
934       };
935 
936       /// Get the value of epsilon
epsilon()937       double epsilon() const {
938         if (currentAlgorithm_ != nullptr)
939           return currentAlgorithm_->epsilon();
940         else
941           GUM_ERROR(FatalError, "No chosen algorithm for learning")
942       }
943 
944       /// Disable stopping criterion on epsilon
disableEpsilon()945       void disableEpsilon() {
946         algoK2_.approximationScheme().disableEpsilon();
947         greedyHillClimbing_.disableEpsilon();
948         localSearchWithTabuList_.disableEpsilon();
949         Dag2BN_.disableEpsilon();
950       };
951 
952       /// Enable stopping criterion on epsilon
enableEpsilon()953       void enableEpsilon() {
954         algoK2_.approximationScheme().enableEpsilon();
955         greedyHillClimbing_.enableEpsilon();
956         localSearchWithTabuList_.enableEpsilon();
957         Dag2BN_.enableEpsilon();
958       };
959 
960       /// @return true if stopping criterion on epsilon is enabled, false
961       /// otherwise
isEnabledEpsilon()962       bool isEnabledEpsilon() const {
963         if (currentAlgorithm_ != nullptr)
964           return currentAlgorithm_->isEnabledEpsilon();
965         else
966           GUM_ERROR(FatalError, "No chosen algorithm for learning")
967       }
968       /// @}
969 
970       /// Given that we approximate f(t), stopping criterion on
971       /// d/dt(|f(t+1)-f(t)|)
972       /// If the criterion was disabled it will be enabled
973       /// @{
974       /// @throw OutOfBounds if rate<0
setMinEpsilonRate(double rate)975       void setMinEpsilonRate(double rate) {
976         algoK2_.approximationScheme().setMinEpsilonRate(rate);
977         greedyHillClimbing_.setMinEpsilonRate(rate);
978         localSearchWithTabuList_.setMinEpsilonRate(rate);
979         Dag2BN_.setMinEpsilonRate(rate);
980       };
981 
982       /// Get the value of the minimal epsilon rate
minEpsilonRate()983       double minEpsilonRate() const {
984         if (currentAlgorithm_ != nullptr)
985           return currentAlgorithm_->minEpsilonRate();
986         else
987           GUM_ERROR(FatalError, "No chosen algorithm for learning")
988       }
989 
990       /// Disable stopping criterion on epsilon rate
disableMinEpsilonRate()991       void disableMinEpsilonRate() {
992         algoK2_.approximationScheme().disableMinEpsilonRate();
993         greedyHillClimbing_.disableMinEpsilonRate();
994         localSearchWithTabuList_.disableMinEpsilonRate();
995         Dag2BN_.disableMinEpsilonRate();
996       };
997       /// Enable stopping criterion on epsilon rate
enableMinEpsilonRate()998       void enableMinEpsilonRate() {
999         algoK2_.approximationScheme().enableMinEpsilonRate();
1000         greedyHillClimbing_.enableMinEpsilonRate();
1001         localSearchWithTabuList_.enableMinEpsilonRate();
1002         Dag2BN_.enableMinEpsilonRate();
1003       };
1004       /// @return true if stopping criterion on epsilon rate is enabled, false
1005       /// otherwise
isEnabledMinEpsilonRate()1006       bool isEnabledMinEpsilonRate() const {
1007         if (currentAlgorithm_ != nullptr)
1008           return currentAlgorithm_->isEnabledMinEpsilonRate();
1009         else
1010           GUM_ERROR(FatalError, "No chosen algorithm for learning")
1011       }
1012       /// @}
1013 
1014       /// stopping criterion on number of iterations
1015       /// @{
1016       /// If the criterion was disabled it will be enabled
1017       /// @param max The maximum number of iterations
1018       /// @throw OutOfBounds if max<=1
setMaxIter(Size max)1019       void setMaxIter(Size max) {
1020         algoK2_.approximationScheme().setMaxIter(max);
1021         greedyHillClimbing_.setMaxIter(max);
1022         localSearchWithTabuList_.setMaxIter(max);
1023         Dag2BN_.setMaxIter(max);
1024       };
1025 
1026       /// @return the criterion on number of iterations
maxIter()1027       Size maxIter() const {
1028         if (currentAlgorithm_ != nullptr)
1029           return currentAlgorithm_->maxIter();
1030         else
1031           GUM_ERROR(FatalError, "No chosen algorithm for learning")
1032       }
1033 
1034       /// Disable stopping criterion on max iterations
disableMaxIter()1035       void disableMaxIter() {
1036         algoK2_.approximationScheme().disableMaxIter();
1037         greedyHillClimbing_.disableMaxIter();
1038         localSearchWithTabuList_.disableMaxIter();
1039         Dag2BN_.disableMaxIter();
1040       };
1041       /// Enable stopping criterion on max iterations
enableMaxIter()1042       void enableMaxIter() {
1043         algoK2_.approximationScheme().enableMaxIter();
1044         greedyHillClimbing_.enableMaxIter();
1045         localSearchWithTabuList_.enableMaxIter();
1046         Dag2BN_.enableMaxIter();
1047       };
1048       /// @return true if stopping criterion on max iterations is enabled, false
1049       /// otherwise
isEnabledMaxIter()1050       bool isEnabledMaxIter() const {
1051         if (currentAlgorithm_ != nullptr)
1052           return currentAlgorithm_->isEnabledMaxIter();
1053         else
1054           GUM_ERROR(FatalError, "No chosen algorithm for learning")
1055       }
1056       /// @}
1057 
1058       /// stopping criterion on timeout
1059       /// If the criterion was disabled it will be enabled
1060       /// @{
1061       /// @throw OutOfBounds if timeout<=0.0
1062       /** timeout is time in second (double).
1063        */
setMaxTime(double timeout)1064       void setMaxTime(double timeout) {
1065         algoK2_.approximationScheme().setMaxTime(timeout);
1066         greedyHillClimbing_.setMaxTime(timeout);
1067         localSearchWithTabuList_.setMaxTime(timeout);
1068         Dag2BN_.setMaxTime(timeout);
1069       }
1070 
1071       /// returns the timeout (in seconds)
maxTime()1072       double maxTime() const {
1073         if (currentAlgorithm_ != nullptr)
1074           return currentAlgorithm_->maxTime();
1075         else
1076           GUM_ERROR(FatalError, "No chosen algorithm for learning")
1077       }
1078 
1079       /// get the current running time in second (double)
currentTime()1080       double currentTime() const {
1081         if (currentAlgorithm_ != nullptr)
1082           return currentAlgorithm_->currentTime();
1083         else
1084           GUM_ERROR(FatalError, "No chosen algorithm for learning")
1085       }
1086 
1087       /// Disable stopping criterion on timeout
disableMaxTime()1088       void disableMaxTime() {
1089         algoK2_.approximationScheme().disableMaxTime();
1090         greedyHillClimbing_.disableMaxTime();
1091         localSearchWithTabuList_.disableMaxTime();
1092         Dag2BN_.disableMaxTime();
1093       };
enableMaxTime()1094       void enableMaxTime() {
1095         algoK2_.approximationScheme().enableMaxTime();
1096         greedyHillClimbing_.enableMaxTime();
1097         localSearchWithTabuList_.enableMaxTime();
1098         Dag2BN_.enableMaxTime();
1099       };
1100       /// @return true if stopping criterion on timeout is enabled, false
1101       /// otherwise
isEnabledMaxTime()1102       bool isEnabledMaxTime() const {
1103         if (currentAlgorithm_ != nullptr)
1104           return currentAlgorithm_->isEnabledMaxTime();
1105         else
1106           GUM_ERROR(FatalError, "No chosen algorithm for learning")
1107       }
1108       /// @}
1109 
1110       /// how many samples between 2 stopping isEnableds
1111       /// @{
1112       /// @throw OutOfBounds if p<1
setPeriodSize(Size p)1113       void setPeriodSize(Size p) {
1114         algoK2_.approximationScheme().setPeriodSize(p);
1115         greedyHillClimbing_.setPeriodSize(p);
1116         localSearchWithTabuList_.setPeriodSize(p);
1117         Dag2BN_.setPeriodSize(p);
1118       };
1119 
periodSize()1120       Size periodSize() const {
1121         if (currentAlgorithm_ != nullptr)
1122           return currentAlgorithm_->periodSize();
1123         else
1124           GUM_ERROR(FatalError, "No chosen algorithm for learning")
1125       }
1126       /// @}
1127 
1128       /// verbosity
1129       /// @{
setVerbosity(bool v)1130       void setVerbosity(bool v) {
1131         algoK2_.approximationScheme().setVerbosity(v);
1132         greedyHillClimbing_.setVerbosity(v);
1133         localSearchWithTabuList_.setVerbosity(v);
1134         Dag2BN_.setVerbosity(v);
1135       };
1136 
verbosity()1137       bool verbosity() const {
1138         if (currentAlgorithm_ != nullptr)
1139           return currentAlgorithm_->verbosity();
1140         else
1141           GUM_ERROR(FatalError, "No chosen algorithm for learning")
1142       }
1143       /// @}
1144 
1145       /// history
1146       /// @{
1147 
stateApproximationScheme()1148       ApproximationSchemeSTATE stateApproximationScheme() const {
1149         if (currentAlgorithm_ != nullptr)
1150           return currentAlgorithm_->stateApproximationScheme();
1151         else
1152           GUM_ERROR(FatalError, "No chosen algorithm for learning")
1153       }
1154 
1155       /// @throw OperationNotAllowed if scheme not performed
nbrIterations()1156       Size nbrIterations() const {
1157         if (currentAlgorithm_ != nullptr)
1158           return currentAlgorithm_->nbrIterations();
1159         else
1160           GUM_ERROR(FatalError, "No chosen algorithm for learning")
1161       }
1162 
1163       /// @throw OperationNotAllowed if scheme not performed or verbosity=false
history()1164       const std::vector< double >& history() const {
1165         if (currentAlgorithm_ != nullptr)
1166           return currentAlgorithm_->history();
1167         else
1168           GUM_ERROR(FatalError, "No chosen algorithm for learning")
1169       }
1170       /// @}
1171     };
1172 
1173   } /* namespace learning */
1174 
1175 } /* namespace gum */
1176 
1177 /// include the inlined functions if necessary
1178 #ifndef GUM_NO_INLINE
1179 #  include <agrum/BN/learning/BNLearnUtils/genericBNLearner_inl.h>
1180 #endif /* GUM_NO_INLINE */
1181 
1182 #include <agrum/BN/learning/BNLearnUtils/genericBNLearner_tpl.h>
1183 
1184 #endif /* GUM_LEARNING_GENERIC_BN_LEARNER_H */
1185