1 /**
2  *
3  *   Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU)
4  *   info_at_agrum_dot_org
5  *
6  *  This library is free software: you can redistribute it and/or modify
7  *  it under the terms of the GNU Lesser General Public License as published by
8  *  the Free Software Foundation, either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  This library is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU Lesser General Public License for more details.
15  *
16  *  You should have received a copy of the GNU Lesser General Public License
17  *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /** @file
23  * @brief The class that computes countings of observations from the database.
24  *
25  * This class is the one to be called by scores and independence tests to
26  * compute countings of observations from tabular databases.
27  *
28  * @author Christophe GONZALES(_at_AMU) and Pierre-Henri WUILLEMIN(_at_LIP6)
29  */
30 #ifndef GUM_LEARNING_RECORD_COUNTER_H
31 #define GUM_LEARNING_RECORD_COUNTER_H
32 
33 #include <vector>
34 #include <utility>
35 #include <sstream>
36 #include <string>
37 
38 #include <agrum/agrum.h>
39 #include <agrum/tools/core/bijection.h>
40 #include <agrum/tools/core/sequence.h>
41 #include <agrum/tools/core/OMPThreads.h>
42 #include <agrum/tools/core/threadData.h>
43 #include <agrum/tools/graphs/DAG.h>
44 #include <agrum/tools/database/DBRowGeneratorParser.h>
45 #include <agrum/tools/stattests/idCondSet.h>
46 
47 
48 namespace gum {
49 
50   namespace learning {
51 
52     /** @class RecordCounter
53      * @brief The class that computes countings of observations from the database.
54      * @headerfile recordCounter.h <agrum/BN/learning/scores_and_tests/recordCounter.h>
55      * @ingroup learning_scores
56      *
57      * This class is the one to be called by scores and independence tests to
58      * compute the countings of observations from tabular datasets they need.
59      * The countings are performed the following way:
60      * when asked for the countings over a set X = {X_1,...,X_n} of
61      * variables, the RecordCounter first checks whether it already contains
62      * some countings over a set Y of variables containing X. If this is the
63      * case, then it extracts from the countings over Y those over X (this is
64      * usually way faster than determining the countings by parsing the database).
65      * Otherwise, it determines the countings over X by parsing in a parallel
66      * way the database. Only the result of the last database-parsed countings
67      * is available for the subset counting determination. As an example, if
68      * we create a RecordCounter and ask it the countings over {A,B,C}, it will
69      * parse the database and provide the countings. Then, if we ask it countings
70      * over B, it will use the table over {A,B,C} to produce the countings we
71      * look for. Then, asking for countings over {A,C} will be performed the same
72      * way. Now, asking countings over {B,C,D} will require another database
73      * parsing. Finally, if we ask for countings over A, a new database parsing
74      * will be performed because only the countings over {B,C,D} are now contained
75      * in the RecordCounter.
76      *
77      * @par Here is an example of how to use the RecordCounter class:
78      * @code
79      * // here, write the code to construct your database, e.g.:
80      * gum::learning::DBInitializerFromCSV<> initializer( "file.csv" );
81      * const auto& var_names = initializer.variableNames();
82      * const std::size_t nb_vars = var_names.size();
83      * gum::learning::DBTranslatorSet<> translator_set;
84      * gum::learning::DBTranslator4ContinuousVariable<> translator;
85      * for (std::size_t i = 0; i < nb_vars; ++i) {
86      *   translator_set.insertTranslator(translator, i);
87      * }
88      * gum::learning::DatabaseTable<> database(translator_set);
89      *
90      * // create the parser of the database
91      * gum::learning::DBRowGeneratorSet<> genset;
92      * gum::learning::DBRowGeneratorParser<> parser(database.handler(), genset);
93      *
94      * // create the record counter
95      * gum::learning::RecordCounter<> counter(parser);
96      *
97      * // get the counts:
98      * gum::learning::IdCondSet<> ids ( 0, gum::vector<gum::NodeId> {2,1} );
99      * const std::vector< double >& counts1 = counter.counts ( ids );
100      *
101      * // change the rows from which we compute the counts:
102      * // they should now be made on rows [500,600) U [1050,1125) U [100,150)
103      * std::vector<std::pair<std::size_t,std::size_t>> new_ranges
104      *   { std::pair<std::size_t,std::size_t>(500,600),
105      *     std::pair<std::size_t,std::size_t>(1050,1125),
106      *     std::pair<std::size_t,std::size_t>(100,150) };
107      * counter.setRanges ( new_ranges );
108      * const std::vector< double >& counts2 = counter.counts ( ids );
109      * @endcode
110      */
111     template < template < typename > class ALLOC = std::allocator >
112     class RecordCounter {
113       public:
114       /// type for the allocators passed in arguments of methods
115       using allocator_type = ALLOC< NodeId >;
116 
117       // ##########################################################################
118       /// @name Constructors / Destructors
119       // ##########################################################################
120       /// @{
121 
122       /// default constructor
123       /** @param parser the parser used to parse the database
124        * @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
125        * indices. The countings are then performed only on the union of the
126        * rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing
127        * cross validation tasks, in which part of the database should be ignored.
128        * An empty set of ranges is equivalent to an interval [X,Y) ranging over
129        * the whole database.
130        * @param nodeId2Columns a mapping from the ids of the nodes in the
131        * graphical model to the corresponding column in the DatabaseTable
132        * parsed by the parser. This enables estimating from a database in
133        * which variable A corresponds to the 2nd column the parameters of a BN
134        * in which variable A has a NodeId of 5. An empty nodeId2Columns
135        * bijection means that the mapping is an identity, i.e., the value of a
136        * NodeId is equal to the index of the column in the DatabaseTable.
137        * @param alloc the allocator used to allocate the structures within the
138        * RecordCounter.
139        * @warning If nodeId2columns is not empty, then only the counts over the
140        * ids belonging to this bijection can be computed: applying method
141        * counts() over other ids will raise exception NotFound. */
142       RecordCounter(const DBRowGeneratorParser< ALLOC >&                                 parser,
143                     const std::vector< std::pair< std::size_t, std::size_t >,
144                                        ALLOC< std::pair< std::size_t, std::size_t > > >& ranges,
145                     const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >& nodeId2columns
146                     = Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(),
147                     const allocator_type& alloc = allocator_type());
148 
149       /// default constructor
150       /** @param parser the parser used to parse the database
151        * @param nodeId2Columns a mapping from the ids of the nodes in the
152        * graphical model to the corresponding column in the DatabaseTable
153        * parsed by the parser. This enables estimating from a database in
154        * which variable A corresponds to the 2nd column the parameters of a BN
155        * in which variable A has a NodeId of 5. An empty nodeId2Columns
156        * bijection means that the mapping is an identity, i.e., the value of a
157        * NodeId is equal to the index of the column in the DatabaseTable.
158        * @param alloc the allocator used to allocate the structures within the
159        * RecordCounter.
160        * @warning If nodeId2columns is not empty, then only the counts over the
161        * ids belonging to this bijection can be computed: applying method
162        * counts() over other ids will raise exception NotFound. */
163       RecordCounter(const DBRowGeneratorParser< ALLOC >&                          parser,
164                     const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >& nodeId2columns
165                     = Bijection< NodeId, std::size_t, ALLOC< std::size_t > >(),
166                     const allocator_type& alloc = allocator_type());
167 
168       /// copy constructor
169       RecordCounter(const RecordCounter< ALLOC >& from);
170 
171       /// copy constructor with a given allocator
172       RecordCounter(const RecordCounter< ALLOC >& from, const allocator_type& alloc);
173 
174       /// move constructor
175       RecordCounter(RecordCounter< ALLOC >&& from);
176 
177       /// move constructor with a given allocator
178       RecordCounter(RecordCounter< ALLOC >&& from, const allocator_type& alloc);
179 
180       /// virtual copy constructor
181       virtual RecordCounter< ALLOC >* clone() const;
182 
183       /// virtual copy constructor with a given allocator
184       virtual RecordCounter< ALLOC >* clone(const allocator_type& alloc) const;
185 
186       /// destructor
187       virtual ~RecordCounter();
188 
189       /// @}
190 
191 
192       // ##########################################################################
193       /// @name Operators
194       // ##########################################################################
195 
196       /// @{
197 
198       /// copy operator
199       RecordCounter< ALLOC >& operator=(const RecordCounter< ALLOC >& from);
200 
201       /// move operator
202       RecordCounter< ALLOC >& operator=(RecordCounter< ALLOC >&& from);
203 
204       /// @}
205 
206 
207       // ##########################################################################
208       /// @name Accessors / Modifiers
209       // ##########################################################################
210 
211       /// @{
212 
213       /// clears all the last database-parsed countings from memory
214       void clear();
215 
216       /// changes the max number of threads used to parse the database
217       void setMaxNbThreads(const std::size_t nb) const;
218 
219       /// returns the number of threads used to parse the database
220       std::size_t nbThreads() const;
221 
222       /** @brief changes the number min of rows a thread should process in a
223        * multithreading context
224        *
225        * When Method counts executes several threads to perform countings on the
226        * rows of the database, the MinNbRowsPerThread indicates how many rows each
227        * thread should at least process. This is used to compute the number of
228        * threads actually run. This number is equal to the min between the max
229        * number of threads allowed and the number of records in the database
230        * divided by nb. */
231       void setMinNbRowsPerThread(const std::size_t nb) const;
232 
233       /// returns the minimum of rows that each thread should process
234       std::size_t minNbRowsPerThread() const;
235 
236       /// returns the counts over all the variables in an IdCondSet
237       /** @param ids the idset of the variables over which we perform countings.
238        * @param check_discrete_vars The record counter can only produce correct
239        * results on sets of discrete variables. By default, the method does not
240        * check whether the variables corresponding to the IdCondSet are actually
241        * discrete. If check_discrete_vars is set to true, then this check is
242        * performed before computing the counting vector. In this case, if a
243        * variable is not discrete, a TypeError exception is raised.
244        * @return a vector containing the multidimensional contingency table
245        * over all the variables corresponding to the ids passed in argument
246        * (both at the left hand side and right hand side of the conditioning
247        * bar of the IdCondSet). The first dimension is that of the first variable
248        * in the IdCondSet, i.e., when its value increases by 1, the offset in the
249        * output vector also increases by 1. The second dimension is that of the
250        * second variable in the IdCondSet, i.e., when its value increases by 1, the
251        * offset in the ouput vector increases by the domain size of the first
252        * variable. For the third variable, the offset corresponds to the product
253        * of the domain sizes of the first two variables, and so on.
254        * @warning The vector returned by the function may differ from one
255        * call to another. So, care must be taken. E,g. a code like:
256        * @code
257        * const std::vector< double, ALLOC<double> >&
258        * counts = counter.counts(ids);
259        * counts = counter.counts(other_ids);
260        * @endcode
261        * may be erroneous because the two calls to method counts() may
262        * return references to different vectors. The correct way of using method
263        * counts() is always to call it declaring a new reference variable:
264        * @code
265        * const std::vector< double, ALLOC<double> >& counts =
266        *   counter.counts(ids);
267        * const std::vector< double, ALLOC<double> >& other_counts =
268        *   counter.counts(other_ids);
269        * @endcode
270        * @throw TypeError is raised if check_discrete_vars is set to true (i.e.,
271        * we check that all variables in the IdCondSet are discrete) and if at least
272        * one variable is not of a discrete nature.
273        */
274       const std::vector< double, ALLOC< double > >& counts(const IdCondSet< ALLOC >& ids,
275                                                            const bool check_discrete_vars = false);
276 
277       /// sets new ranges to perform the countings
278       /** @param ranges a set of pairs {(X1,Y1),...,(Xn,Yn)} of database's rows
279        * indices. The countings are then performed only on the union of the
280        * rows [Xi,Yi), i in {1,...,n}. This is useful, e.g, when performing
281        * cross validation tasks, in which part of the database should be ignored.
282        * An empty set of ranges is equivalent to an interval [X,Y) ranging over
283        * the whole database. */
284       template < template < typename > class XALLOC >
285       void setRanges(
286          const std::vector< std::pair< std::size_t, std::size_t >,
287                             XALLOC< std::pair< std::size_t, std::size_t > > >& new_ranges);
288 
289       /// reset the ranges to the one range corresponding to the whole database
290       void clearRanges();
291 
292       /// returns the current ranges
293       const std::vector< std::pair< std::size_t, std::size_t >,
294                          ALLOC< std::pair< std::size_t, std::size_t > > >&
295          ranges() const;
296 
297       /// assign a new Bayes net to all the counter's generators depending on a BN
298       /** Typically, generators based on EM or K-means depend on a model to
299        * compute correctly their outputs. Method setBayesNet enables to
300        * update their BN model. */
301       template < typename GUM_SCALAR >
302       void setBayesNet(const BayesNet< GUM_SCALAR >& new_bn);
303 
304       /// returns the allocator used
305       allocator_type getAllocator() const;
306 
307       /// returns the mapping from ids to column positions in the database
308       /** @warning An empty nodeId2Columns bijection means that the mapping is
309        * an identity, i.e., the value of a NodeId is equal to the index of the
310        * column in the DatabaseTable. */
311       const Bijection< NodeId, std::size_t, ALLOC< std::size_t > >& nodeId2Columns() const;
312 
313       /// returns the database on which we perform the counts
314       const DatabaseTable< ALLOC >& database() const;
315 
316       /// @}
317 
318 
319 #ifndef DOXYGEN_SHOULD_SKIP_THIS
320 
321       private:
322       // the parsers used by the threads
323       std::vector< ThreadData< DBRowGeneratorParser< ALLOC > >,
324                    ALLOC< ThreadData< DBRowGeneratorParser< ALLOC > > > >
325          _parsers_;
326 
327       // the set of ranges of the database's rows indices over which the user
328       // wishes to perform the countings
329       std::vector< std::pair< std::size_t, std::size_t >,
330                    ALLOC< std::pair< std::size_t, std::size_t > > >
331          _ranges_;
332 
333       // the ranges actually used by the threads: there is a hopefully clever
334       // algorithm that split the rows ranges into another set of ranges that
335       // are assigned to the threads. For instance, if the database has 1000
336       // rows and there are 10 threads, each one will be assed a set of 100
337       // rows. These sets are precisely what are stored in the field below
338       mutable std::vector< std::pair< std::size_t, std::size_t >,
339                            ALLOC< std::pair< std::size_t, std::size_t > > >
340          _thread_ranges_;
341 
342       // the mapping from the NodeIds of the variables to the indices of the
343       // columns in the database
344       Bijection< NodeId, std::size_t, ALLOC< std::size_t > > _nodeId2columns_;
345 
346       // the last database-parsed countings
347       std::vector< double, ALLOC< double > > _last_DB_countings_;
348 
349       // the ids of the nodes for the last database-parsed countings
350       IdCondSet< ALLOC > _last_DB_ids_;
351 
352       // the last countings deduced from  _last_DB_countings_
353       std::vector< double, ALLOC< double > > _last_nonDB_countings_;
354 
355       // the ids of the nodes of last countings deduced from  _last_DB_countings_
356       IdCondSet< ALLOC > _last_nonDB_ids_;
357 
358       // the maximal number of threads that the record counter can use
359       mutable std::size_t _max_nb_threads_{std::size_t(gum::getMaxNumberOfThreads())};
360 
361       // the min number of rows that a thread should process in a
362       // multithreading context
363       mutable std::size_t _min_nb_rows_per_thread_{100};
364 
365       // returns a mapping from the nodes ids to the columns of the database
366       // for a given sequence of ids. This is especially convenient when
367       //  _nodeId2columns_ is empty (which means that there is an identity mapping)
368       HashTable< NodeId, std::size_t > _getNodeIds2Columns_(const IdCondSet< ALLOC >& ids) const;
369 
370       /// extracts some new countings from previously computed ones
371       std::vector< double, ALLOC< double > >&
372          _extractFromCountings_(const IdCondSet< ALLOC >&                     subset_ids,
373                                 const IdCondSet< ALLOC >&                     superset_ids,
374                                 const std::vector< double, ALLOC< double > >& superset_vect);
375 
376       /// parse the database to produce new countings
377       std::vector< double, ALLOC< double > >& _countFromDatabase_(const IdCondSet< ALLOC >& ids);
378 
379       /// the method used by threads to produce countings by parsing the database
380       void _threadedCount_(
381          const std::size_t                                                    range_begin,
382          const std::size_t                                                    range_end,
383          DBRowGeneratorParser< ALLOC >&                                       parser,
384          const std::vector< std::pair< std::size_t, std::size_t >,
385                             ALLOC< std::pair< std::size_t, std::size_t > > >& cols_and_offsets,
386          std::vector< double, ALLOC< double > >&                              countings);
387 
388       /// checks that the ranges passed in argument are ok or raise an exception
389       /** A range is ok if its upper bound is strictly higher than its lower
390        * bound and the latter is also lower than or equal to the number of rows
391        * in the database. */
392       template < template < typename > class XALLOC >
393       void _checkRanges_(
394          const std::vector< std::pair< std::size_t, std::size_t >,
395                             XALLOC< std::pair< std::size_t, std::size_t > > >& new_ranges) const;
396 
397       /// check that the variables at indices [beg,end) of an idset are discrete
398       /** @throw TypeError is raised if at least one variable in ids is
399        * of a continuous nature. */
400       void _checkDiscreteVariables_(const IdCondSet< ALLOC >& ids) const;
401 
402       /// compute and raise the exception when some variables are continuous
403       /** This method is used by  _checkDiscreteVariables_ to determine the
404        * appropriate message to include in the TypeError exception raised when
405        * some variables over which we should perform countings are continuous. */
406       void _raiseCheckException_(
407          const std::vector< std::string, ALLOC< std::string > >& bad_vars) const;
408 
409       /// sets the ranges within which each thread will perform its computations
410       void _dispatchRangesToThreads_();
411 
412 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
413     };
414 
415   } /* namespace learning */
416 
417 } /* namespace gum */
418 
419 /// always include the templated implementations
420 #include <agrum/tools/stattests/recordCounter_tpl.h>
421 
422 #endif /* GUM_LEARNING_RECORD_COUNTER_H */
423