1 /** 2 * 3 * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU) 4 * info_at_agrum_dot_org 5 * 6 * This library is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU Lesser General Public License as published by 8 * the Free Software Foundation, either version 3 of the License, or 9 * (at your option) any later version. 10 * 11 * This library is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 * GNU Lesser General Public License for more details. 15 * 16 * You should have received a copy of the GNU Lesser General Public License 17 * along with this library. If not, see <http://www.gnu.org/licenses/>. 18 * 19 */ 20 21 22 /** 23 * @file 24 * @brief A class for sampling w.r.t. Dirichlet distributions. 25 * 26 * @author Christophe GONZALES(_at_AMU) and Pierre-Henri WUILLEMIN(_at_LIP6) 27 */ 28 #ifndef GUM_LEARNING_DIRICHLET_H 29 #define GUM_LEARNING_DIRICHLET_H 30 31 #include <random> 32 #include <vector> 33 34 #include <agrum/agrum.h> 35 #include <agrum/tools/core/utils_random.h> 36 37 namespace gum { 38 39 // ========================================================================= 40 // === DIRICHLET CLASS === 41 // ========================================================================= 42 43 /** 44 * @class Dirichlet 45 * @headerfile Dirichlet.h <agrum/tools/core/math/Dirichlet.h> 46 * @brief A class for sampling w.r.t. Dirichlet distributions. 47 * @ingroup math_group 48 */ 49 class Dirichlet { 50 public: 51 /// The parameter type. 52 using param_type = std::vector< float >; 53 54 /// The type for the samples generated. 55 using result_type = std::vector< float >; 56 57 // ========================================================================== 58 /// @name Constructors / Destructors 59 // ========================================================================== 60 /// @{ 61 62 /** 63 * @brief Default constructor. 64 * @param params The distribution parameters. 65 * @param seed The distribution seed. 66 */ 67 Dirichlet(const param_type& params, unsigned int seed = GUM_RANDOMSEED); 68 69 /** 70 * @brief Copy constructor. 71 * @param from The distribution to copy. 72 */ 73 Dirichlet(const Dirichlet& from); 74 75 /** 76 * @brief Move constructor. 77 * @param from The distribution to move. 78 */ 79 Dirichlet(Dirichlet&& from); 80 81 /** 82 * @brief Class destructor. 83 */ 84 ~Dirichlet(); 85 86 /// @} 87 // ========================================================================== 88 /// @name Operators 89 // ========================================================================== 90 /// @{ 91 92 /** 93 * @brief Copy operator. 94 * @param from The distribution to copy. 95 * @return Returns this gum::Dirichlet distribution. 96 */ 97 Dirichlet& operator=(const Dirichlet& from); 98 99 /** 100 * @brief Move operator. 101 * @param from The distribution to move. 102 * @return Returns this gum::Dirichlet distribution. 103 */ 104 Dirichlet& operator=(Dirichlet&& from); 105 106 /** 107 * @brief Returns a sample from the Dirichlet distribution. 108 * @return Returns a sample from the Dirichlet distribution. 109 */ 110 result_type operator()(); 111 112 /** 113 * @brief Returns a sample from the Dirichlet distribution. 114 * @param p An object representing the distribution's parameters, 115 * obtained by a call to gum::Dirichlet::param(const param_type&). 116 */ 117 result_type operator()(const param_type& p); 118 119 /** 120 * @brief Returns a sample from the Dirichlet distribution. 121 * 122 * @param generator A uniform random number generator object, used as the 123 * source of randomness. URNG shall be a uniform random number generator 124 * type, such as one of the standard generator classes. 125 * @param p An object representing the distribution's parameters, 126 * obtained by a call to gum::Dirichlet::param(const param_type&). 127 */ 128 template < class URNG > 129 result_type operator()(URNG& generator, const param_type& p); 130 131 /// @} 132 // ========================================================================== 133 /// @name Accessors / Modifiers 134 // ========================================================================== 135 /// @{ 136 137 /** 138 * @brief Returns the parameters of the distribution. 139 * @return Returns the parameters of the distribution. 140 */ 141 const param_type& param() const noexcept; 142 143 /** 144 * @brief Sets the parameters of the distribution. 145 * @param p An object representing the distribution's parameters, obtained 146 * by a call to member function param. 147 */ 148 void param(const param_type& p); 149 150 /** 151 * @brief Returns the greatest lower bound of the range of values returned 152 * by gum::Dirichlet::operator()(). 153 * @return Returns the greatest lower bound of the range of values returned 154 * by gum::Dirichlet::operator()(). 155 */ 156 float min() const noexcept; 157 158 /** 159 * @brief Returns the lowest higher bound of the range of values returned 160 * by gum::Dirichlet::operator()(). 161 * @return Returns the lowest higher bound of the range of values returned 162 * by gum::Dirichlet::operator()(). 163 */ 164 float max() const noexcept; 165 166 /// @} 167 168 private: 169 /// The random engine used by the unform random distribution. 170 std::default_random_engine _generator_; 171 172 /// The gamma distribution used to compute the Dirichlet unnormalized 173 /// samples. 174 std::gamma_distribution< float > _gamma_; 175 176 /// The parameters of the distribution. 177 param_type _params_; 178 }; 179 180 } /* namespace gum */ 181 182 // include the inlined functions if necessary 183 #ifndef GUM_NO_INLINE 184 # include <agrum/tools/core/math/Dirichlet_inl.h> 185 #endif /* GUM_NO_INLINE */ 186 187 // always include templates 188 #include <agrum/tools/core/math/Dirichlet_tpl.h> 189 190 #endif /* GUM_LEARNING_DIRICHLET_H */ 191