1 /** 2 * 3 * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU) 4 * info_at_agrum_dot_org 5 * 6 * This library is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU Lesser General Public License as published by 8 * the Free Software Foundation, either version 3 of the License, or 9 * (at your option) any later version. 10 * 11 * This library is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 * GNU Lesser General Public License for more details. 15 * 16 * You should have received a copy of the GNU Lesser General Public License 17 * along with this library. If not, see <http://www.gnu.org/licenses/>. 18 * 19 */ 20 21 22 /** @file 23 * @brief the base class for estimating parameters of CPTs 24 * 25 * @author Christophe GONZALES(_at_AMU) and Pierre-Henri WUILLEMIN(_at_LIP6) 26 */ 27 #ifndef GUM_LEARNING_PARAM_ESTIMATOR_H 28 #define GUM_LEARNING_PARAM_ESTIMATOR_H 29 30 #include <type_traits> 31 32 #include <agrum/agrum.h> 33 #include <agrum/tools/database/databaseTable.h> 34 #include <agrum/BN/learning/aprioris/apriori.h> 35 #include <agrum/tools/stattests/recordCounter.h> 36 #include <agrum/tools/multidim/potential.h> 37 38 namespace gum { 39 40 namespace learning { 41 42 43 /** @class ParamEstimator 44 * @brief The base class for estimating parameters of CPTs 45 * @headerfile paramEstimator.h <agrum/BN/learning/paramUtils/paramEstimator.h> 46 * @ingroup learning_param_utils 47 */ 48 template < template < typename > class ALLOC = std::allocator > 49 class ParamEstimator { 50 public: 51 /// type for the allocators passed in arguments of methods 52 using allocator_type = ALLOC< NodeId >; 53 54 // ########################################################################## 55 /// @name Constructors / Destructors 56 // ########################################################################## 57 /// @{ 58 59 /// default constructor 60 /** @param parser the parser used to parse the database 61 * @param external_apriori An apriori that we add to the computation 62 * of the score 63 * @param score_internal_apriori The apriori within the score used 64 * to learn the data structure (might be a NoApriori) 65 * @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows 66 * indices. The countings are then performed only on the union of the 67 * rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing 68 * cross validation tasks, in which part of the database should be ignored. 69 * An empty set of ranges is equivalent to an interval [X,Y) ranging over 70 * the whole database. 71 * @param nodeId2Columns a mapping from the ids of the nodes in the 72 * graphical model to the corresponding column in the DatabaseTable 73 * parsed by the parser. This enables estimating from a database in 74 * which variable A corresponds to the 2nd column the parameters of a BN 75 * in which variable A has a NodeId of 5. An empty nodeId2Columns 76 * bijection means that the mapping is an identity, i.e., the value of a 77 * NodeId is equal to the index of the column in the DatabaseTable. 78 * @param alloc the allocator used to allocate the structures within the 79 * Score. 80 * @warning If nodeId2columns is not empty, then only the scores over the 81 * ids belonging to this bijection can be computed: applying method 82 * score() over other ids will raise exception NotFound. */ 83 ParamEstimator(const DBRowGeneratorParser< ALLOC >& parser, 84 const Apriori< ALLOC >& external_apriori, 85 const Apriori< ALLOC >& _score_internal_apriori, 86 const std::vector< std::pair< std::size_t, std::size_t >, 87 ALLOC< std::pair< std::size_t, std::size_t > > >& ranges, 88 const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >& nodeId2columns 89 = Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(), 90 const allocator_type& alloc = allocator_type()); 91 92 /// default constructor 93 /** @param parser the parser used to parse the database 94 * @param external_apriori An apriori that we add to the computation 95 * of the score 96 * @param score_internal_apriori The apriori within the score used 97 * to learn the data structure (might be a NoApriori) 98 * @param nodeId2Columns a mapping from the ids of the nodes in the 99 * graphical model to the corresponding column in the DatabaseTable 100 * parsed by the parser. This enables estimating from a database in 101 * which variable A corresponds to the 2nd column the parameters of a BN 102 * in which variable A has a NodeId of 5. An empty nodeId2Columns 103 * bijection means that the mapping is an identity, i.e., the value of a 104 * NodeId is equal to the index of the column in the DatabaseTable. 105 * @param alloc the allocator used to allocate the structures within the 106 * Score. 107 * @warning If nodeId2columns is not empty, then only the scores over the 108 * ids belonging to this bijection can be computed: applying method 109 * score() over other ids will raise exception NotFound. */ 110 ParamEstimator(const DBRowGeneratorParser< ALLOC >& parser, 111 const Apriori< ALLOC >& external_apriori, 112 const Apriori< ALLOC >& _score_internal_apriori, 113 const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >& nodeId2columns 114 = Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(), 115 const allocator_type& alloc = allocator_type()); 116 117 /// copy constructor 118 ParamEstimator(const ParamEstimator< ALLOC >& from); 119 120 /// copy constructor with a given allocator 121 ParamEstimator(const ParamEstimator< ALLOC >& from, const allocator_type& alloc); 122 123 /// move constructor 124 ParamEstimator(ParamEstimator< ALLOC >&& from); 125 126 /// move constructor with a given allocator 127 ParamEstimator(ParamEstimator< ALLOC >&& from, const allocator_type& alloc); 128 129 /// virtual copy constructor 130 virtual ParamEstimator< ALLOC >* clone() const = 0; 131 132 /// virtual copy constructor with a given allocator 133 virtual ParamEstimator< ALLOC >* clone(const allocator_type& alloc) const = 0; 134 135 /// destructor 136 virtual ~ParamEstimator(); 137 138 /// @} 139 140 141 // ########################################################################## 142 /// @name Accessors / Modifiers 143 // ########################################################################## 144 /// @{ 145 146 /// clears all the data structures from memory 147 virtual void clear(); 148 149 /// changes the max number of threads used to parse the database 150 virtual void setMaxNbThreads(std::size_t nb) const; 151 152 /// returns the number of threads used to parse the database 153 virtual std::size_t nbThreads() const; 154 155 /** @brief changes the number min of rows a thread should process in a 156 * multithreading context 157 * 158 * When computing score, several threads are used by record counters to 159 * perform countings on the rows of the database, the MinNbRowsPerThread 160 * method indicates how many rows each thread should at least process. 161 * This is used to compute the number of threads actually run. This number 162 * is equal to the min between the max number of threads allowed and the 163 * number of records in the database divided by nb. */ 164 virtual void setMinNbRowsPerThread(const std::size_t nb) const; 165 166 /// returns the minimum of rows that each thread should process 167 virtual std::size_t minNbRowsPerThread() const; 168 169 /// sets new ranges to perform the countings used by the parameter estimator 170 /** @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows 171 * indices. The countings are then performed only on the union of the 172 * rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing 173 * cross validation tasks, in which part of the database should be ignored. 174 * An empty set of ranges is equivalent to an interval [X,Y) ranging over 175 * the whole database. */ 176 template < template < typename > class XALLOC > 177 void setRanges( 178 const std::vector< std::pair< std::size_t, std::size_t >, 179 XALLOC< std::pair< std::size_t, std::size_t > > >& new_ranges); 180 181 /// reset the ranges to the one range corresponding to the whole database 182 void clearRanges(); 183 184 /// returns the current ranges 185 const std::vector< std::pair< std::size_t, std::size_t >, 186 ALLOC< std::pair< std::size_t, std::size_t > > >& 187 ranges() const; 188 189 /// returns the CPT's parameters corresponding to a given target node 190 std::vector< double, ALLOC< double > > parameters(const NodeId target_node); 191 192 /// returns the CPT's parameters corresponding to a given nodeset 193 /** The vector contains the parameters of an n-dimensional CPT. The 194 * distribution of the dimensions of the CPT within the vector is as 195 * follows: 196 * first, there is the target node, then the conditioning nodes (in the 197 * order in which they were specified). */ 198 virtual std::vector< double, ALLOC< double > > 199 parameters(const NodeId target_node, 200 const std::vector< NodeId, ALLOC< NodeId > >& conditioning_nodes) 201 = 0; 202 203 /// sets the CPT's parameters corresponding to a given Potential 204 /** The potential is assumed to be a conditional probability, the first 205 * variable of its variablesSequence() being the target variable, the 206 * other ones being on the right side of the conditioning bar. */ 207 template < typename GUM_SCALAR > 208 void setParameters(const NodeId target_node, 209 const std::vector< NodeId, ALLOC< NodeId > >& conditioning_nodes, 210 Potential< GUM_SCALAR >& pot); 211 212 /// returns the mapping from ids to column positions in the database 213 /** @warning An empty nodeId2Columns bijection means that the mapping is 214 * an identity, i.e., the value of a NodeId is equal to the index of the 215 * column in the DatabaseTable. */ 216 const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >& nodeId2Columns() const; 217 218 /// returns the database on which we perform the counts 219 const DatabaseTable< ALLOC >& database() const; 220 221 /// assign a new Bayes net to all the counter's generators depending on a BN 222 /** Typically, generators based on EM or K-means depend on a model to 223 * compute correctly their outputs. Method setBayesNet enables to 224 * update their BN model. */ 225 template < typename GUM_SCALAR > 226 void setBayesNet(const BayesNet< GUM_SCALAR >& new_bn); 227 228 /// returns the allocator used by the score 229 allocator_type getAllocator() const; 230 231 /// @} 232 233 protected: 234 /// an external a priori 235 Apriori< ALLOC >* external_apriori_{nullptr}; 236 237 /** @brief if a score was used for learning the structure of the PGM, this 238 * is the a priori internal to the score */ 239 Apriori< ALLOC >* score_internal_apriori_{nullptr}; 240 241 /// the record counter used to parse the database 242 RecordCounter< ALLOC > counter_; 243 244 /// an empty vector of nodes, used for empty conditioning 245 const std::vector< NodeId, ALLOC< NodeId > > empty_nodevect_; 246 247 248 /// copy operator 249 ParamEstimator< ALLOC >& operator=(const ParamEstimator< ALLOC >& from); 250 251 /// move operator 252 ParamEstimator< ALLOC >& operator=(ParamEstimator< ALLOC >&& from); 253 254 private: 255 #ifndef DOXYGEN_SHOULD_SKIP_THIS 256 257 /** @brief check the coherency between the parameters passed to 258 * the setParameters functions */ 259 template < typename GUM_SCALAR > 260 void _checkParameters_(const NodeId target_node, 261 const std::vector< NodeId, ALLOC< NodeId > >& conditioning_nodes, 262 Potential< GUM_SCALAR >& pot); 263 264 // sets the CPT's parameters corresponding to a given Potential 265 // when the potential belongs to a BayesNet<GUM_SCALAR> when 266 // GUM_SCALAR is different from a double 267 template < typename GUM_SCALAR > 268 typename std::enable_if< !std::is_same< GUM_SCALAR, double >::value, void >::type 269 _setParameters_(const NodeId target_node, 270 const std::vector< NodeId, ALLOC< NodeId > >& conditioning_nodes, 271 Potential< GUM_SCALAR >& pot); 272 273 // sets the CPT's parameters corresponding to a given Potential 274 // when the potential belongs to a BayesNet<GUM_SCALAR> when 275 // GUM_SCALAR is equal to double (the code is optimized for doubles) 276 template < typename GUM_SCALAR > 277 typename std::enable_if< std::is_same< GUM_SCALAR, double >::value, void >::type 278 _setParameters_(const NodeId target_node, 279 const std::vector< NodeId, ALLOC< NodeId > >& conditioning_nodes, 280 Potential< GUM_SCALAR >& pot); 281 282 #endif /* DOXYGEN_SHOULD_SKIP_THIS */ 283 }; 284 285 } /* namespace learning */ 286 287 } /* namespace gum */ 288 289 /// include the template implementation 290 #include <agrum/BN/learning/paramUtils/paramEstimator_tpl.h> 291 292 #endif /* GUM_LEARNING_PARAM_ESTIMATOR_H */ 293