1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
10 #define _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
11 
12 #include <__algorithm/upper_bound.h>
13 #include <__config>
14 #include <__random/is_valid.h>
15 #include <__random/uniform_real_distribution.h>
16 #include <cstddef>
17 #include <iosfwd>
18 #include <numeric>
19 #include <vector>
20 
21 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
22 #  pragma GCC system_header
23 #endif
24 
25 _LIBCPP_PUSH_MACROS
26 #include <__undef_macros>
27 
28 _LIBCPP_BEGIN_NAMESPACE_STD
29 
30 template<class _IntType = int>
31 class _LIBCPP_TEMPLATE_VIS discrete_distribution
32 {
33     static_assert(__libcpp_random_is_valid_inttype<_IntType>::value, "IntType must be a supported integer type");
34 public:
35     // types
36     typedef _IntType result_type;
37 
38     class _LIBCPP_TEMPLATE_VIS param_type
39     {
40         vector<double> __p_;
41     public:
42         typedef discrete_distribution distribution_type;
43 
44         _LIBCPP_INLINE_VISIBILITY
45         param_type() {}
46         template<class _InputIterator>
47             _LIBCPP_INLINE_VISIBILITY
48             param_type(_InputIterator __f, _InputIterator __l)
49             : __p_(__f, __l) {__init();}
50 #ifndef _LIBCPP_CXX03_LANG
51         _LIBCPP_INLINE_VISIBILITY
52         param_type(initializer_list<double> __wl)
53             : __p_(__wl.begin(), __wl.end()) {__init();}
54 #endif // _LIBCPP_CXX03_LANG
55         template<class _UnaryOperation>
56             param_type(size_t __nw, double __xmin, double __xmax,
57                        _UnaryOperation __fw);
58 
59         vector<double> probabilities() const;
60 
61         friend _LIBCPP_INLINE_VISIBILITY
62             bool operator==(const param_type& __x, const param_type& __y)
63             {return __x.__p_ == __y.__p_;}
64         friend _LIBCPP_INLINE_VISIBILITY
65             bool operator!=(const param_type& __x, const param_type& __y)
66             {return !(__x == __y);}
67 
68     private:
69         void __init();
70 
71         friend class discrete_distribution;
72 
73         template <class _CharT, class _Traits, class _IT>
74         friend
75         basic_ostream<_CharT, _Traits>&
76         operator<<(basic_ostream<_CharT, _Traits>& __os,
77                    const discrete_distribution<_IT>& __x);
78 
79         template <class _CharT, class _Traits, class _IT>
80         friend
81         basic_istream<_CharT, _Traits>&
82         operator>>(basic_istream<_CharT, _Traits>& __is,
83                    discrete_distribution<_IT>& __x);
84     };
85 
86 private:
87     param_type __p_;
88 
89 public:
90     // constructor and reset functions
91     _LIBCPP_INLINE_VISIBILITY
92     discrete_distribution() {}
93     template<class _InputIterator>
94         _LIBCPP_INLINE_VISIBILITY
95         discrete_distribution(_InputIterator __f, _InputIterator __l)
96             : __p_(__f, __l) {}
97 #ifndef _LIBCPP_CXX03_LANG
98     _LIBCPP_INLINE_VISIBILITY
99     discrete_distribution(initializer_list<double> __wl)
100         : __p_(__wl) {}
101 #endif // _LIBCPP_CXX03_LANG
102     template<class _UnaryOperation>
103         _LIBCPP_INLINE_VISIBILITY
104         discrete_distribution(size_t __nw, double __xmin, double __xmax,
105                               _UnaryOperation __fw)
106         : __p_(__nw, __xmin, __xmax, __fw) {}
107     _LIBCPP_INLINE_VISIBILITY
108     explicit discrete_distribution(const param_type& __p)
109         : __p_(__p) {}
110     _LIBCPP_INLINE_VISIBILITY
111     void reset() {}
112 
113     // generating functions
114     template<class _URNG>
115         _LIBCPP_INLINE_VISIBILITY
116         result_type operator()(_URNG& __g)
117         {return (*this)(__g, __p_);}
118     template<class _URNG> result_type operator()(_URNG& __g, const param_type& __p);
119 
120     // property functions
121     _LIBCPP_INLINE_VISIBILITY
122     vector<double> probabilities() const {return __p_.probabilities();}
123 
124     _LIBCPP_INLINE_VISIBILITY
125     param_type param() const {return __p_;}
126     _LIBCPP_INLINE_VISIBILITY
127     void param(const param_type& __p) {__p_ = __p;}
128 
129     _LIBCPP_INLINE_VISIBILITY
130     result_type min() const {return 0;}
131     _LIBCPP_INLINE_VISIBILITY
132     result_type max() const {return __p_.__p_.size();}
133 
134     friend _LIBCPP_INLINE_VISIBILITY
135         bool operator==(const discrete_distribution& __x,
136                         const discrete_distribution& __y)
137         {return __x.__p_ == __y.__p_;}
138     friend _LIBCPP_INLINE_VISIBILITY
139         bool operator!=(const discrete_distribution& __x,
140                         const discrete_distribution& __y)
141         {return !(__x == __y);}
142 
143     template <class _CharT, class _Traits, class _IT>
144     friend
145     basic_ostream<_CharT, _Traits>&
146     operator<<(basic_ostream<_CharT, _Traits>& __os,
147                const discrete_distribution<_IT>& __x);
148 
149     template <class _CharT, class _Traits, class _IT>
150     friend
151     basic_istream<_CharT, _Traits>&
152     operator>>(basic_istream<_CharT, _Traits>& __is,
153                discrete_distribution<_IT>& __x);
154 };
155 
156 template<class _IntType>
157 template<class _UnaryOperation>
158 discrete_distribution<_IntType>::param_type::param_type(size_t __nw,
159                                                         double __xmin,
160                                                         double __xmax,
161                                                         _UnaryOperation __fw)
162 {
163     if (__nw > 1)
164     {
165         __p_.reserve(__nw - 1);
166         double __d = (__xmax - __xmin) / __nw;
167         double __d2 = __d / 2;
168         for (size_t __k = 0; __k < __nw; ++__k)
169             __p_.push_back(__fw(__xmin + __k * __d + __d2));
170         __init();
171     }
172 }
173 
174 template<class _IntType>
175 void
176 discrete_distribution<_IntType>::param_type::__init()
177 {
178     if (!__p_.empty())
179     {
180         if (__p_.size() > 1)
181         {
182             double __s = _VSTD::accumulate(__p_.begin(), __p_.end(), 0.0);
183             for (vector<double>::iterator __i = __p_.begin(), __e = __p_.end(); __i < __e; ++__i)
184                 *__i /= __s;
185             vector<double> __t(__p_.size() - 1);
186             _VSTD::partial_sum(__p_.begin(), __p_.end() - 1, __t.begin());
187             swap(__p_, __t);
188         }
189         else
190         {
191             __p_.clear();
192             __p_.shrink_to_fit();
193         }
194     }
195 }
196 
197 template<class _IntType>
198 vector<double>
199 discrete_distribution<_IntType>::param_type::probabilities() const
200 {
201     size_t __n = __p_.size();
202     vector<double> __p(__n+1);
203     _VSTD::adjacent_difference(__p_.begin(), __p_.end(), __p.begin());
204     if (__n > 0)
205         __p[__n] = 1 - __p_[__n-1];
206     else
207         __p[0] = 1;
208     return __p;
209 }
210 
211 template<class _IntType>
212 template<class _URNG>
213 _IntType
214 discrete_distribution<_IntType>::operator()(_URNG& __g, const param_type& __p)
215 {
216     static_assert(__libcpp_random_is_valid_urng<_URNG>::value, "");
217     uniform_real_distribution<double> __gen;
218     return static_cast<_IntType>(
219            _VSTD::upper_bound(__p.__p_.begin(), __p.__p_.end(), __gen(__g)) -
220                                                               __p.__p_.begin());
221 }
222 
223 template <class _CharT, class _Traits, class _IT>
224 basic_ostream<_CharT, _Traits>&
225 operator<<(basic_ostream<_CharT, _Traits>& __os,
226            const discrete_distribution<_IT>& __x)
227 {
228     __save_flags<_CharT, _Traits> __lx(__os);
229     typedef basic_ostream<_CharT, _Traits> _OStream;
230     __os.flags(_OStream::dec | _OStream::left | _OStream::fixed |
231                _OStream::scientific);
232     _CharT __sp = __os.widen(' ');
233     __os.fill(__sp);
234     size_t __n = __x.__p_.__p_.size();
235     __os << __n;
236     for (size_t __i = 0; __i < __n; ++__i)
237         __os << __sp << __x.__p_.__p_[__i];
238     return __os;
239 }
240 
241 template <class _CharT, class _Traits, class _IT>
242 basic_istream<_CharT, _Traits>&
243 operator>>(basic_istream<_CharT, _Traits>& __is,
244            discrete_distribution<_IT>& __x)
245 {
246     __save_flags<_CharT, _Traits> __lx(__is);
247     typedef basic_istream<_CharT, _Traits> _Istream;
248     __is.flags(_Istream::dec | _Istream::skipws);
249     size_t __n;
250     __is >> __n;
251     vector<double> __p(__n);
252     for (size_t __i = 0; __i < __n; ++__i)
253         __is >> __p[__i];
254     if (!__is.fail())
255         swap(__x.__p_.__p_, __p);
256     return __is;
257 }
258 
259 _LIBCPP_END_NAMESPACE_STD
260 
261 _LIBCPP_POP_MACROS
262 
263 #endif // _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H
264