1 /**
2  *
3  *   Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU)
4  *   info_at_agrum_dot_org
5  *
6  *  This library is free software: you can redistribute it and/or modify
7  *  it under the terms of the GNU Lesser General Public License as published by
8  *  the Free Software Foundation, either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  This library is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU Lesser General Public License for more details.
15  *
16  *  You should have received a copy of the GNU Lesser General Public License
17  *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Headers of the interface specifying functions to be implemented by any
25  * incremental learner.
26  *
27  * @author Jean-Christophe MAGNAN
28  */
29 
30 // =========================================================================
31 #ifndef GUM_INCREMENTAL_GRAPH_LEARNER_H
32 #define GUM_INCREMENTAL_GRAPH_LEARNER_H
33 // =========================================================================
34 // =========================================================================
35 #include <agrum/tools/multidim/implementations/multiDimFunctionGraph.h>
36 // =========================================================================
37 #include <agrum/FMDP/learning/core/templateStrategy.h>
38 #include <agrum/FMDP/learning/datastructure/IVisitableGraphLearner.h>
39 #include <agrum/FMDP/learning/datastructure/nodeDatabase.h>
40 // =========================================================================
41 #include <agrum/tools/multidim/utils/FunctionGraphUtilities/link.h>
42 // =========================================================================
43 
44 namespace gum {
45 
46   /**
47    * @class IncrementalGraphLearner incrementalGraphLearner.h
48    * <agrum/FMDP/learning/datastructure/incrementalGraphLearner>
49    * @brief
50    * @ingroup fmdp_group
51    *
52    * Abstract class for incrementaly learn a graphical representation of a
53    * function.
54    * Can handle both function of real values, and function explaining the
55    * behaviour
56    * of a variable given set of other variables (as typically in conditionnal
57    * probabilities)
58    *
59    * Maintains two graph in memory, one which is incrementaly updated and the
60    * other one
61    * which is updated on demand and is usable by the outside.
62    *
63    */
64   template < TESTNAME AttributeSelection, bool isScalar = false >
65   class IncrementalGraphLearner: public IVisitableGraphLearner {
66     typedef typename ValueSelect< isScalar, double, Idx >::type ValueType;
67 
68     public:
69     // ###################################################################
70     /// @name Constructor & destructor.
71     // ###################################################################
72     /// @{
73 
74     // ==========================================================================
75     /**
76      * Default constructor
77      * @param target : the output diagram usable by the outside
78      * @param attributesSet : set of variables from which we try to describe the
79      * learned function
80      * @param learnVariable : if we tried to learn a the behaviour of a variable
81      * given variable given another set of variables, this is the one. If we are
82      * learning a function of real value, this is just a computationnal trick
83      * (and is to be deprecated)
84      */
85     // ==========================================================================
86     IncrementalGraphLearner(MultiDimFunctionGraph< double >* target,
87                             Set< const DiscreteVariable* >   attributesSet,
88                             const DiscreteVariable*          learnVariable);
89 
90     // ==========================================================================
91     /// Default destructor
92     // ==========================================================================
93     virtual ~IncrementalGraphLearner();
94 
95     private:
96     // ==========================================================================
97     /// Template function dispatcher
98     // ==========================================================================
_clearValue_()99     void _clearValue_() { _clearValue_(Int2Type< isScalar >()); }
100 
101     // ==========================================================================
102     /// In the case where we're learning a function of real values
103     /// this has to be wiped out upon destruction (to be deprecated)
104     // ==========================================================================
_clearValue_(Int2Type<true>)105     void _clearValue_(Int2Type< true >) { delete value_; }
106 
107     // ==========================================================================
108     /// In case where we're learning function of variable behaviour,
109     /// this should do nothing
110     // ==========================================================================
_clearValue_(Int2Type<false>)111     void _clearValue_(Int2Type< false >) {}
112 
113     /// @}
114 
115 
116     // ###################################################################
117     /// @name New Observation insertion methods
118     // ###################################################################
119     /// @{
120     public:
121     // ==========================================================================
122     /**
123      * Inserts a new observation
124      */
125     // ==========================================================================
126     virtual void addObservation(const Observation* obs);
127 
128     private:
129     // ==========================================================================
130     /**
131      * Get value assumed by studied variable for current observation
132      */
133     // ==========================================================================
_assumeValue_(const Observation * obs)134     void _assumeValue_(const Observation* obs) { _assumeValue_(obs, Int2Type< isScalar >()); }
_assumeValue_(const Observation * obs,Int2Type<true>)135     void _assumeValue_(const Observation* obs, Int2Type< true >) {
136       if (!valueAssumed_.exists(obs->reward())) valueAssumed_ << obs->reward();
137     }
_assumeValue_(const Observation * obs,Int2Type<false>)138     void _assumeValue_(const Observation* obs, Int2Type< false >) {
139       if (!valueAssumed_.exists(obs->modality(value_))) valueAssumed_ << obs->modality(value_);
140     }
141 
142 
143     // ==========================================================================
144     /**
145      * Seek modality assumed in obs for given var
146      */
147     // ==========================================================================
_branchObs_(const Observation * obs,const DiscreteVariable * var)148     Idx _branchObs_(const Observation* obs, const DiscreteVariable* var) {
149       return _branchObs_(obs, var, Int2Type< isScalar >());
150     }
_branchObs_(const Observation * obs,const DiscreteVariable * var,Int2Type<true>)151     Idx _branchObs_(const Observation* obs, const DiscreteVariable* var, Int2Type< true >) {
152       return obs->rModality(var);
153     }
_branchObs_(const Observation * obs,const DiscreteVariable * var,Int2Type<false>)154     Idx _branchObs_(const Observation* obs, const DiscreteVariable* var, Int2Type< false >) {
155       return obs->modality(var);
156     }
157 
158     protected:
159     // ==========================================================================
160     /**
161      * Will update internal graph's NodeDatabase of given node with the new
162      * observation
163      * @param newObs
164      * @param currentNodeId
165      */
166     // ==========================================================================
updateNodeWithObservation_(const Observation * newObs,NodeId currentNodeId)167     virtual void updateNodeWithObservation_(const Observation* newObs, NodeId currentNodeId) {
168       nodeId2Database_[currentNodeId]->addObservation(newObs);
169     }
170 
171     /// @}
172 
173     // ###################################################################
174     /// @name Graph Structure update methods
175     // ###################################################################
176     /// @{
177 
178     public:
179     // ==========================================================================
180     /// If a new modality appears to exists for given variable,
181     /// call this method to turn every associated node to this variable into
182     /// leaf.
183     /// Graph has then indeed to be revised
184     // ==========================================================================
185     virtual void updateVar(const DiscreteVariable*);
186 
187     // ==========================================================================
188     /// Updates the tree after a new observation has been added
189     // ==========================================================================
190     virtual void updateGraph() = 0;
191 
192     protected:
193     // ==========================================================================
194     /**
195      * From the given sets of node, selects randomly one and installs it
196      * on given node. Chechks of course if node's current variable is not in
197      * that
198      * set first.
199      * @param nody : the node we update
200      * @param bestVars : the set of interessting vars to be installed here
201      */
202     // ==========================================================================
203     void updateNode_(NodeId nody, Set< const DiscreteVariable* >& bestVars);
204 
205     // ==========================================================================
206     /// Turns the given node into a leaf if not already so
207     // ==========================================================================
208     virtual void convertNode2Leaf_(NodeId);
209 
210     // ==========================================================================
211     /// Installs given variable to the given node, ensuring that the variable
212     /// is not present in its subtree
213     // ==========================================================================
214     virtual void transpose_(NodeId, const DiscreteVariable*);
215 
216     // ==========================================================================
217     /**
218      * inserts a new node in internal graph
219      * @param nDB : the associated database
220      * @param boundVar : the associated variable
221      * @return the newly created node's id
222      */
223     // ==========================================================================
224     virtual NodeId insertNode_(NodeDatabase< AttributeSelection, isScalar >* nDB,
225                                const DiscreteVariable*                       boundVar);
226 
227     // ==========================================================================
228     /**
229      * inserts a new internal node in internal graph
230      * @param nDB : the associated database
231      * @param boundVar : the associated variable
232      * @param sonsMap : a table giving node's sons node
233      * @return the newly created node's id
234      */
235     // ==========================================================================
236     virtual NodeId insertInternalNode_(NodeDatabase< AttributeSelection, isScalar >* nDB,
237                                        const DiscreteVariable*                       boundVar,
238                                        NodeId*                                       sonsMap);
239 
240     // ==========================================================================
241     /**
242      * inserts a new leaf node in internal graohs
243      * @param nDB : the associated database
244      * @param boundVar : the associated variable
245      * @param obsSet : the set of observation this leaf retains
246      * @return the newly created node's id
247      */
248     // ==========================================================================
249     virtual NodeId insertLeafNode_(NodeDatabase< AttributeSelection, isScalar >* nDB,
250                                    const DiscreteVariable*                       boundVar,
251                                    Set< const Observation* >*                    obsSet);
252 
253     // ==========================================================================
254     /**
255      * Changes the associated variable of a node
256      * @param chgedNodeId : the node to change
257      * @param desiredVar : its new associated variable
258      */
259     // ==========================================================================
260     virtual void chgNodeBoundVar_(NodeId chgedNodeId, const DiscreteVariable* desiredVar);
261 
262     // ==========================================================================
263     /**
264      * Removes a node from the internal graph
265      * @param removedNodeId : the node to remove
266      */
267     // ==========================================================================
268     virtual void removeNode_(NodeId removedNodeId);
269 
270     /// @}
271 
272 
273     // ###################################################################
274     /// @name Function Graph Updating methods
275     // ###################################################################
276     /// @{
277     public:
278     // ==========================================================================
279     /// Updates target to currently learned graph structure
280     // ==========================================================================
281     virtual void updateFunctionGraph() = 0;
282 
283     /// @}
284 
285 
286     public:
287     // ==========================================================================
288     ///
289     // ==========================================================================
size()290     Size size() { return nodeVarMap_.size(); }
291 
292 
293     // ###################################################################
294     /// @name Visit Methods
295     // ###################################################################
296     /// @{
297     public:
298     // ==========================================================================
299     ///
300     // ==========================================================================
root()301     NodeId root() const { return this->root_; }
302 
303     // ==========================================================================
304     ///
305     // ==========================================================================
isTerminal(NodeId ni)306     bool isTerminal(NodeId ni) const { return !this->nodeSonsMap_.exists(ni); }
307 
308     // ==========================================================================
309     ///
310     // ==========================================================================
nodeVar(NodeId ni)311     const DiscreteVariable* nodeVar(NodeId ni) const { return this->nodeVarMap_[ni]; }
312 
313     // ==========================================================================
314     ///
315     // ==========================================================================
nodeSon(NodeId ni,Idx modality)316     NodeId nodeSon(NodeId ni, Idx modality) const { return this->nodeSonsMap_[ni][modality]; }
317 
318     // ==========================================================================
319     ///
320     // ==========================================================================
nodeNbObservation(NodeId ni)321     Idx nodeNbObservation(NodeId ni) const { return this->nodeId2Database_[ni]->nbObservation(); }
322 
323     // ==========================================================================
324     ///
325     // ==========================================================================
insertSetOfVars(MultiDimFunctionGraph<double> * ret)326     virtual void insertSetOfVars(MultiDimFunctionGraph< double >* ret) const {
327       for (SetIteratorSafe< const DiscreteVariable* > varIter = setOfVars_.beginSafe();
328            varIter != setOfVars_.endSafe();
329            ++varIter)
330         ret->add(**varIter);
331     }
332     /// @}
333 
334     protected:
335     /// @}
336 
337     // ###################################################################
338     /// @name Model handling datastructures
339     // ###################################################################
340     /// @{
341 
342     // ==========================================================================
343     /// The source of nodeId
344     // ==========================================================================
345     NodeGraphPart model_;
346 
347     // ==========================================================================
348     /// The root of the ordered tree
349     // ==========================================================================
350     NodeId root_;
351 
352     // ==========================================================================
353     /// Gives for any node its associated variable
354     // ==========================================================================
355     HashTable< NodeId, const DiscreteVariable* > nodeVarMap_;
356 
357     // ==========================================================================
358     /// A table giving for any node a table mapping to its son
359     /// idx is the modality of associated variable
360     // ==========================================================================
361     HashTable< NodeId, NodeId* > nodeSonsMap_;
362 
363     // ==========================================================================
364     /// Associates to any variable the list of all nodes associated to
365     /// this variable
366     // ==========================================================================
367     HashTable< const DiscreteVariable*, LinkedList< NodeId >* > var2Node_;
368 
369     // ==========================================================================
370     /// This hashtable binds every node to an associated NodeDatabase
371     /// which handles every observation that concerns that node
372     // ==========================================================================
373     HashTable< NodeId, NodeDatabase< AttributeSelection, isScalar >* > nodeId2Database_;
374 
375     // ==========================================================================
376     /// This hashtable binds to every leaf an associated set of all
377     /// hte observations compatible with it
378     // ==========================================================================
379     HashTable< NodeId, Set< const Observation* >* > leafDatabase_;
380 
381     /// @}
382 
383 
384     /// The final diagram we're building
385     MultiDimFunctionGraph< double >* target_;
386 
387     Set< const DiscreteVariable* > setOfVars_;
388 
389     const DiscreteVariable* value_;
390     Sequence< ValueType >   valueAssumed_;
391 
392     bool needUpdate_;
393   };
394 
395 
396 } /* namespace gum */
397 
398 #include <agrum/FMDP/learning/datastructure/incrementalGraphLearner_tpl.h>
399 
400 #endif   // GUM_INCREMENTAL_GRAPH_LEARNER_H
401