1 /**
2  *
3  *   Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU)
4  *   info_at_agrum_dot_org
5  *
6  *  This library is free software: you can redistribute it and/or modify
7  *  it under the terms of the GNU Lesser General Public License as published by
8  *  the Free Software Foundation, either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  This library is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU Lesser General Public License for more details.
15  *
16  *  You should have received a copy of the GNU Lesser General Public License
17  *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /** @file
23  * @brief A pack of learning algorithms that can easily be used
24  *
25  * The pack currently contains K2, GreedyHillClimbing and
26  *LocalSearchWithTabuList
27  *
28  * @author Christophe GONZALES(_at_AMU) and Pierre-Henri WUILLEMIN(_at_LIP6)
29  */
30 #include <fstream>
31 
32 #ifndef DOXYGEN_SHOULD_SKIP_THIS
33 
34 // to help IDE parser
35 #  include <agrum/BN/learning/BNLearner.h>
36 
37 #  include <agrum/BN/learning/BNLearnUtils/BNLearnerListener.h>
38 
39 namespace gum {
40 
41   namespace learning {
42     template < typename GUM_SCALAR >
BNLearner(const std::string & filename,const bool induceTypes,const std::vector<std::string> & missingSymbols)43     BNLearner< GUM_SCALAR >::BNLearner(const std::string&                filename,
44                                        const bool                        induceTypes,
45                                        const std::vector< std::string >& missingSymbols) :
46         genericBNLearner(filename, missingSymbols, induceTypes) {
47       GUM_CONSTRUCTOR(BNLearner);
48     }
49 
50     template < typename GUM_SCALAR >
BNLearner(const DatabaseTable<> & db)51     BNLearner< GUM_SCALAR >::BNLearner(const DatabaseTable<>& db) : genericBNLearner(db) {
52       GUM_CONSTRUCTOR(BNLearner);
53     }
54 
55     template < typename GUM_SCALAR >
BNLearner(const std::string & filename,const gum::BayesNet<GUM_SCALAR> & bn,const std::vector<std::string> & missing_symbols)56     BNLearner< GUM_SCALAR >::BNLearner(const std::string&                 filename,
57                                        const gum::BayesNet< GUM_SCALAR >& bn,
58                                        const std::vector< std::string >&  missing_symbols) :
59         genericBNLearner(filename, bn, missing_symbols) {
60       GUM_CONSTRUCTOR(BNLearner);
61     }
62 
63     /// copy constructor
64     template < typename GUM_SCALAR >
BNLearner(const BNLearner<GUM_SCALAR> & src)65     BNLearner< GUM_SCALAR >::BNLearner(const BNLearner< GUM_SCALAR >& src) : genericBNLearner(src) {
66       GUM_CONSTRUCTOR(BNLearner);
67     }
68 
69     /// move constructor
70     template < typename GUM_SCALAR >
BNLearner(BNLearner<GUM_SCALAR> && src)71     BNLearner< GUM_SCALAR >::BNLearner(BNLearner< GUM_SCALAR >&& src) : genericBNLearner(src) {
72       GUM_CONSTRUCTOR(BNLearner);
73     }
74 
75     /// destructor
76     template < typename GUM_SCALAR >
~BNLearner()77     BNLearner< GUM_SCALAR >::~BNLearner() {
78       GUM_DESTRUCTOR(BNLearner);
79     }
80 
81     /// @}
82 
83     // ##########################################################################
84     /// @name Operators
85     // ##########################################################################
86     /// @{
87 
88     /// copy operator
89     template < typename GUM_SCALAR >
90     BNLearner< GUM_SCALAR >&
91        BNLearner< GUM_SCALAR >::operator=(const BNLearner< GUM_SCALAR >& src) {
92       genericBNLearner::operator=(src);
93       return *this;
94     }
95 
96     /// move operator
97     template < typename GUM_SCALAR >
98     BNLearner< GUM_SCALAR >& BNLearner< GUM_SCALAR >::operator=(BNLearner< GUM_SCALAR >&& src) {
99       genericBNLearner::operator=(std::move(src));
100       return *this;
101     }
102 
103     /// learn a Bayes Net from a file
104     template < typename GUM_SCALAR >
learnBN()105     BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnBN() {
106       // create the score, the apriori and the estimator
107       auto notification = checkScoreAprioriCompatibility();
108       if (notification != "") { std::cout << "[aGrUM notification] " << notification << std::endl; }
109       createApriori_();
110       createScore_();
111 
112       std::unique_ptr< ParamEstimator<> > param_estimator(
113          createParamEstimator_(scoreDatabase_.parser(), true));
114 
115       return Dag2BN_.createBN< GUM_SCALAR >(*(param_estimator.get()), learnDag_());
116     }
117 
118     /// learns a BN (its parameters) when its structure is known
119     template < typename GUM_SCALAR >
learnParameters(const DAG & dag,bool takeIntoAccountScore)120     BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnParameters(const DAG& dag,
121                                                                     bool takeIntoAccountScore) {
122       // if the dag contains no node, return an empty BN
123       if (dag.size() == 0) return BayesNet< GUM_SCALAR >();
124 
125       // check that the dag corresponds to the database
126       std::vector< NodeId > ids;
127       ids.reserve(dag.sizeNodes());
128       for (const auto node: dag)
129         ids.push_back(node);
130       std::sort(ids.begin(), ids.end());
131 
132       if (ids.back() >= scoreDatabase_.names().size()) {
133         std::stringstream str;
134         str << "Learning parameters corresponding to the dag is impossible "
135             << "because the database does not contain the following nodeID";
136         std::vector< NodeId > bad_ids;
137         for (const auto node: ids) {
138           if (node >= scoreDatabase_.names().size()) bad_ids.push_back(node);
139         }
140         if (bad_ids.size() > 1) str << 's';
141         str << ": ";
142         bool deja = false;
143         for (const auto node: bad_ids) {
144           if (deja)
145             str << ", ";
146           else
147             deja = true;
148           str << node;
149         }
150         GUM_ERROR(MissingVariableInDatabase, str.str())
151       }
152 
153       // create the apriori
154       createApriori_();
155 
156       if (epsilonEM_ == 0.0) {
157         // check that the database does not contain any missing value
158         if (scoreDatabase_.databaseTable().hasMissingValues()
159             || ((aprioriDatabase_ != nullptr)
160                 && (aprioriType_ == AprioriType::DIRICHLET_FROM_DATABASE)
161                 && aprioriDatabase_->databaseTable().hasMissingValues())) {
162           GUM_ERROR(MissingValueInDatabase,
163                     "In general, the BNLearner is unable to cope with "
164                        << "missing values in databases. To learn parameters in "
165                        << "such situations, you should first use method "
166                        << "useEM()");
167         }
168 
169         // create the usual estimator
170         DBRowGeneratorParser<>              parser(scoreDatabase_.databaseTable().handler(),
171                                       DBRowGeneratorSet<>());
172         std::unique_ptr< ParamEstimator<> > param_estimator(
173            createParamEstimator_(parser, takeIntoAccountScore));
174 
175         return Dag2BN_.createBN< GUM_SCALAR >(*(param_estimator.get()), dag);
176       } else {
177         // EM !
178         BNLearnerListener listener(this, Dag2BN_);
179 
180         // get the column types
181         const auto&       database = scoreDatabase_.databaseTable();
182         const std::size_t nb_vars  = database.nbVariables();
183         const std::vector< gum::learning::DBTranslatedValueType > col_types(
184            nb_vars,
185            gum::learning::DBTranslatedValueType::DISCRETE);
186 
187         // create the bootstrap estimator
188         DBRowGenerator4CompleteRows<> generator_bootstrap(col_types);
189         DBRowGeneratorSet<>           genset_bootstrap;
190         genset_bootstrap.insertGenerator(generator_bootstrap);
191         DBRowGeneratorParser<>              parser_bootstrap(database.handler(), genset_bootstrap);
192         std::unique_ptr< ParamEstimator<> > param_estimator_bootstrap(
193            createParamEstimator_(parser_bootstrap, takeIntoAccountScore));
194 
195         // create the EM estimator
196         BayesNet< GUM_SCALAR >         dummy_bn;
197         DBRowGeneratorEM< GUM_SCALAR > generator_EM(col_types, dummy_bn);
198         DBRowGenerator<>&              gen_EM = generator_EM;   // fix for g++-4.8
199         DBRowGeneratorSet<>            genset_EM;
200         genset_EM.insertGenerator(gen_EM);
201         DBRowGeneratorParser<>              parser_EM(database.handler(), genset_EM);
202         std::unique_ptr< ParamEstimator<> > param_estimator_EM(
203            createParamEstimator_(parser_EM, takeIntoAccountScore));
204 
205         Dag2BN_.setEpsilon(epsilonEM_);
206         return Dag2BN_.createBN< GUM_SCALAR >(*(param_estimator_bootstrap.get()),
207                                               *(param_estimator_EM.get()),
208                                               dag);
209       }
210     }
211 
212 
213     /// learns a BN (its parameters) when its structure is known
214     template < typename GUM_SCALAR >
learnParameters(bool take_into_account_score)215     BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnParameters(bool take_into_account_score) {
216       return learnParameters(initialDag_, take_into_account_score);
217     }
218 
219 
220     template < typename GUM_SCALAR >
221     NodeProperty< Sequence< std::string > >
_labelsFromBN_(const std::string & filename,const BayesNet<GUM_SCALAR> & src)222        BNLearner< GUM_SCALAR >::_labelsFromBN_(const std::string&            filename,
223                                                const BayesNet< GUM_SCALAR >& src) {
224       std::ifstream in(filename, std::ifstream::in);
225 
226       if ((in.rdstate() & std::ifstream::failbit) != 0) {
227         GUM_ERROR(gum::IOError, "File " << filename << " not found")
228       }
229 
230       CSVParser<> parser(in, filename);
231       parser.next();
232       auto names = parser.current();
233 
234       NodeProperty< Sequence< std::string > > modals;
235 
236       for (gum::Idx col = 0; col < names.size(); col++) {
237         try {
238           gum::NodeId graphId = src.idFromName(names[col]);
239           modals.insert(col, gum::Sequence< std::string >());
240 
241           for (gum::Size i = 0; i < src.variable(graphId).domainSize(); ++i)
242             modals[col].insert(src.variable(graphId).label(i));
243         } catch (const gum::NotFound&) {
244           // no problem : a column not in the BN...
245         }
246       }
247 
248       return modals;
249     }
250 
251 
252     template < typename GUM_SCALAR >
toString()253     std::string BNLearner< GUM_SCALAR >::toString() const {
254       const auto st = state();
255 
256       Size maxkey = 0;
257       for (const auto& tuple: st)
258         if (std::get< 0 >(tuple).length() > maxkey) maxkey = std::get< 0 >(tuple).length();
259 
260       std::stringstream s;
261       for (const auto& tuple: st) {
262         s << std::setiosflags(std::ios::left) << std::setw(maxkey) << std::get< 0 >(tuple) << " : "
263           << std::get< 1 >(tuple);
264         if (std::get< 2 >(tuple) != "") s << "  (" << std::get< 2 >(tuple) << ")";
265         s << std::endl;
266       }
267       return s.str();
268     }
269 
270     template < typename GUM_SCALAR >
271     std::vector< std::tuple< std::string, std::string, std::string > >
state()272        BNLearner< GUM_SCALAR >::state() const {
273       std::vector< std::tuple< std::string, std::string, std::string > > vals;
274 
275       std::string key;
276       std::string comment;
277       const auto& db = database();
278 
279       vals.emplace_back("Filename", filename_, "");
280       vals.emplace_back("Size",
281                         "(" + std::to_string(nbRows()) + "," + std::to_string(nbCols()) + ")",
282                         "");
283 
284       std::string vars = "";
285       for (NodeId i = 0; i < db.nbVariables(); i++) {
286         if (i > 0) vars += ", ";
287         vars += nameFromId(i) + "[" + std::to_string(db.domainSize(i)) + "]";
288       }
289       vals.emplace_back("Variables", vars, "");
290       vals.emplace_back("Induced types", inducedTypes_ ? "True" : "False", "");
291       vals.emplace_back("Missing values", hasMissingValues() ? "True" : "False", "");
292 
293       key = "Algorithm";
294       switch (selectedAlgo_) {
295         case AlgoType::GREEDY_HILL_CLIMBING:
296           vals.emplace_back(key, "Greedy Hill Climbing", "");
297           break;
298         case AlgoType::K2: {
299           vals.emplace_back(key, "K2", "");
300           const auto& k2order = algoK2_.order();
301           vars                = "";
302           for (NodeId i = 0; i < k2order.size(); i++) {
303             if (i > 0) vars += ", ";
304             vars += nameFromId(k2order.atPos(i));
305           }
306           vals.emplace_back("K2 order", vars, "");
307         } break;
308         case AlgoType::LOCAL_SEARCH_WITH_TABU_LIST:
309           vals.emplace_back(key, "Local Search with Tabu List", "");
310           vals.emplace_back("Tabu list size", std::to_string(nbDecreasingChanges_), "");
311           break;
312         case AlgoType::THREE_OFF_TWO:
313           vals.emplace_back(key, "3off2", "");
314           break;
315         case AlgoType::MIIC:
316           vals.emplace_back(key, "MIIC", "");
317           break;
318         default:
319           vals.emplace_back(key, "(unknown)", "?");
320           break;
321       }
322 
323       if (selectedAlgo_ != AlgoType::MIIC && selectedAlgo_ != AlgoType::THREE_OFF_TWO) {
324         key = "Score";
325         switch (scoreType_) {
326           case ScoreType::K2:
327             vals.emplace_back(key, "K2", "");
328             break;
329           case ScoreType::AIC:
330             vals.emplace_back(key, "AIC", "");
331             break;
332           case ScoreType::BIC:
333             vals.emplace_back(key, "BIC", "");
334             break;
335           case ScoreType::BD:
336             vals.emplace_back(key, "BD", "");
337             break;
338           case ScoreType::BDeu:
339             vals.emplace_back(key, "BDeu", "");
340             break;
341           case ScoreType::LOG2LIKELIHOOD:
342             vals.emplace_back(key, "Log2Likelihood", "");
343             break;
344           default:
345             vals.emplace_back(key, "(unknown)", "?");
346             break;
347         }
348       } else {
349         key = "Correction";
350         switch (kmode3Off2_) {
351           case CorrectedMutualInformation<>::KModeTypes::MDL:
352             vals.emplace_back(key, "MDL", "");
353             break;
354           case CorrectedMutualInformation<>::KModeTypes::NML:
355             vals.emplace_back(key, "NML", "");
356             break;
357           case CorrectedMutualInformation<>::KModeTypes::NoCorr:
358             vals.emplace_back(key, "No correction", "");
359             break;
360           default:
361             vals.emplace_back(key, "(unknown)", "?");
362             break;
363         }
364       }
365 
366 
367       key     = "Prior";
368       comment = checkScoreAprioriCompatibility();
369       switch (aprioriType_) {
370         case AprioriType::NO_APRIORI:
371           vals.emplace_back(key, "-", comment);
372           break;
373         case AprioriType::DIRICHLET_FROM_DATABASE:
374           vals.emplace_back(key, "Dirichlet", comment);
375           vals.emplace_back("Dirichlet database", aprioriDbname_, "");
376           break;
377         case AprioriType::BDEU:
378           vals.emplace_back(key, "BDEU", comment);
379           break;
380         case AprioriType::SMOOTHING:
381           vals.emplace_back(key, "Smoothing", comment);
382           break;
383         default:
384           vals.emplace_back(key, "(unknown)", "?");
385           break;
386       }
387 
388       if (aprioriType_ != AprioriType::NO_APRIORI)
389         vals.emplace_back("Prior weight", std::to_string(aprioriWeight_), "");
390 
391       if (databaseWeight() != double(nbRows())) {
392         vals.emplace_back("Database weight", std::to_string(databaseWeight()), "");
393       }
394 
395       if (epsilonEM_ > 0.0) {
396         comment = "";
397         if (!hasMissingValues()) comment = "But no missing values in this database";
398         vals.emplace_back("EM", "True", "");
399         vals.emplace_back("EM epsilon", std::to_string(epsilonEM_), comment);
400       }
401 
402       std::string res;
403       bool        nofirst;
404       if (constraintIndegree_.maxIndegree() < std::numeric_limits< Size >::max()) {
405         vals.emplace_back("Constraint Max InDegree",
406                           std::to_string(constraintIndegree_.maxIndegree()),
407                           "Used only for score-based algorithms.");
408       }
409       if (!constraintForbiddenArcs_.arcs().empty()) {
410         res     = "{";
411         nofirst = false;
412         for (const auto& arc: constraintForbiddenArcs_.arcs()) {
413           if (nofirst)
414             res += ", ";
415           else
416             nofirst = true;
417           res += nameFromId(arc.tail()) + "->" + nameFromId(arc.head());
418         }
419         res += "}";
420         vals.emplace_back("Constraint Forbidden Arcs", res, "");
421       }
422       if (!constraintMandatoryArcs_.arcs().empty()) {
423         res     = "{";
424         nofirst = false;
425         for (const auto& arc: constraintMandatoryArcs_.arcs()) {
426           if (nofirst)
427             res += ", ";
428           else
429             nofirst = true;
430           res += nameFromId(arc.tail()) + "->" + nameFromId(arc.head());
431         }
432         res += "}";
433         vals.emplace_back("Constraint Mandatory Arcs", res, "");
434       }
435       if (!constraintPossibleEdges_.edges().empty()) {
436         res     = "{";
437         nofirst = false;
438         for (const auto& edge: constraintPossibleEdges_.edges()) {
439           if (nofirst)
440             res += ", ";
441           else
442             nofirst = true;
443           res += nameFromId(edge.first()) + "--" + nameFromId(edge.second());
444         }
445         res += "}";
446         vals.emplace_back("Constraint Possible Edges",
447                           res,
448                           "Used only for score-based algorithms.");
449       }
450       if (!constraintSliceOrder_.sliceOrder().empty()) {
451         res               = "{";
452         nofirst           = false;
453         const auto& order = constraintSliceOrder_.sliceOrder();
454         for (const auto& p: order) {
455           if (nofirst)
456             res += ", ";
457           else
458             nofirst = true;
459           res += nameFromId(p.first) + ":" + std::to_string(p.second);
460         }
461         res += "}";
462         vals.emplace_back("Constraint Slice Order", res, "Used only for score-based algorithms.");
463       }
464       if (initialDag_.size() != 0) {
465         vals.emplace_back("Initial DAG", "True", initialDag_.toDot());
466       }
467 
468       return vals;
469     }
470 
471     template < typename GUM_SCALAR >
472     INLINE std::ostream& operator<<(std::ostream& output, const BNLearner< GUM_SCALAR >& learner) {
473       output << learner.toString();
474       return output;
475     }
476 
477   } /* namespace learning */
478 
479 } /* namespace gum */
480 
481 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
482