1 /** 2 * 3 * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU) 4 * info_at_agrum_dot_org 5 * 6 * This library is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU Lesser General Public License as published by 8 * the Free Software Foundation, either version 3 of the License, or 9 * (at your option) any later version. 10 * 11 * This library is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 * GNU Lesser General Public License for more details. 15 * 16 * You should have received a copy of the GNU Lesser General Public License 17 * along with this library. If not, see <http://www.gnu.org/licenses/>. 18 * 19 */ 20 21 22 /** @file 23 * @brief The class that computes countings of observations from the database. 24 * 25 * This class is the one to be called by scores and independence tests to 26 * compute countings of observations from tabular databases. 27 * 28 * @author Christophe GONZALES(_at_AMU) and Pierre-Henri WUILLEMIN(_at_LIP6) 29 */ 30 #ifndef GUM_LEARNING_RECORD_COUNTER_H 31 #define GUM_LEARNING_RECORD_COUNTER_H 32 33 #include <vector> 34 #include <utility> 35 #include <sstream> 36 #include <string> 37 38 #include <agrum/agrum.h> 39 #include <agrum/tools/core/bijection.h> 40 #include <agrum/tools/core/sequence.h> 41 #include <agrum/tools/core/OMPThreads.h> 42 #include <agrum/tools/core/threadData.h> 43 #include <agrum/tools/graphs/DAG.h> 44 #include <agrum/tools/database/DBRowGeneratorParser.h> 45 #include <agrum/tools/stattests/idCondSet.h> 46 47 48 namespace gum { 49 50 namespace learning { 51 52 /** @class RecordCounter 53 * @brief The class that computes countings of observations from the database. 54 * @headerfile recordCounter.h <agrum/BN/learning/scores_and_tests/recordCounter.h> 55 * @ingroup learning_scores 56 * 57 * This class is the one to be called by scores and independence tests to 58 * compute the countings of observations from tabular datasets they need. 59 * The countings are performed the following way: 60 * when asked for the countings over a set X = {X_1,...,X_n} of 61 * variables, the RecordCounter first checks whether it already contains 62 * some countings over a set Y of variables containing X. If this is the 63 * case, then it extracts from the countings over Y those over X (this is 64 * usually way faster than determining the countings by parsing the database). 65 * Otherwise, it determines the countings over X by parsing in a parallel 66 * way the database. Only the result of the last database-parsed countings 67 * is available for the subset counting determination. As an example, if 68 * we create a RecordCounter and ask it the countings over {A,B,C}, it will 69 * parse the database and provide the countings. Then, if we ask it countings 70 * over B, it will use the table over {A,B,C} to produce the countings we 71 * look for. Then, asking for countings over {A,C} will be performed the same 72 * way. Now, asking countings over {B,C,D} will require another database 73 * parsing. Finally, if we ask for countings over A, a new database parsing 74 * will be performed because only the countings over {B,C,D} are now contained 75 * in the RecordCounter. 76 * 77 * @par Here is an example of how to use the RecordCounter class: 78 * @code 79 * // here, write the code to construct your database, e.g.: 80 * gum::learning::DBInitializerFromCSV<> initializer( "file.csv" ); 81 * const auto& var_names = initializer.variableNames(); 82 * const std::size_t nb_vars = var_names.size(); 83 * gum::learning::DBTranslatorSet<> translator_set; 84 * gum::learning::DBTranslator4ContinuousVariable<> translator; 85 * for (std::size_t i = 0; i < nb_vars; ++i) { 86 * translator_set.insertTranslator(translator, i); 87 * } 88 * gum::learning::DatabaseTable<> database(translator_set); 89 * 90 * // create the parser of the database 91 * gum::learning::DBRowGeneratorSet<> genset; 92 * gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset); 93 * 94 * // create the record counter 95 * gum::learning::RecordCounter<> counter(parser); 96 * 97 * // get the counts: 98 * gum::learning::IdCondSet<> ids ( 0, gum::vector<gum::NodeId> {2,1} ); 99 * const std::vector< double >& counts1 = counter.counts ( ids ); 100 * 101 * // change the rows from which we compute the counts: 102 * // they should now be made on rows [500,600) U [1050,1125) U [100,150) 103 * std::vector<std::pair<std::size_t,std::size_t>> new_ranges 104 * { std::pair<std::size_t,std::size_t>(500,600), 105 * std::pair<std::size_t,std::size_t>(1050,1125), 106 * std::pair<std::size_t,std::size_t>(100,150) }; 107 * counter.setRanges ( new_ranges ); 108 * const std::vector< double >& counts2 = counter.counts ( ids ); 109 * @endcode 110 */ 111 template < template < typename > class ALLOC = std::allocator > 112 class RecordCounter { 113 public: 114 /// type for the allocators passed in arguments of methods 115 using allocator_type = ALLOC< NodeId >; 116 117 // ########################################################################## 118 /// @name Constructors / Destructors 119 // ########################################################################## 120 /// @{ 121 122 /// default constructor 123 /** @param parser the parser used to parse the database 124 * @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows 125 * indices. The countings are then performed only on the union of the 126 * rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing 127 * cross validation tasks, in which part of the database should be ignored. 128 * An empty set of ranges is equivalent to an interval [X,Y) ranging over 129 * the whole database. 130 * @param nodeId2Columns a mapping from the ids of the nodes in the 131 * graphical model to the corresponding column in the DatabaseTable 132 * parsed by the parser. This enables estimating from a database in 133 * which variable A corresponds to the 2nd column the parameters of a BN 134 * in which variable A has a NodeId of 5. An empty nodeId2Columns 135 * bijection means that the mapping is an identity, i.e., the value of a 136 * NodeId is equal to the index of the column in the DatabaseTable. 137 * @param alloc the allocator used to allocate the structures within the 138 * RecordCounter. 139 * @warning If nodeId2columns is not empty, then only the counts over the 140 * ids belonging to this bijection can be computed: applying method 141 * counts() over other ids will raise exception NotFound. */ 142 RecordCounter(const DBRowGeneratorParser< ALLOC >& parser, 143 const std::vector< std::pair< std::size_t, std::size_t >, 144 ALLOC< std::pair< std::size_t, std::size_t > > >& ranges, 145 const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >& nodeId2columns 146 = Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(), 147 const allocator_type& alloc = allocator_type()); 148 149 /// default constructor 150 /** @param parser the parser used to parse the database 151 * @param nodeId2Columns a mapping from the ids of the nodes in the 152 * graphical model to the corresponding column in the DatabaseTable 153 * parsed by the parser. This enables estimating from a database in 154 * which variable A corresponds to the 2nd column the parameters of a BN 155 * in which variable A has a NodeId of 5. An empty nodeId2Columns 156 * bijection means that the mapping is an identity, i.e., the value of a 157 * NodeId is equal to the index of the column in the DatabaseTable. 158 * @param alloc the allocator used to allocate the structures within the 159 * RecordCounter. 160 * @warning If nodeId2columns is not empty, then only the counts over the 161 * ids belonging to this bijection can be computed: applying method 162 * counts() over other ids will raise exception NotFound. */ 163 RecordCounter(const DBRowGeneratorParser< ALLOC >& parser, 164 const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >& nodeId2columns 165 = Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(), 166 const allocator_type& alloc = allocator_type()); 167 168 /// copy constructor 169 RecordCounter(const RecordCounter< ALLOC >& from); 170 171 /// copy constructor with a given allocator 172 RecordCounter(const RecordCounter< ALLOC >& from, const allocator_type& alloc); 173 174 /// move constructor 175 RecordCounter(RecordCounter< ALLOC >&& from); 176 177 /// move constructor with a given allocator 178 RecordCounter(RecordCounter< ALLOC >&& from, const allocator_type& alloc); 179 180 /// virtual copy constructor 181 virtual RecordCounter< ALLOC >* clone() const; 182 183 /// virtual copy constructor with a given allocator 184 virtual RecordCounter< ALLOC >* clone(const allocator_type& alloc) const; 185 186 /// destructor 187 virtual ~RecordCounter(); 188 189 /// @} 190 191 192 // ########################################################################## 193 /// @name Operators 194 // ########################################################################## 195 196 /// @{ 197 198 /// copy operator 199 RecordCounter< ALLOC >& operator=(const RecordCounter< ALLOC >& from); 200 201 /// move operator 202 RecordCounter< ALLOC >& operator=(RecordCounter< ALLOC >&& from); 203 204 /// @} 205 206 207 // ########################################################################## 208 /// @name Accessors / Modifiers 209 // ########################################################################## 210 211 /// @{ 212 213 /// clears all the last database-parsed countings from memory 214 void clear(); 215 216 /// changes the max number of threads used to parse the database 217 void setMaxNbThreads(const std::size_t nb) const; 218 219 /// returns the number of threads used to parse the database 220 std::size_t nbThreads() const; 221 222 /** @brief changes the number min of rows a thread should process in a 223 * multithreading context 224 * 225 * When Method counts executes several threads to perform countings on the 226 * rows of the database, the MinNbRowsPerThread indicates how many rows each 227 * thread should at least process. This is used to compute the number of 228 * threads actually run. This number is equal to the min between the max 229 * number of threads allowed and the number of records in the database 230 * divided by nb. */ 231 void setMinNbRowsPerThread(const std::size_t nb) const; 232 233 /// returns the minimum of rows that each thread should process 234 std::size_t minNbRowsPerThread() const; 235 236 /// returns the counts over all the variables in an IdCondSet 237 /** @param ids the idset of the variables over which we perform countings. 238 * @param check_discrete_vars The record counter can only produce correct 239 * results on sets of discrete variables. By default, the method does not 240 * check whether the variables corresponding to the IdCondSet are actually 241 * discrete. If check_discrete_vars is set to true, then this check is 242 * performed before computing the counting vector. In this case, if a 243 * variable is not discrete, a TypeError exception is raised. 244 * @return a vector containing the multidimensional contingency table 245 * over all the variables corresponding to the ids passed in argument 246 * (both at the left hand side and right hand side of the conditioning 247 * bar of the IdCondSet). The first dimension is that of the first variable 248 * in the IdCondSet, i.e., when its value increases by 1, the offset in the 249 * output vector also increases by 1. The second dimension is that of the 250 * second variable in the IdCondSet, i.e., when its value increases by 1, the 251 * offset in the ouput vector increases by the domain size of the first 252 * variable. For the third variable, the offset corresponds to the product 253 * of the domain sizes of the first two variables, and so on. 254 * @warning The vector returned by the function may differ from one 255 * call to another. So, care must be taken. E,g. a code like: 256 * @code 257 * const std::vector< double, ALLOC<double> >& 258 * counts = counter.counts(ids); 259 * counts = counter.counts(other_ids); 260 * @endcode 261 * may be erroneous because the two calls to method counts() may 262 * return references to different vectors. The correct way of using method 263 * counts() is always to call it declaring a new reference variable: 264 * @code 265 * const std::vector< double, ALLOC<double> >& counts = 266 * counter.counts(ids); 267 * const std::vector< double, ALLOC<double> >& other_counts = 268 * counter.counts(other_ids); 269 * @endcode 270 * @throw TypeError is raised if check_discrete_vars is set to true (i.e., 271 * we check that all variables in the IdCondSet are discrete) and if at least 272 * one variable is not of a discrete nature. 273 */ 274 const std::vector< double, ALLOC< double > >& counts(const IdCondSet< ALLOC >& ids, 275 const bool check_discrete_vars = false); 276 277 /// sets new ranges to perform the countings 278 /** @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows 279 * indices. The countings are then performed only on the union of the 280 * rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing 281 * cross validation tasks, in which part of the database should be ignored. 282 * An empty set of ranges is equivalent to an interval [X,Y) ranging over 283 * the whole database. */ 284 template < template < typename > class XALLOC > 285 void setRanges( 286 const std::vector< std::pair< std::size_t, std::size_t >, 287 XALLOC< std::pair< std::size_t, std::size_t > > >& new_ranges); 288 289 /// reset the ranges to the one range corresponding to the whole database 290 void clearRanges(); 291 292 /// returns the current ranges 293 const std::vector< std::pair< std::size_t, std::size_t >, 294 ALLOC< std::pair< std::size_t, std::size_t > > >& 295 ranges() const; 296 297 /// assign a new Bayes net to all the counter's generators depending on a BN 298 /** Typically, generators based on EM or K-means depend on a model to 299 * compute correctly their outputs. Method setBayesNet enables to 300 * update their BN model. */ 301 template < typename GUM_SCALAR > 302 void setBayesNet(const BayesNet< GUM_SCALAR >& new_bn); 303 304 /// returns the allocator used 305 allocator_type getAllocator() const; 306 307 /// returns the mapping from ids to column positions in the database 308 /** @warning An empty nodeId2Columns bijection means that the mapping is 309 * an identity, i.e., the value of a NodeId is equal to the index of the 310 * column in the DatabaseTable. */ 311 const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >& nodeId2Columns() const; 312 313 /// returns the database on which we perform the counts 314 const DatabaseTable< ALLOC >& database() const; 315 316 /// @} 317 318 319 #ifndef DOXYGEN_SHOULD_SKIP_THIS 320 321 private: 322 // the parsers used by the threads 323 std::vector< ThreadData< DBRowGeneratorParser< ALLOC > >, 324 ALLOC< ThreadData< DBRowGeneratorParser< ALLOC > > > > 325 _parsers_; 326 327 // the set of ranges of the database's rows indices over which the user 328 // wishes to perform the countings 329 std::vector< std::pair< std::size_t, std::size_t >, 330 ALLOC< std::pair< std::size_t, std::size_t > > > 331 _ranges_; 332 333 // the ranges actually used by the threads: there is a hopefully clever 334 // algorithm that split the rows ranges into another set of ranges that 335 // are assigned to the threads. For instance, if the database has 1000 336 // rows and there are 10 threads, each one will be assed a set of 100 337 // rows. These sets are precisely what are stored in the field below 338 mutable std::vector< std::pair< std::size_t, std::size_t >, 339 ALLOC< std::pair< std::size_t, std::size_t > > > 340 _thread_ranges_; 341 342 // the mapping from the NodeIds of the variables to the indices of the 343 // columns in the database 344 Bijection< NodeId, std::size_t, ALLOC< std::size_t > > _nodeId2columns_; 345 346 // the last database-parsed countings 347 std::vector< double, ALLOC< double > > _last_DB_countings_; 348 349 // the ids of the nodes for the last database-parsed countings 350 IdCondSet< ALLOC > _last_DB_ids_; 351 352 // the last countings deduced from _last_DB_countings_ 353 std::vector< double, ALLOC< double > > _last_nonDB_countings_; 354 355 // the ids of the nodes of last countings deduced from _last_DB_countings_ 356 IdCondSet< ALLOC > _last_nonDB_ids_; 357 358 // the maximal number of threads that the record counter can use 359 mutable std::size_t _max_nb_threads_{std::size_t(gum::getMaxNumberOfThreads())}; 360 361 // the min number of rows that a thread should process in a 362 // multithreading context 363 mutable std::size_t _min_nb_rows_per_thread_{100}; 364 365 // returns a mapping from the nodes ids to the columns of the database 366 // for a given sequence of ids. This is especially convenient when 367 // _nodeId2columns_ is empty (which means that there is an identity mapping) 368 HashTable< NodeId, std::size_t > _getNodeIds2Columns_(const IdCondSet< ALLOC >& ids) const; 369 370 /// extracts some new countings from previously computed ones 371 std::vector< double, ALLOC< double > >& 372 _extractFromCountings_(const IdCondSet< ALLOC >& subset_ids, 373 const IdCondSet< ALLOC >& superset_ids, 374 const std::vector< double, ALLOC< double > >& superset_vect); 375 376 /// parse the database to produce new countings 377 std::vector< double, ALLOC< double > >& _countFromDatabase_(const IdCondSet< ALLOC >& ids); 378 379 /// the method used by threads to produce countings by parsing the database 380 void _threadedCount_( 381 const std::size_t range_begin, 382 const std::size_t range_end, 383 DBRowGeneratorParser< ALLOC >& parser, 384 const std::vector< std::pair< std::size_t, std::size_t >, 385 ALLOC< std::pair< std::size_t, std::size_t > > >& cols_and_offsets, 386 std::vector< double, ALLOC< double > >& countings); 387 388 /// checks that the ranges passed in argument are ok or raise an exception 389 /** A range is ok if its upper bound is strictly higher than its lower 390 * bound and the latter is also lower than or equal to the number of rows 391 * in the database. */ 392 template < template < typename > class XALLOC > 393 void _checkRanges_( 394 const std::vector< std::pair< std::size_t, std::size_t >, 395 XALLOC< std::pair< std::size_t, std::size_t > > >& new_ranges) const; 396 397 /// check that the variables at indices [beg,end) of an idset are discrete 398 /** @throw TypeError is raised if at least one variable in ids is 399 * of a continuous nature. */ 400 void _checkDiscreteVariables_(const IdCondSet< ALLOC >& ids) const; 401 402 /// compute and raise the exception when some variables are continuous 403 /** This method is used by _checkDiscreteVariables_ to determine the 404 * appropriate message to include in the TypeError exception raised when 405 * some variables over which we should perform countings are continuous. */ 406 void _raiseCheckException_( 407 const std::vector< std::string, ALLOC< std::string > >& bad_vars) const; 408 409 /// sets the ranges within which each thread will perform its computations 410 void _dispatchRangesToThreads_(); 411 412 #endif /* DOXYGEN_SHOULD_SKIP_THIS */ 413 }; 414 415 } /* namespace learning */ 416 417 } /* namespace gum */ 418 419 /// always include the templated implementations 420 #include <agrum/tools/stattests/recordCounter_tpl.h> 421 422 #endif /* GUM_LEARNING_RECORD_COUNTER_H */ 423