1 /**
2  *
3  *   Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU)
4  *   info_at_agrum_dot_org
5  *
6  *  This library is free software: you can redistribute it and/or modify
7  *  it under the terms of the GNU Lesser General Public License as published by
8  *  the Free Software Foundation, either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  This library is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU Lesser General Public License for more details.
15  *
16  *  You should have received a copy of the GNU Lesser General Public License
17  *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief A class for sampling w.r.t. Dirichlet distributions.
25  *
26  * @author Christophe GONZALES(_at_AMU) and Pierre-Henri WUILLEMIN(_at_LIP6)
27  */
28 #ifndef GUM_LEARNING_DIRICHLET_H
29 #define GUM_LEARNING_DIRICHLET_H
30 
31 #include <random>
32 #include <vector>
33 
34 #include <agrum/agrum.h>
35 #include <agrum/tools/core/utils_random.h>
36 
37 namespace gum {
38 
39   // =========================================================================
40   // ===                          DIRICHLET CLASS                          ===
41   // =========================================================================
42 
43   /**
44    * @class Dirichlet
45    * @headerfile Dirichlet.h <agrum/tools/core/math/Dirichlet.h>
46    * @brief A class for sampling w.r.t. Dirichlet distributions.
47    * @ingroup math_group
48    */
49   class Dirichlet {
50     public:
51     /// The parameter type.
52     using param_type = std::vector< float >;
53 
54     /// The type for the samples generated.
55     using result_type = std::vector< float >;
56 
57     // ==========================================================================
58     /// @name Constructors / Destructors
59     // ==========================================================================
60     /// @{
61 
62     /**
63      * @brief Default constructor.
64      * @param params The distribution parameters.
65      * @param seed The distribution seed.
66      */
67     Dirichlet(const param_type& params, unsigned int seed = GUM_RANDOMSEED);
68 
69     /**
70      * @brief Copy constructor.
71      * @param from The distribution to copy.
72      */
73     Dirichlet(const Dirichlet& from);
74 
75     /**
76      * @brief Move constructor.
77      * @param from The distribution to move.
78      */
79     Dirichlet(Dirichlet&& from);
80 
81     /**
82      * @brief Class destructor.
83      */
84     ~Dirichlet();
85 
86     /// @}
87     // ==========================================================================
88     /// @name Operators
89     // ==========================================================================
90     /// @{
91 
92     /**
93      * @brief Copy operator.
94      * @param from The distribution to copy.
95      * @return Returns this gum::Dirichlet distribution.
96      */
97     Dirichlet& operator=(const Dirichlet& from);
98 
99     /**
100      * @brief Move operator.
101      * @param from The distribution to move.
102      * @return Returns this gum::Dirichlet distribution.
103      */
104     Dirichlet& operator=(Dirichlet&& from);
105 
106     /**
107      * @brief Returns a sample from the Dirichlet distribution.
108      * @return Returns a sample from the Dirichlet distribution.
109      */
110     result_type operator()();
111 
112     /**
113      * @brief Returns a sample from the Dirichlet distribution.
114      * @param p An object representing the distribution's parameters,
115      * obtained by a call to gum::Dirichlet::param(const param_type&).
116      */
117     result_type operator()(const param_type& p);
118 
119     /**
120      * @brief Returns a sample from the Dirichlet distribution.
121      *
122      * @param generator A uniform random number generator object, used as the
123      * source of randomness. URNG shall be a uniform random number generator
124      * type, such as one of the standard generator classes.
125      * @param p An object representing the distribution's parameters,
126      * obtained by a call to gum::Dirichlet::param(const param_type&).
127      */
128     template < class URNG >
129     result_type operator()(URNG& generator, const param_type& p);
130 
131     /// @}
132     // ==========================================================================
133     /// @name Accessors / Modifiers
134     // ==========================================================================
135     /// @{
136 
137     /**
138      * @brief Returns the parameters of the distribution.
139      * @return Returns the parameters of the distribution.
140      */
141     const param_type& param() const noexcept;
142 
143     /**
144      * @brief Sets the parameters of the distribution.
145      * @param p An object representing the distribution's parameters, obtained
146      * by a call to member function param.
147      */
148     void param(const param_type& p);
149 
150     /**
151      * @brief Returns the greatest lower bound of the range of values returned
152      * by gum::Dirichlet::operator()().
153      * @return Returns the greatest lower bound of the range of values returned
154      * by gum::Dirichlet::operator()().
155      */
156     float min() const noexcept;
157 
158     /**
159      * @brief Returns the lowest higher bound of the range of values returned
160      * by gum::Dirichlet::operator()().
161      * @return Returns the lowest higher bound of the range of values returned
162      * by gum::Dirichlet::operator()().
163      */
164     float max() const noexcept;
165 
166     /// @}
167 
168     private:
169     /// The random engine used by the unform random distribution.
170     std::default_random_engine _generator_;
171 
172     /// The gamma distribution used to compute the Dirichlet unnormalized
173     /// samples.
174     std::gamma_distribution< float > _gamma_;
175 
176     /// The parameters of the distribution.
177     param_type _params_;
178   };
179 
180 } /* namespace gum */
181 
182 // include the inlined functions if necessary
183 #ifndef GUM_NO_INLINE
184 #  include <agrum/tools/core/math/Dirichlet_inl.h>
185 #endif /* GUM_NO_INLINE */
186 
187 // always include templates
188 #include <agrum/tools/core/math/Dirichlet_tpl.h>
189 
190 #endif /* GUM_LEARNING_DIRICHLET_H */
191