1 /** 2 * 3 * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU) 4 * info_at_agrum_dot_org 5 * 6 * This library is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU Lesser General Public License as published by 8 * the Free Software Foundation, either version 3 of the License, or 9 * (at your option) any later version. 10 * 11 * This library is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 * GNU Lesser General Public License for more details. 15 * 16 * You should have received a copy of the GNU Lesser General Public License 17 * along with this library. If not, see <http://www.gnu.org/licenses/>. 18 * 19 */ 20 21 22 /** 23 * @file 24 * @brief Headers of the interface specifying functions to be implemented by any 25 * incremental learner. 26 * 27 * @author Jean-Christophe MAGNAN 28 */ 29 30 // ========================================================================= 31 #ifndef GUM_INCREMENTAL_GRAPH_LEARNER_H 32 #define GUM_INCREMENTAL_GRAPH_LEARNER_H 33 // ========================================================================= 34 // ========================================================================= 35 #include <agrum/tools/multidim/implementations/multiDimFunctionGraph.h> 36 // ========================================================================= 37 #include <agrum/FMDP/learning/core/templateStrategy.h> 38 #include <agrum/FMDP/learning/datastructure/IVisitableGraphLearner.h> 39 #include <agrum/FMDP/learning/datastructure/nodeDatabase.h> 40 // ========================================================================= 41 #include <agrum/tools/multidim/utils/FunctionGraphUtilities/link.h> 42 // ========================================================================= 43 44 namespace gum { 45 46 /** 47 * @class IncrementalGraphLearner incrementalGraphLearner.h 48 * <agrum/FMDP/learning/datastructure/incrementalGraphLearner> 49 * @brief 50 * @ingroup fmdp_group 51 * 52 * Abstract class for incrementaly learn a graphical representation of a 53 * function. 54 * Can handle both function of real values, and function explaining the 55 * behaviour 56 * of a variable given set of other variables (as typically in conditionnal 57 * probabilities) 58 * 59 * Maintains two graph in memory, one which is incrementaly updated and the 60 * other one 61 * which is updated on demand and is usable by the outside. 62 * 63 */ 64 template < TESTNAME AttributeSelection, bool isScalar = false > 65 class IncrementalGraphLearner: public IVisitableGraphLearner { 66 typedef typename ValueSelect< isScalar, double, Idx >::type ValueType; 67 68 public: 69 // ################################################################### 70 /// @name Constructor & destructor. 71 // ################################################################### 72 /// @{ 73 74 // ========================================================================== 75 /** 76 * Default constructor 77 * @param target : the output diagram usable by the outside 78 * @param attributesSet : set of variables from which we try to describe the 79 * learned function 80 * @param learnVariable : if we tried to learn a the behaviour of a variable 81 * given variable given another set of variables, this is the one. If we are 82 * learning a function of real value, this is just a computationnal trick 83 * (and is to be deprecated) 84 */ 85 // ========================================================================== 86 IncrementalGraphLearner(MultiDimFunctionGraph< double >* target, 87 Set< const DiscreteVariable* > attributesSet, 88 const DiscreteVariable* learnVariable); 89 90 // ========================================================================== 91 /// Default destructor 92 // ========================================================================== 93 virtual ~IncrementalGraphLearner(); 94 95 private: 96 // ========================================================================== 97 /// Template function dispatcher 98 // ========================================================================== _clearValue_()99 void _clearValue_() { _clearValue_(Int2Type< isScalar >()); } 100 101 // ========================================================================== 102 /// In the case where we're learning a function of real values 103 /// this has to be wiped out upon destruction (to be deprecated) 104 // ========================================================================== _clearValue_(Int2Type<true>)105 void _clearValue_(Int2Type< true >) { delete value_; } 106 107 // ========================================================================== 108 /// In case where we're learning function of variable behaviour, 109 /// this should do nothing 110 // ========================================================================== _clearValue_(Int2Type<false>)111 void _clearValue_(Int2Type< false >) {} 112 113 /// @} 114 115 116 // ################################################################### 117 /// @name New Observation insertion methods 118 // ################################################################### 119 /// @{ 120 public: 121 // ========================================================================== 122 /** 123 * Inserts a new observation 124 */ 125 // ========================================================================== 126 virtual void addObservation(const Observation* obs); 127 128 private: 129 // ========================================================================== 130 /** 131 * Get value assumed by studied variable for current observation 132 */ 133 // ========================================================================== _assumeValue_(const Observation * obs)134 void _assumeValue_(const Observation* obs) { _assumeValue_(obs, Int2Type< isScalar >()); } _assumeValue_(const Observation * obs,Int2Type<true>)135 void _assumeValue_(const Observation* obs, Int2Type< true >) { 136 if (!valueAssumed_.exists(obs->reward())) valueAssumed_ << obs->reward(); 137 } _assumeValue_(const Observation * obs,Int2Type<false>)138 void _assumeValue_(const Observation* obs, Int2Type< false >) { 139 if (!valueAssumed_.exists(obs->modality(value_))) valueAssumed_ << obs->modality(value_); 140 } 141 142 143 // ========================================================================== 144 /** 145 * Seek modality assumed in obs for given var 146 */ 147 // ========================================================================== _branchObs_(const Observation * obs,const DiscreteVariable * var)148 Idx _branchObs_(const Observation* obs, const DiscreteVariable* var) { 149 return _branchObs_(obs, var, Int2Type< isScalar >()); 150 } _branchObs_(const Observation * obs,const DiscreteVariable * var,Int2Type<true>)151 Idx _branchObs_(const Observation* obs, const DiscreteVariable* var, Int2Type< true >) { 152 return obs->rModality(var); 153 } _branchObs_(const Observation * obs,const DiscreteVariable * var,Int2Type<false>)154 Idx _branchObs_(const Observation* obs, const DiscreteVariable* var, Int2Type< false >) { 155 return obs->modality(var); 156 } 157 158 protected: 159 // ========================================================================== 160 /** 161 * Will update internal graph's NodeDatabase of given node with the new 162 * observation 163 * @param newObs 164 * @param currentNodeId 165 */ 166 // ========================================================================== updateNodeWithObservation_(const Observation * newObs,NodeId currentNodeId)167 virtual void updateNodeWithObservation_(const Observation* newObs, NodeId currentNodeId) { 168 nodeId2Database_[currentNodeId]->addObservation(newObs); 169 } 170 171 /// @} 172 173 // ################################################################### 174 /// @name Graph Structure update methods 175 // ################################################################### 176 /// @{ 177 178 public: 179 // ========================================================================== 180 /// If a new modality appears to exists for given variable, 181 /// call this method to turn every associated node to this variable into 182 /// leaf. 183 /// Graph has then indeed to be revised 184 // ========================================================================== 185 virtual void updateVar(const DiscreteVariable*); 186 187 // ========================================================================== 188 /// Updates the tree after a new observation has been added 189 // ========================================================================== 190 virtual void updateGraph() = 0; 191 192 protected: 193 // ========================================================================== 194 /** 195 * From the given sets of node, selects randomly one and installs it 196 * on given node. Chechks of course if node's current variable is not in 197 * that 198 * set first. 199 * @param nody : the node we update 200 * @param bestVars : the set of interessting vars to be installed here 201 */ 202 // ========================================================================== 203 void updateNode_(NodeId nody, Set< const DiscreteVariable* >& bestVars); 204 205 // ========================================================================== 206 /// Turns the given node into a leaf if not already so 207 // ========================================================================== 208 virtual void convertNode2Leaf_(NodeId); 209 210 // ========================================================================== 211 /// Installs given variable to the given node, ensuring that the variable 212 /// is not present in its subtree 213 // ========================================================================== 214 virtual void transpose_(NodeId, const DiscreteVariable*); 215 216 // ========================================================================== 217 /** 218 * inserts a new node in internal graph 219 * @param nDB : the associated database 220 * @param boundVar : the associated variable 221 * @return the newly created node's id 222 */ 223 // ========================================================================== 224 virtual NodeId insertNode_(NodeDatabase< AttributeSelection, isScalar >* nDB, 225 const DiscreteVariable* boundVar); 226 227 // ========================================================================== 228 /** 229 * inserts a new internal node in internal graph 230 * @param nDB : the associated database 231 * @param boundVar : the associated variable 232 * @param sonsMap : a table giving node's sons node 233 * @return the newly created node's id 234 */ 235 // ========================================================================== 236 virtual NodeId insertInternalNode_(NodeDatabase< AttributeSelection, isScalar >* nDB, 237 const DiscreteVariable* boundVar, 238 NodeId* sonsMap); 239 240 // ========================================================================== 241 /** 242 * inserts a new leaf node in internal graohs 243 * @param nDB : the associated database 244 * @param boundVar : the associated variable 245 * @param obsSet : the set of observation this leaf retains 246 * @return the newly created node's id 247 */ 248 // ========================================================================== 249 virtual NodeId insertLeafNode_(NodeDatabase< AttributeSelection, isScalar >* nDB, 250 const DiscreteVariable* boundVar, 251 Set< const Observation* >* obsSet); 252 253 // ========================================================================== 254 /** 255 * Changes the associated variable of a node 256 * @param chgedNodeId : the node to change 257 * @param desiredVar : its new associated variable 258 */ 259 // ========================================================================== 260 virtual void chgNodeBoundVar_(NodeId chgedNodeId, const DiscreteVariable* desiredVar); 261 262 // ========================================================================== 263 /** 264 * Removes a node from the internal graph 265 * @param removedNodeId : the node to remove 266 */ 267 // ========================================================================== 268 virtual void removeNode_(NodeId removedNodeId); 269 270 /// @} 271 272 273 // ################################################################### 274 /// @name Function Graph Updating methods 275 // ################################################################### 276 /// @{ 277 public: 278 // ========================================================================== 279 /// Updates target to currently learned graph structure 280 // ========================================================================== 281 virtual void updateFunctionGraph() = 0; 282 283 /// @} 284 285 286 public: 287 // ========================================================================== 288 /// 289 // ========================================================================== size()290 Size size() { return nodeVarMap_.size(); } 291 292 293 // ################################################################### 294 /// @name Visit Methods 295 // ################################################################### 296 /// @{ 297 public: 298 // ========================================================================== 299 /// 300 // ========================================================================== root()301 NodeId root() const { return this->root_; } 302 303 // ========================================================================== 304 /// 305 // ========================================================================== isTerminal(NodeId ni)306 bool isTerminal(NodeId ni) const { return !this->nodeSonsMap_.exists(ni); } 307 308 // ========================================================================== 309 /// 310 // ========================================================================== nodeVar(NodeId ni)311 const DiscreteVariable* nodeVar(NodeId ni) const { return this->nodeVarMap_[ni]; } 312 313 // ========================================================================== 314 /// 315 // ========================================================================== nodeSon(NodeId ni,Idx modality)316 NodeId nodeSon(NodeId ni, Idx modality) const { return this->nodeSonsMap_[ni][modality]; } 317 318 // ========================================================================== 319 /// 320 // ========================================================================== nodeNbObservation(NodeId ni)321 Idx nodeNbObservation(NodeId ni) const { return this->nodeId2Database_[ni]->nbObservation(); } 322 323 // ========================================================================== 324 /// 325 // ========================================================================== insertSetOfVars(MultiDimFunctionGraph<double> * ret)326 virtual void insertSetOfVars(MultiDimFunctionGraph< double >* ret) const { 327 for (SetIteratorSafe< const DiscreteVariable* > varIter = setOfVars_.beginSafe(); 328 varIter != setOfVars_.endSafe(); 329 ++varIter) 330 ret->add(**varIter); 331 } 332 /// @} 333 334 protected: 335 /// @} 336 337 // ################################################################### 338 /// @name Model handling datastructures 339 // ################################################################### 340 /// @{ 341 342 // ========================================================================== 343 /// The source of nodeId 344 // ========================================================================== 345 NodeGraphPart model_; 346 347 // ========================================================================== 348 /// The root of the ordered tree 349 // ========================================================================== 350 NodeId root_; 351 352 // ========================================================================== 353 /// Gives for any node its associated variable 354 // ========================================================================== 355 HashTable< NodeId, const DiscreteVariable* > nodeVarMap_; 356 357 // ========================================================================== 358 /// A table giving for any node a table mapping to its son 359 /// idx is the modality of associated variable 360 // ========================================================================== 361 HashTable< NodeId, NodeId* > nodeSonsMap_; 362 363 // ========================================================================== 364 /// Associates to any variable the list of all nodes associated to 365 /// this variable 366 // ========================================================================== 367 HashTable< const DiscreteVariable*, LinkedList< NodeId >* > var2Node_; 368 369 // ========================================================================== 370 /// This hashtable binds every node to an associated NodeDatabase 371 /// which handles every observation that concerns that node 372 // ========================================================================== 373 HashTable< NodeId, NodeDatabase< AttributeSelection, isScalar >* > nodeId2Database_; 374 375 // ========================================================================== 376 /// This hashtable binds to every leaf an associated set of all 377 /// hte observations compatible with it 378 // ========================================================================== 379 HashTable< NodeId, Set< const Observation* >* > leafDatabase_; 380 381 /// @} 382 383 384 /// The final diagram we're building 385 MultiDimFunctionGraph< double >* target_; 386 387 Set< const DiscreteVariable* > setOfVars_; 388 389 const DiscreteVariable* value_; 390 Sequence< ValueType > valueAssumed_; 391 392 bool needUpdate_; 393 }; 394 395 396 } /* namespace gum */ 397 398 #include <agrum/FMDP/learning/datastructure/incrementalGraphLearner_tpl.h> 399 400 #endif // GUM_INCREMENTAL_GRAPH_LEARNER_H 401