1 /**
2  *
3  *   Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe
4  * GONZALES(_at_AMU) info_at_agrum_dot_org
5  *
6  *  This library is free software: you can redistribute it and/or modify
7  *  it under the terms of the GNU Lesser General Public License as published by
8  *  the Free Software Foundation, either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  This library is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU Lesser General Public License for more details.
15  *
16  *  You should have received a copy of the GNU Lesser General Public License
17  *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief The 3off2 algorithm
25  *
26  * The ThreeOffTwo class implements the 3off2 algorithm as proposed by Affeldt and
27  * al. in https://doi.org/10.1186/s12859-015-0856-x.
28  * It starts by eliminating edges that correspond to independent variables to
29  * build the skeleton of the graph, and then directs the remaining edges to get an
30  * essential graph. Latent variables can be detected using bi-directed arcs.
31  *
32  * The variant MIIC is also implemented based on
33  * https://doi.org/10.1371/journal.pcbi.1005662. Only the orientation phase differs
34  * from 3off2, with a diffferent ranking method and different propagation rules.
35  *
36  * @author Quentin FALCAND and Pierre-Henri WUILLEMIN(_at_LIP6) and Maria Virginia
37  * RUIZ CUEVAS
38  */
39 #ifndef GUM_LEARNING_3_OFF_2_H
40 #define GUM_LEARNING_3_OFF_2_H
41 
42 #include <string>
43 #include <vector>
44 
45 #include <agrum/BN/BayesNet.h>
46 #include <agrum/config.h>
47 #include <agrum/tools/core/approximations/IApproximationSchemeConfiguration.h>
48 #include <agrum/tools/core/approximations/approximationScheme.h>
49 #include <agrum/tools/core/heap.h>
50 #include <agrum/tools/graphs/DAG.h>
51 #include <agrum/tools/graphs/mixedGraph.h>
52 #include <agrum/tools/stattests/correctedMutualInformation.h>
53 
54 namespace gum {
55 
56   namespace learning {
57     using CondThreePoints = std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > >;
58     using CondRanking     = std::pair< CondThreePoints*, double >;
59 
60     using ThreePoints = std::tuple< NodeId, NodeId, NodeId >;
61     using Ranking     = std::pair< ThreePoints*, double >;
62 
63     using ProbabilisticRanking = std::tuple< ThreePoints*, double, double, double >;
64 
65     class GreaterPairOn2nd {
66       public:
67       bool operator()(const CondRanking& e1, const CondRanking& e2) const;
68     };
69 
70     class GreaterAbsPairOn2nd {
71       public:
72       bool operator()(const Ranking& e1, const Ranking& e2) const;
73     };
74 
75     class GreaterTupleOnLast {
76       public:
77       bool operator()(const ProbabilisticRanking& e1, const ProbabilisticRanking& e2) const;
78     };
79 
80     /**
81      * @class Miic
82      * @brief The miic learning algorithm
83      *
84      * The miic class implements the miic algorithm based on
85      * https://doi.org/10.1371/journal.pcbi.1005662.
86      * It starts by eliminating edges that correspond to independent variables to
87      * build the skeleton of the graph, and then directs the remaining edges to get
88      * an
89      * essential graph. Latent variables can be detected using bi-directed arcs.
90      *
91      * The variant 3off2 is also implemented as proposed by Affeldt and
92      * al. in https://doi.org/10.1186/s12859-015-0856-x.  Only the orientation
93      * phase differs from miic, with a different ranking method and different
94      * propagation rules.
95      *
96      * @ingroup learning_group
97      */
98     class Miic: public ApproximationScheme {
99       public:
100       // ##########################################################################
101       /// @name Constructors / Destructors
102       // ##########################################################################
103       /// @{
104 
105       /// default constructor
106       Miic();
107 
108       /// default constructor with maxLog
109       explicit Miic(int maxLog);
110 
111       /// copy constructor
112       Miic(const Miic& from);
113 
114       /// move constructor
115       Miic(Miic&& from);
116 
117       /// destructor
118       ~Miic() override;
119 
120       /// @}
121 
122       /// copy operator
123       Miic& operator=(const Miic& from);
124 
125       /// move operator
126       Miic& operator=(Miic&& from);
127 
128       // ##########################################################################
129       /// @name Accessors / Modifiers
130       // ##########################################################################
131       /// @{
132 
133 
134       /// learns the structure of an Essential Graph
135       /** @param mutualInformation A mutual information instance that will do the
136        * computations and has loaded the database.
137        * @param graph the MixedGraph we start from for the learning
138        * */
139       MixedGraph learnMixedStructure(CorrectedMutualInformation<>& mutualInformation,
140                                      MixedGraph                    graph);
141 
142       /// learns the structure of a Bayesian network, i.e. a DAG, by first learning
143       /// an Essential graph and then directing the remaining edges.
144       /** @param I A mutual information instance that will do the computations
145        * and has loaded the database
146        * @param graph the MixedGraph we start from for the learning
147        */
148       DAG learnStructure(CorrectedMutualInformation<>& I, MixedGraph graph);
149 
150       /// learns the structure and the parameters of a BN
151       /** @param selector A selector class that computes the best changes that
152        * can be applied and that enables the user to get them very easily.
153        * Typically, the selector is a GraphChangesSelector4DiGraph<SCORE,
154        * STRUCT_CONSTRAINT, GRAPH_CHANGES_GENERATOR>.
155        * @param estimator A estimator.
156        * @param names The variables names.
157        * @param modal the domain sizes of the random variables observed in the
158        * database
159        * @param translator The cell translator to use.
160        * @param initial_dag the DAG we start from for our learning */
161       template < typename GUM_SCALAR = double,
162                  typename GRAPH_CHANGES_SELECTOR,
163                  typename PARAM_ESTIMATOR >
164       BayesNet< GUM_SCALAR > learnBN(GRAPH_CHANGES_SELECTOR& selector,
165                                      PARAM_ESTIMATOR&        estimator,
166                                      DAG                     initial_dag = DAG());
167 
168       /// get the list of arcs hiding latent variables
169       const std::vector< Arc > latentVariables() const;
170 
171       /// Sets the orientation phase to follow the one of the MIIC algorithm
172       void setMiicBehaviour();
173 
174       /// Sets the orientation phase to follow the one of the 3off2 algorithm
175       void set3of2Behaviour();
176 
177       /// Set a ensemble of constraints for the orientation phase
178       void addConstraints(HashTable< std::pair< NodeId, NodeId >, char > constraints);
179       /// @}
180 
181       protected:
182       // ##########################################################################
183       /// @name Main phases
184       // ##########################################################################
185       /// @{
186 
187       /// Initiation phase
188       /**
189        * We go over all edges and test if the variables are marginally independent.
190        * If they are, the edge is deleted. If not, the best contributor is found.
191        *
192        * @param mutualInformation A mutual information instance that will do the
193        * computations and has loaded the database.
194        * @param graph the MixedGraph we start from for the learning
195        * @param sepSet the separation set for independent couples, here set to {}
196        * @param rank the heap of ranks of the algorithm
197        */
198       void initiation_(CorrectedMutualInformation<>& mutualInformation,
199                        MixedGraph&                   graph,
200                        HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet,
201                        Heap< CondRanking, GreaterPairOn2nd >&                           rank);
202 
203       /// Iteration phase
204       /**
205        * As long as we find important nodes for edges, we go over them to see if
206        * we can assess the conditional independence of the variables.
207        *
208        * @param mutualInformation A mutual information instance that will do the
209        * computations and has loaded the database.
210        * @param graph the MixedGraph returned from the previous phase
211        * @param sepSet the separation set for independent couples, built during
212        * the iterations of the phase
213        * @param rank the heap of ranks of the algorithm
214        */
215       void iteration_(CorrectedMutualInformation<>& mutualInformation,
216                       MixedGraph&                   graph,
217                       HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet,
218                       Heap< CondRanking, GreaterPairOn2nd >&                           rank);
219 
220       /// Orientation phase from the 3off2 algorithm, returns a CPDAG
221       /** @param mutualInformation A mutual information instance that will do the
222        * computations and has loaded the database.
223        * @param graph the MixedGraph returned from the previous phase
224        * @param sepSet the separation set for independent couples, built during
225        * the previous phase
226        */
227       void orientation3off2_(
228          CorrectedMutualInformation<>&                                          mutualInformation,
229          MixedGraph&                                                            graph,
230          const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet);
231 
232       /// Modified version of the orientation phase that tries to propagate
233       /// orientations from both orientations in case of a bidirected arc, not used
234       /** @param mutualInformation A mutual information instance that will do the
235        * computations and has loaded the database.
236        * @param graph the MixedGraph returned from the previous phase
237        * @param sepSet the separation set for independent couples, built during
238        * the previous phase
239        */
240       void orientationLatents_(
241          CorrectedMutualInformation<>&                                          mutualInformation,
242          MixedGraph&                                                            graph,
243          const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet);
244 
245       /// Orientation phase from the MIIC algorithm, returns a mixed graph that
246       /// may contain circles
247       /** @param mutualInformation A mutual information instance that will do the
248        * computations and has loaded the database.
249        * @param graph the MixedGraph returned from the previous phase
250        * @param sepSet the separation set for independent couples, built during
251        * the previous phase
252        */
253       void orientationMiic_(
254          CorrectedMutualInformation<>&                                          mutualInformation,
255          MixedGraph&                                                            graph,
256          const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet);
257       /// @}
258 
259       /// finds the best contributor node for a pair given a conditioning set
260       /**@param x first node
261        * @param y second node
262        * @param ui conditioning set
263        * @param mutualInformation A mutual information instance that will do the
264        * computations and has loaded the database.
265        * @param graph containing the assessed nodes
266        * @param rank the heap of ranks of the algorithm
267        */
268       void findBestContributor_(NodeId                                 x,
269                                 NodeId                                 y,
270                                 const std::vector< NodeId >&           ui,
271                                 const MixedGraph&                      graph,
272                                 CorrectedMutualInformation<>&          mutualInformation,
273                                 Heap< CondRanking, GreaterPairOn2nd >& rank);
274 
275       /// gets the list of unshielded triples in the graph in decreasing value of
276       ///|I'(x, y, z|{ui})|
277       /*@param graph graph in which to find the triples
278        *@param I mutual information object to compute the scores
279        *@param sep_set hashtable storing the separation sets for pairs of variables
280        */
281       std::vector< Ranking > unshieldedTriples_(
282          const MixedGraph&                                                      graph,
283          CorrectedMutualInformation<>&                                          mutualInformation,
284          const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet);
285 
286       /// gets the list of unshielded triples in the graph in decreasing value of
287       ///|I'(x, y, z|{ui})|, prepares the orientation matrix for MIIC
288       /*@param graph graph in which to find the triples
289        *@param I mutual information object to compute the scores
290        *@param sep_set hashtable storing the separation sets for pairs of variables
291        * @param marks hashtable containing the orientation marks for edges
292        */
293       std::vector< ProbabilisticRanking > unshieldedTriplesMiic_(
294          const MixedGraph&                                                      graph,
295          CorrectedMutualInformation<>&                                          mutualInformation,
296          const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet,
297          HashTable< std::pair< NodeId, NodeId >, char >&                        marks);
298 
299       /// Gets the orientation probabilities like MIIC for the orientation phase
300       /*@param graph graph in which to find the triples
301        *@param proba_triples probabilities for the different triples to update
302        */
303       std::vector< ProbabilisticRanking >
304          updateProbaTriples_(const MixedGraph&                   graph,
305                              std::vector< ProbabilisticRanking > probaTriples);
306 
307       /// Propagates the orientation from a node to its neighbours
308       /*@param dag graph in which to which to propagate arcs
309        *@param node node on which neighbours to propagate th orientation
310        *@param force : true if an orientation has always to be found.
311        */
312       bool propagatesRemainingOrientableEdges_(MixedGraph& graph, NodeId xj);
313 
314       /// heuristic for remaining edges when everything else has been tried
315       void propagatesOrientationInChainOfRemainingEdges_(MixedGraph& graph);
316 
317       protected:
318       bool isForbidenArc_(NodeId x, NodeId y) const;
319       bool isOrientable_(const MixedGraph& graph, NodeId xi, NodeId xj) const;
320 
321       private:
322       /// Fixes the maximum log that we accept in exponential computations
323       int _maxLog_ = 100;
324       /// an empty conditioning set
325       const std::vector< NodeId > _emptySet_;
326       /// an empty vector of arcs
327       std::vector< Arc > _latentCouples_;
328 
329       /// size of the database
330       Size _size_;
331       /// wether to use the miic algorithm or not
332       bool _useMiic_{false};
333 
334       /// Storing the propabilities for each arc set in the graph
335       ArcProperty< double > _arcProbas_;
336 
337       /// Initial marks for the orientation phase, used to convey constraints
338       HashTable< std::pair< NodeId, NodeId >, char > _initialMarks_;
339 
340       /** @brief checks for directed paths in a graph, considering double arcs like
341        * edges, not considering arc as a directed path.
342        * @param graph MixedGraph in which to search the path
343        * @param n1 tail of the path
344        * @param n2 head of the path
345        * @param countArc bool to know if we
346        */
347       static bool _existsNonTrivialDirectedPath_(const MixedGraph& graph, NodeId n1, NodeId n2);
348 
349       /** checks for directed paths in a graph, consider double arcs like edges
350        *@param graph MixedGraph in which to search the path
351        *@param n1 tail of the path
352        *@param n2 head of the path
353        */
354       static bool _existsDirectedPath_(const MixedGraph& graph, NodeId n1, NodeId n2);
355 
356       void _orientingVstructureMiic_(MixedGraph&                                     graph,
357                                      HashTable< std::pair< NodeId, NodeId >, char >& marks,
358                                      NodeId                                          x,
359                                      NodeId                                          y,
360                                      NodeId                                          z,
361                                      double                                          p1,
362                                      double                                          p2);
363 
364       void _propagatingOrientationMiic_(MixedGraph&                                     graph,
365                                         HashTable< std::pair< NodeId, NodeId >, char >& marks,
366                                         NodeId                                          x,
367                                         NodeId                                          y,
368                                         NodeId                                          z,
369                                         double                                          p1,
370                                         double                                          p2);
371 
372       bool _isNotLatentCouple_(NodeId x, NodeId y);
373     };
374 
375   } /* namespace learning */
376 
377 } /* namespace gum */
378 
379 /// always include templated methods
380 //#include <agrum/BN/learning/threeOffTwo_tpl.h>
381 
382 #endif /* GUM_LEARNING_3_OFF_2_H */
383