1 /** 2 * 3 * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU) 4 * info_at_agrum_dot_org 5 * 6 * This library is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU Lesser General Public License as published by 8 * the Free Software Foundation, either version 3 of the License, or 9 * (at your option) any later version. 10 * 11 * This library is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 * GNU Lesser General Public License for more details. 15 * 16 * You should have received a copy of the GNU Lesser General Public License 17 * along with this library. If not, see <http://www.gnu.org/licenses/>. 18 * 19 */ 20 21 22 /** @file 23 * @brief A pack of learning algorithms that can easily be used 24 * 25 * The pack currently contains K2, GreedyHillClimbing and 26 *LocalSearchWithTabuList 27 * 28 * @author Christophe GONZALES(_at_AMU) and Pierre-Henri WUILLEMIN(_at_LIP6) 29 */ 30 #include <fstream> 31 32 #ifndef DOXYGEN_SHOULD_SKIP_THIS 33 34 // to help IDE parser 35 # include <agrum/BN/learning/BNLearner.h> 36 37 # include <agrum/BN/learning/BNLearnUtils/BNLearnerListener.h> 38 39 namespace gum { 40 41 namespace learning { 42 template < typename GUM_SCALAR > BNLearner(const std::string & filename,const bool induceTypes,const std::vector<std::string> & missingSymbols)43 BNLearner< GUM_SCALAR >::BNLearner(const std::string& filename, 44 const bool induceTypes, 45 const std::vector< std::string >& missingSymbols) : 46 genericBNLearner(filename, missingSymbols, induceTypes) { 47 GUM_CONSTRUCTOR(BNLearner); 48 } 49 50 template < typename GUM_SCALAR > BNLearner(const DatabaseTable<> & db)51 BNLearner< GUM_SCALAR >::BNLearner(const DatabaseTable<>& db) : genericBNLearner(db) { 52 GUM_CONSTRUCTOR(BNLearner); 53 } 54 55 template < typename GUM_SCALAR > BNLearner(const std::string & filename,const gum::BayesNet<GUM_SCALAR> & bn,const std::vector<std::string> & missing_symbols)56 BNLearner< GUM_SCALAR >::BNLearner(const std::string& filename, 57 const gum::BayesNet< GUM_SCALAR >& bn, 58 const std::vector< std::string >& missing_symbols) : 59 genericBNLearner(filename, bn, missing_symbols) { 60 GUM_CONSTRUCTOR(BNLearner); 61 } 62 63 /// copy constructor 64 template < typename GUM_SCALAR > BNLearner(const BNLearner<GUM_SCALAR> & src)65 BNLearner< GUM_SCALAR >::BNLearner(const BNLearner< GUM_SCALAR >& src) : genericBNLearner(src) { 66 GUM_CONSTRUCTOR(BNLearner); 67 } 68 69 /// move constructor 70 template < typename GUM_SCALAR > BNLearner(BNLearner<GUM_SCALAR> && src)71 BNLearner< GUM_SCALAR >::BNLearner(BNLearner< GUM_SCALAR >&& src) : genericBNLearner(src) { 72 GUM_CONSTRUCTOR(BNLearner); 73 } 74 75 /// destructor 76 template < typename GUM_SCALAR > ~BNLearner()77 BNLearner< GUM_SCALAR >::~BNLearner() { 78 GUM_DESTRUCTOR(BNLearner); 79 } 80 81 /// @} 82 83 // ########################################################################## 84 /// @name Operators 85 // ########################################################################## 86 /// @{ 87 88 /// copy operator 89 template < typename GUM_SCALAR > 90 BNLearner< GUM_SCALAR >& 91 BNLearner< GUM_SCALAR >::operator=(const BNLearner< GUM_SCALAR >& src) { 92 genericBNLearner::operator=(src); 93 return *this; 94 } 95 96 /// move operator 97 template < typename GUM_SCALAR > 98 BNLearner< GUM_SCALAR >& BNLearner< GUM_SCALAR >::operator=(BNLearner< GUM_SCALAR >&& src) { 99 genericBNLearner::operator=(std::move(src)); 100 return *this; 101 } 102 103 /// learn a Bayes Net from a file 104 template < typename GUM_SCALAR > learnBN()105 BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnBN() { 106 // create the score, the apriori and the estimator 107 auto notification = checkScoreAprioriCompatibility(); 108 if (notification != "") { std::cout << "[aGrUM notification] " << notification << std::endl; } 109 createApriori_(); 110 createScore_(); 111 112 std::unique_ptr< ParamEstimator<> > param_estimator( 113 createParamEstimator_(scoreDatabase_.parser(), true)); 114 115 return Dag2BN_.createBN< GUM_SCALAR >(*(param_estimator.get()), learnDag_()); 116 } 117 118 /// learns a BN (its parameters) when its structure is known 119 template < typename GUM_SCALAR > learnParameters(const DAG & dag,bool takeIntoAccountScore)120 BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnParameters(const DAG& dag, 121 bool takeIntoAccountScore) { 122 // if the dag contains no node, return an empty BN 123 if (dag.size() == 0) return BayesNet< GUM_SCALAR >(); 124 125 // check that the dag corresponds to the database 126 std::vector< NodeId > ids; 127 ids.reserve(dag.sizeNodes()); 128 for (const auto node: dag) 129 ids.push_back(node); 130 std::sort(ids.begin(), ids.end()); 131 132 if (ids.back() >= scoreDatabase_.names().size()) { 133 std::stringstream str; 134 str << "Learning parameters corresponding to the dag is impossible " 135 << "because the database does not contain the following nodeID"; 136 std::vector< NodeId > bad_ids; 137 for (const auto node: ids) { 138 if (node >= scoreDatabase_.names().size()) bad_ids.push_back(node); 139 } 140 if (bad_ids.size() > 1) str << 's'; 141 str << ": "; 142 bool deja = false; 143 for (const auto node: bad_ids) { 144 if (deja) 145 str << ", "; 146 else 147 deja = true; 148 str << node; 149 } 150 GUM_ERROR(MissingVariableInDatabase, str.str()) 151 } 152 153 // create the apriori 154 createApriori_(); 155 156 if (epsilonEM_ == 0.0) { 157 // check that the database does not contain any missing value 158 if (scoreDatabase_.databaseTable().hasMissingValues() 159 || ((aprioriDatabase_ != nullptr) 160 && (aprioriType_ == AprioriType::DIRICHLET_FROM_DATABASE) 161 && aprioriDatabase_->databaseTable().hasMissingValues())) { 162 GUM_ERROR(MissingValueInDatabase, 163 "In general, the BNLearner is unable to cope with " 164 << "missing values in databases. To learn parameters in " 165 << "such situations, you should first use method " 166 << "useEM()"); 167 } 168 169 // create the usual estimator 170 DBRowGeneratorParser<> parser(scoreDatabase_.databaseTable().handler(), 171 DBRowGeneratorSet<>()); 172 std::unique_ptr< ParamEstimator<> > param_estimator( 173 createParamEstimator_(parser, takeIntoAccountScore)); 174 175 return Dag2BN_.createBN< GUM_SCALAR >(*(param_estimator.get()), dag); 176 } else { 177 // EM ! 178 BNLearnerListener listener(this, Dag2BN_); 179 180 // get the column types 181 const auto& database = scoreDatabase_.databaseTable(); 182 const std::size_t nb_vars = database.nbVariables(); 183 const std::vector< gum::learning::DBTranslatedValueType > col_types( 184 nb_vars, 185 gum::learning::DBTranslatedValueType::DISCRETE); 186 187 // create the bootstrap estimator 188 DBRowGenerator4CompleteRows<> generator_bootstrap(col_types); 189 DBRowGeneratorSet<> genset_bootstrap; 190 genset_bootstrap.insertGenerator(generator_bootstrap); 191 DBRowGeneratorParser<> parser_bootstrap(database.handler(), genset_bootstrap); 192 std::unique_ptr< ParamEstimator<> > param_estimator_bootstrap( 193 createParamEstimator_(parser_bootstrap, takeIntoAccountScore)); 194 195 // create the EM estimator 196 BayesNet< GUM_SCALAR > dummy_bn; 197 DBRowGeneratorEM< GUM_SCALAR > generator_EM(col_types, dummy_bn); 198 DBRowGenerator<>& gen_EM = generator_EM; // fix for g++-4.8 199 DBRowGeneratorSet<> genset_EM; 200 genset_EM.insertGenerator(gen_EM); 201 DBRowGeneratorParser<> parser_EM(database.handler(), genset_EM); 202 std::unique_ptr< ParamEstimator<> > param_estimator_EM( 203 createParamEstimator_(parser_EM, takeIntoAccountScore)); 204 205 Dag2BN_.setEpsilon(epsilonEM_); 206 return Dag2BN_.createBN< GUM_SCALAR >(*(param_estimator_bootstrap.get()), 207 *(param_estimator_EM.get()), 208 dag); 209 } 210 } 211 212 213 /// learns a BN (its parameters) when its structure is known 214 template < typename GUM_SCALAR > learnParameters(bool take_into_account_score)215 BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnParameters(bool take_into_account_score) { 216 return learnParameters(initialDag_, take_into_account_score); 217 } 218 219 220 template < typename GUM_SCALAR > 221 NodeProperty< Sequence< std::string > > _labelsFromBN_(const std::string & filename,const BayesNet<GUM_SCALAR> & src)222 BNLearner< GUM_SCALAR >::_labelsFromBN_(const std::string& filename, 223 const BayesNet< GUM_SCALAR >& src) { 224 std::ifstream in(filename, std::ifstream::in); 225 226 if ((in.rdstate() & std::ifstream::failbit) != 0) { 227 GUM_ERROR(gum::IOError, "File " << filename << " not found") 228 } 229 230 CSVParser<> parser(in, filename); 231 parser.next(); 232 auto names = parser.current(); 233 234 NodeProperty< Sequence< std::string > > modals; 235 236 for (gum::Idx col = 0; col < names.size(); col++) { 237 try { 238 gum::NodeId graphId = src.idFromName(names[col]); 239 modals.insert(col, gum::Sequence< std::string >()); 240 241 for (gum::Size i = 0; i < src.variable(graphId).domainSize(); ++i) 242 modals[col].insert(src.variable(graphId).label(i)); 243 } catch (const gum::NotFound&) { 244 // no problem : a column not in the BN... 245 } 246 } 247 248 return modals; 249 } 250 251 252 template < typename GUM_SCALAR > toString()253 std::string BNLearner< GUM_SCALAR >::toString() const { 254 const auto st = state(); 255 256 Size maxkey = 0; 257 for (const auto& tuple: st) 258 if (std::get< 0 >(tuple).length() > maxkey) maxkey = std::get< 0 >(tuple).length(); 259 260 std::stringstream s; 261 for (const auto& tuple: st) { 262 s << std::setiosflags(std::ios::left) << std::setw(maxkey) << std::get< 0 >(tuple) << " : " 263 << std::get< 1 >(tuple); 264 if (std::get< 2 >(tuple) != "") s << " (" << std::get< 2 >(tuple) << ")"; 265 s << std::endl; 266 } 267 return s.str(); 268 } 269 270 template < typename GUM_SCALAR > 271 std::vector< std::tuple< std::string, std::string, std::string > > state()272 BNLearner< GUM_SCALAR >::state() const { 273 std::vector< std::tuple< std::string, std::string, std::string > > vals; 274 275 std::string key; 276 std::string comment; 277 const auto& db = database(); 278 279 vals.emplace_back("Filename", filename_, ""); 280 vals.emplace_back("Size", 281 "(" + std::to_string(nbRows()) + "," + std::to_string(nbCols()) + ")", 282 ""); 283 284 std::string vars = ""; 285 for (NodeId i = 0; i < db.nbVariables(); i++) { 286 if (i > 0) vars += ", "; 287 vars += nameFromId(i) + "[" + std::to_string(db.domainSize(i)) + "]"; 288 } 289 vals.emplace_back("Variables", vars, ""); 290 vals.emplace_back("Induced types", inducedTypes_ ? "True" : "False", ""); 291 vals.emplace_back("Missing values", hasMissingValues() ? "True" : "False", ""); 292 293 key = "Algorithm"; 294 switch (selectedAlgo_) { 295 case AlgoType::GREEDY_HILL_CLIMBING: 296 vals.emplace_back(key, "Greedy Hill Climbing", ""); 297 break; 298 case AlgoType::K2: { 299 vals.emplace_back(key, "K2", ""); 300 const auto& k2order = algoK2_.order(); 301 vars = ""; 302 for (NodeId i = 0; i < k2order.size(); i++) { 303 if (i > 0) vars += ", "; 304 vars += nameFromId(k2order.atPos(i)); 305 } 306 vals.emplace_back("K2 order", vars, ""); 307 } break; 308 case AlgoType::LOCAL_SEARCH_WITH_TABU_LIST: 309 vals.emplace_back(key, "Local Search with Tabu List", ""); 310 vals.emplace_back("Tabu list size", std::to_string(nbDecreasingChanges_), ""); 311 break; 312 case AlgoType::THREE_OFF_TWO: 313 vals.emplace_back(key, "3off2", ""); 314 break; 315 case AlgoType::MIIC: 316 vals.emplace_back(key, "MIIC", ""); 317 break; 318 default: 319 vals.emplace_back(key, "(unknown)", "?"); 320 break; 321 } 322 323 if (selectedAlgo_ != AlgoType::MIIC && selectedAlgo_ != AlgoType::THREE_OFF_TWO) { 324 key = "Score"; 325 switch (scoreType_) { 326 case ScoreType::K2: 327 vals.emplace_back(key, "K2", ""); 328 break; 329 case ScoreType::AIC: 330 vals.emplace_back(key, "AIC", ""); 331 break; 332 case ScoreType::BIC: 333 vals.emplace_back(key, "BIC", ""); 334 break; 335 case ScoreType::BD: 336 vals.emplace_back(key, "BD", ""); 337 break; 338 case ScoreType::BDeu: 339 vals.emplace_back(key, "BDeu", ""); 340 break; 341 case ScoreType::LOG2LIKELIHOOD: 342 vals.emplace_back(key, "Log2Likelihood", ""); 343 break; 344 default: 345 vals.emplace_back(key, "(unknown)", "?"); 346 break; 347 } 348 } else { 349 key = "Correction"; 350 switch (kmode3Off2_) { 351 case CorrectedMutualInformation<>::KModeTypes::MDL: 352 vals.emplace_back(key, "MDL", ""); 353 break; 354 case CorrectedMutualInformation<>::KModeTypes::NML: 355 vals.emplace_back(key, "NML", ""); 356 break; 357 case CorrectedMutualInformation<>::KModeTypes::NoCorr: 358 vals.emplace_back(key, "No correction", ""); 359 break; 360 default: 361 vals.emplace_back(key, "(unknown)", "?"); 362 break; 363 } 364 } 365 366 367 key = "Prior"; 368 comment = checkScoreAprioriCompatibility(); 369 switch (aprioriType_) { 370 case AprioriType::NO_APRIORI: 371 vals.emplace_back(key, "-", comment); 372 break; 373 case AprioriType::DIRICHLET_FROM_DATABASE: 374 vals.emplace_back(key, "Dirichlet", comment); 375 vals.emplace_back("Dirichlet database", aprioriDbname_, ""); 376 break; 377 case AprioriType::BDEU: 378 vals.emplace_back(key, "BDEU", comment); 379 break; 380 case AprioriType::SMOOTHING: 381 vals.emplace_back(key, "Smoothing", comment); 382 break; 383 default: 384 vals.emplace_back(key, "(unknown)", "?"); 385 break; 386 } 387 388 if (aprioriType_ != AprioriType::NO_APRIORI) 389 vals.emplace_back("Prior weight", std::to_string(aprioriWeight_), ""); 390 391 if (databaseWeight() != double(nbRows())) { 392 vals.emplace_back("Database weight", std::to_string(databaseWeight()), ""); 393 } 394 395 if (epsilonEM_ > 0.0) { 396 comment = ""; 397 if (!hasMissingValues()) comment = "But no missing values in this database"; 398 vals.emplace_back("EM", "True", ""); 399 vals.emplace_back("EM epsilon", std::to_string(epsilonEM_), comment); 400 } 401 402 std::string res; 403 bool nofirst; 404 if (constraintIndegree_.maxIndegree() < std::numeric_limits< Size >::max()) { 405 vals.emplace_back("Constraint Max InDegree", 406 std::to_string(constraintIndegree_.maxIndegree()), 407 "Used only for score-based algorithms."); 408 } 409 if (!constraintForbiddenArcs_.arcs().empty()) { 410 res = "{"; 411 nofirst = false; 412 for (const auto& arc: constraintForbiddenArcs_.arcs()) { 413 if (nofirst) 414 res += ", "; 415 else 416 nofirst = true; 417 res += nameFromId(arc.tail()) + "->" + nameFromId(arc.head()); 418 } 419 res += "}"; 420 vals.emplace_back("Constraint Forbidden Arcs", res, ""); 421 } 422 if (!constraintMandatoryArcs_.arcs().empty()) { 423 res = "{"; 424 nofirst = false; 425 for (const auto& arc: constraintMandatoryArcs_.arcs()) { 426 if (nofirst) 427 res += ", "; 428 else 429 nofirst = true; 430 res += nameFromId(arc.tail()) + "->" + nameFromId(arc.head()); 431 } 432 res += "}"; 433 vals.emplace_back("Constraint Mandatory Arcs", res, ""); 434 } 435 if (!constraintPossibleEdges_.edges().empty()) { 436 res = "{"; 437 nofirst = false; 438 for (const auto& edge: constraintPossibleEdges_.edges()) { 439 if (nofirst) 440 res += ", "; 441 else 442 nofirst = true; 443 res += nameFromId(edge.first()) + "--" + nameFromId(edge.second()); 444 } 445 res += "}"; 446 vals.emplace_back("Constraint Possible Edges", 447 res, 448 "Used only for score-based algorithms."); 449 } 450 if (!constraintSliceOrder_.sliceOrder().empty()) { 451 res = "{"; 452 nofirst = false; 453 const auto& order = constraintSliceOrder_.sliceOrder(); 454 for (const auto& p: order) { 455 if (nofirst) 456 res += ", "; 457 else 458 nofirst = true; 459 res += nameFromId(p.first) + ":" + std::to_string(p.second); 460 } 461 res += "}"; 462 vals.emplace_back("Constraint Slice Order", res, "Used only for score-based algorithms."); 463 } 464 if (initialDag_.size() != 0) { 465 vals.emplace_back("Initial DAG", "True", initialDag_.toDot()); 466 } 467 468 return vals; 469 } 470 471 template < typename GUM_SCALAR > 472 INLINE std::ostream& operator<<(std::ostream& output, const BNLearner< GUM_SCALAR >& learner) { 473 output << learner.toString(); 474 return output; 475 } 476 477 } /* namespace learning */ 478 479 } /* namespace gum */ 480 481 #endif /* DOXYGEN_SHOULD_SKIP_THIS */ 482