1 /** 2 * 3 * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe 4 * GONZALES(_at_AMU) info_at_agrum_dot_org 5 * 6 * This library is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU Lesser General Public License as published by 8 * the Free Software Foundation, either version 3 of the License, or 9 * (at your option) any later version. 10 * 11 * This library is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 * GNU Lesser General Public License for more details. 15 * 16 * You should have received a copy of the GNU Lesser General Public License 17 * along with this library. If not, see <http://www.gnu.org/licenses/>. 18 * 19 */ 20 21 22 /** 23 * @file 24 * @brief The 3off2 algorithm 25 * 26 * The ThreeOffTwo class implements the 3off2 algorithm as proposed by Affeldt and 27 * al. in https://doi.org/10.1186/s12859-015-0856-x. 28 * It starts by eliminating edges that correspond to independent variables to 29 * build the skeleton of the graph, and then directs the remaining edges to get an 30 * essential graph. Latent variables can be detected using bi-directed arcs. 31 * 32 * The variant MIIC is also implemented based on 33 * https://doi.org/10.1371/journal.pcbi.1005662. Only the orientation phase differs 34 * from 3off2, with a diffferent ranking method and different propagation rules. 35 * 36 * @author Quentin FALCAND and Pierre-Henri WUILLEMIN(_at_LIP6) and Maria Virginia 37 * RUIZ CUEVAS 38 */ 39 #ifndef GUM_LEARNING_3_OFF_2_H 40 #define GUM_LEARNING_3_OFF_2_H 41 42 #include <string> 43 #include <vector> 44 45 #include <agrum/BN/BayesNet.h> 46 #include <agrum/config.h> 47 #include <agrum/tools/core/approximations/IApproximationSchemeConfiguration.h> 48 #include <agrum/tools/core/approximations/approximationScheme.h> 49 #include <agrum/tools/core/heap.h> 50 #include <agrum/tools/graphs/DAG.h> 51 #include <agrum/tools/graphs/mixedGraph.h> 52 #include <agrum/tools/stattests/correctedMutualInformation.h> 53 54 namespace gum { 55 56 namespace learning { 57 using CondThreePoints = std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > >; 58 using CondRanking = std::pair< CondThreePoints*, double >; 59 60 using ThreePoints = std::tuple< NodeId, NodeId, NodeId >; 61 using Ranking = std::pair< ThreePoints*, double >; 62 63 using ProbabilisticRanking = std::tuple< ThreePoints*, double, double, double >; 64 65 class GreaterPairOn2nd { 66 public: 67 bool operator()(const CondRanking& e1, const CondRanking& e2) const; 68 }; 69 70 class GreaterAbsPairOn2nd { 71 public: 72 bool operator()(const Ranking& e1, const Ranking& e2) const; 73 }; 74 75 class GreaterTupleOnLast { 76 public: 77 bool operator()(const ProbabilisticRanking& e1, const ProbabilisticRanking& e2) const; 78 }; 79 80 /** 81 * @class Miic 82 * @brief The miic learning algorithm 83 * 84 * The miic class implements the miic algorithm based on 85 * https://doi.org/10.1371/journal.pcbi.1005662. 86 * It starts by eliminating edges that correspond to independent variables to 87 * build the skeleton of the graph, and then directs the remaining edges to get 88 * an 89 * essential graph. Latent variables can be detected using bi-directed arcs. 90 * 91 * The variant 3off2 is also implemented as proposed by Affeldt and 92 * al. in https://doi.org/10.1186/s12859-015-0856-x. Only the orientation 93 * phase differs from miic, with a different ranking method and different 94 * propagation rules. 95 * 96 * @ingroup learning_group 97 */ 98 class Miic: public ApproximationScheme { 99 public: 100 // ########################################################################## 101 /// @name Constructors / Destructors 102 // ########################################################################## 103 /// @{ 104 105 /// default constructor 106 Miic(); 107 108 /// default constructor with maxLog 109 explicit Miic(int maxLog); 110 111 /// copy constructor 112 Miic(const Miic& from); 113 114 /// move constructor 115 Miic(Miic&& from); 116 117 /// destructor 118 ~Miic() override; 119 120 /// @} 121 122 /// copy operator 123 Miic& operator=(const Miic& from); 124 125 /// move operator 126 Miic& operator=(Miic&& from); 127 128 // ########################################################################## 129 /// @name Accessors / Modifiers 130 // ########################################################################## 131 /// @{ 132 133 134 /// learns the structure of an Essential Graph 135 /** @param mutualInformation A mutual information instance that will do the 136 * computations and has loaded the database. 137 * @param graph the MixedGraph we start from for the learning 138 * */ 139 MixedGraph learnMixedStructure(CorrectedMutualInformation<>& mutualInformation, 140 MixedGraph graph); 141 142 /// learns the structure of a Bayesian network, i.e. a DAG, by first learning 143 /// an Essential graph and then directing the remaining edges. 144 /** @param I A mutual information instance that will do the computations 145 * and has loaded the database 146 * @param graph the MixedGraph we start from for the learning 147 */ 148 DAG learnStructure(CorrectedMutualInformation<>& I, MixedGraph graph); 149 150 /// learns the structure and the parameters of a BN 151 /** @param selector A selector class that computes the best changes that 152 * can be applied and that enables the user to get them very easily. 153 * Typically, the selector is a GraphChangesSelector4DiGraph<SCORE, 154 * STRUCT_CONSTRAINT, GRAPH_CHANGES_GENERATOR>. 155 * @param estimator A estimator. 156 * @param names The variables names. 157 * @param modal the domain sizes of the random variables observed in the 158 * database 159 * @param translator The cell translator to use. 160 * @param initial_dag the DAG we start from for our learning */ 161 template < typename GUM_SCALAR = double, 162 typename GRAPH_CHANGES_SELECTOR, 163 typename PARAM_ESTIMATOR > 164 BayesNet< GUM_SCALAR > learnBN(GRAPH_CHANGES_SELECTOR& selector, 165 PARAM_ESTIMATOR& estimator, 166 DAG initial_dag = DAG()); 167 168 /// get the list of arcs hiding latent variables 169 const std::vector< Arc > latentVariables() const; 170 171 /// Sets the orientation phase to follow the one of the MIIC algorithm 172 void setMiicBehaviour(); 173 174 /// Sets the orientation phase to follow the one of the 3off2 algorithm 175 void set3of2Behaviour(); 176 177 /// Set a ensemble of constraints for the orientation phase 178 void addConstraints(HashTable< std::pair< NodeId, NodeId >, char > constraints); 179 /// @} 180 181 protected: 182 // ########################################################################## 183 /// @name Main phases 184 // ########################################################################## 185 /// @{ 186 187 /// Initiation phase 188 /** 189 * We go over all edges and test if the variables are marginally independent. 190 * If they are, the edge is deleted. If not, the best contributor is found. 191 * 192 * @param mutualInformation A mutual information instance that will do the 193 * computations and has loaded the database. 194 * @param graph the MixedGraph we start from for the learning 195 * @param sepSet the separation set for independent couples, here set to {} 196 * @param rank the heap of ranks of the algorithm 197 */ 198 void initiation_(CorrectedMutualInformation<>& mutualInformation, 199 MixedGraph& graph, 200 HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet, 201 Heap< CondRanking, GreaterPairOn2nd >& rank); 202 203 /// Iteration phase 204 /** 205 * As long as we find important nodes for edges, we go over them to see if 206 * we can assess the conditional independence of the variables. 207 * 208 * @param mutualInformation A mutual information instance that will do the 209 * computations and has loaded the database. 210 * @param graph the MixedGraph returned from the previous phase 211 * @param sepSet the separation set for independent couples, built during 212 * the iterations of the phase 213 * @param rank the heap of ranks of the algorithm 214 */ 215 void iteration_(CorrectedMutualInformation<>& mutualInformation, 216 MixedGraph& graph, 217 HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet, 218 Heap< CondRanking, GreaterPairOn2nd >& rank); 219 220 /// Orientation phase from the 3off2 algorithm, returns a CPDAG 221 /** @param mutualInformation A mutual information instance that will do the 222 * computations and has loaded the database. 223 * @param graph the MixedGraph returned from the previous phase 224 * @param sepSet the separation set for independent couples, built during 225 * the previous phase 226 */ 227 void orientation3off2_( 228 CorrectedMutualInformation<>& mutualInformation, 229 MixedGraph& graph, 230 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet); 231 232 /// Modified version of the orientation phase that tries to propagate 233 /// orientations from both orientations in case of a bidirected arc, not used 234 /** @param mutualInformation A mutual information instance that will do the 235 * computations and has loaded the database. 236 * @param graph the MixedGraph returned from the previous phase 237 * @param sepSet the separation set for independent couples, built during 238 * the previous phase 239 */ 240 void orientationLatents_( 241 CorrectedMutualInformation<>& mutualInformation, 242 MixedGraph& graph, 243 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet); 244 245 /// Orientation phase from the MIIC algorithm, returns a mixed graph that 246 /// may contain circles 247 /** @param mutualInformation A mutual information instance that will do the 248 * computations and has loaded the database. 249 * @param graph the MixedGraph returned from the previous phase 250 * @param sepSet the separation set for independent couples, built during 251 * the previous phase 252 */ 253 void orientationMiic_( 254 CorrectedMutualInformation<>& mutualInformation, 255 MixedGraph& graph, 256 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet); 257 /// @} 258 259 /// finds the best contributor node for a pair given a conditioning set 260 /**@param x first node 261 * @param y second node 262 * @param ui conditioning set 263 * @param mutualInformation A mutual information instance that will do the 264 * computations and has loaded the database. 265 * @param graph containing the assessed nodes 266 * @param rank the heap of ranks of the algorithm 267 */ 268 void findBestContributor_(NodeId x, 269 NodeId y, 270 const std::vector< NodeId >& ui, 271 const MixedGraph& graph, 272 CorrectedMutualInformation<>& mutualInformation, 273 Heap< CondRanking, GreaterPairOn2nd >& rank); 274 275 /// gets the list of unshielded triples in the graph in decreasing value of 276 ///|I'(x, y, z|{ui})| 277 /*@param graph graph in which to find the triples 278 *@param I mutual information object to compute the scores 279 *@param sep_set hashtable storing the separation sets for pairs of variables 280 */ 281 std::vector< Ranking > unshieldedTriples_( 282 const MixedGraph& graph, 283 CorrectedMutualInformation<>& mutualInformation, 284 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet); 285 286 /// gets the list of unshielded triples in the graph in decreasing value of 287 ///|I'(x, y, z|{ui})|, prepares the orientation matrix for MIIC 288 /*@param graph graph in which to find the triples 289 *@param I mutual information object to compute the scores 290 *@param sep_set hashtable storing the separation sets for pairs of variables 291 * @param marks hashtable containing the orientation marks for edges 292 */ 293 std::vector< ProbabilisticRanking > unshieldedTriplesMiic_( 294 const MixedGraph& graph, 295 CorrectedMutualInformation<>& mutualInformation, 296 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet, 297 HashTable< std::pair< NodeId, NodeId >, char >& marks); 298 299 /// Gets the orientation probabilities like MIIC for the orientation phase 300 /*@param graph graph in which to find the triples 301 *@param proba_triples probabilities for the different triples to update 302 */ 303 std::vector< ProbabilisticRanking > 304 updateProbaTriples_(const MixedGraph& graph, 305 std::vector< ProbabilisticRanking > probaTriples); 306 307 /// Propagates the orientation from a node to its neighbours 308 /*@param dag graph in which to which to propagate arcs 309 *@param node node on which neighbours to propagate th orientation 310 *@param force : true if an orientation has always to be found. 311 */ 312 bool propagatesRemainingOrientableEdges_(MixedGraph& graph, NodeId xj); 313 314 /// heuristic for remaining edges when everything else has been tried 315 void propagatesOrientationInChainOfRemainingEdges_(MixedGraph& graph); 316 317 protected: 318 bool isForbidenArc_(NodeId x, NodeId y) const; 319 bool isOrientable_(const MixedGraph& graph, NodeId xi, NodeId xj) const; 320 321 private: 322 /// Fixes the maximum log that we accept in exponential computations 323 int _maxLog_ = 100; 324 /// an empty conditioning set 325 const std::vector< NodeId > _emptySet_; 326 /// an empty vector of arcs 327 std::vector< Arc > _latentCouples_; 328 329 /// size of the database 330 Size _size_; 331 /// wether to use the miic algorithm or not 332 bool _useMiic_{false}; 333 334 /// Storing the propabilities for each arc set in the graph 335 ArcProperty< double > _arcProbas_; 336 337 /// Initial marks for the orientation phase, used to convey constraints 338 HashTable< std::pair< NodeId, NodeId >, char > _initialMarks_; 339 340 /** @brief checks for directed paths in a graph, considering double arcs like 341 * edges, not considering arc as a directed path. 342 * @param graph MixedGraph in which to search the path 343 * @param n1 tail of the path 344 * @param n2 head of the path 345 * @param countArc bool to know if we 346 */ 347 static bool _existsNonTrivialDirectedPath_(const MixedGraph& graph, NodeId n1, NodeId n2); 348 349 /** checks for directed paths in a graph, consider double arcs like edges 350 *@param graph MixedGraph in which to search the path 351 *@param n1 tail of the path 352 *@param n2 head of the path 353 */ 354 static bool _existsDirectedPath_(const MixedGraph& graph, NodeId n1, NodeId n2); 355 356 void _orientingVstructureMiic_(MixedGraph& graph, 357 HashTable< std::pair< NodeId, NodeId >, char >& marks, 358 NodeId x, 359 NodeId y, 360 NodeId z, 361 double p1, 362 double p2); 363 364 void _propagatingOrientationMiic_(MixedGraph& graph, 365 HashTable< std::pair< NodeId, NodeId >, char >& marks, 366 NodeId x, 367 NodeId y, 368 NodeId z, 369 double p1, 370 double p2); 371 372 bool _isNotLatentCouple_(NodeId x, NodeId y); 373 }; 374 375 } /* namespace learning */ 376 377 } /* namespace gum */ 378 379 /// always include templated methods 380 //#include <agrum/BN/learning/threeOffTwo_tpl.h> 381 382 #endif /* GUM_LEARNING_3_OFF_2_H */ 383