1 // Copyright 2018 Developers of the Rand project.
2 // Copyright 2016-2017 The Rust Project Developers.
3 //
4 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7 // option. This file may not be copied, modified, or distributed
8 // except according to those terms.
9 
10 //! The binomial distribution.
11 #![allow(deprecated)]
12 #![allow(clippy::all)]
13 
14 use crate::distributions::{Distribution, Uniform};
15 use crate::Rng;
16 
17 /// The binomial distribution `Binomial(n, p)`.
18 ///
19 /// This distribution has density function:
20 /// `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`.
21 #[deprecated(since = "0.7.0", note = "moved to rand_distr crate")]
22 #[derive(Clone, Copy, Debug)]
23 pub struct Binomial {
24     /// Number of trials.
25     n: u64,
26     /// Probability of success.
27     p: f64,
28 }
29 
30 impl Binomial {
31     /// Construct a new `Binomial` with the given shape parameters `n` (number
32     /// of trials) and `p` (probability of success).
33     ///
34     /// Panics if `p < 0` or `p > 1`.
new(n: u64, p: f64) -> Binomial35     pub fn new(n: u64, p: f64) -> Binomial {
36         assert!(p >= 0.0, "Binomial::new called with p < 0");
37         assert!(p <= 1.0, "Binomial::new called with p > 1");
38         Binomial { n, p }
39     }
40 }
41 
42 /// Convert a `f64` to an `i64`, panicing on overflow.
43 // In the future (Rust 1.34), this might be replaced with `TryFrom`.
f64_to_i64(x: f64) -> i6444 fn f64_to_i64(x: f64) -> i64 {
45     assert!(x < (::std::i64::MAX as f64));
46     x as i64
47 }
48 
49 impl Distribution<u64> for Binomial {
sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u6450     fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
51         // Handle these values directly.
52         if self.p == 0.0 {
53             return 0;
54         } else if self.p == 1.0 {
55             return self.n;
56         }
57 
58         // The binomial distribution is symmetrical with respect to p -> 1-p,
59         // k -> n-k switch p so that it is less than 0.5 - this allows for lower
60         // expected values we will just invert the result at the end
61         let p = if self.p <= 0.5 { self.p } else { 1.0 - self.p };
62 
63         let result;
64         let q = 1. - p;
65 
66         // For small n * min(p, 1 - p), the BINV algorithm based on the inverse
67         // transformation of the binomial distribution is efficient. Otherwise,
68         // the BTPE algorithm is used.
69         //
70         // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1988. Binomial
71         // random variate generation. Commun. ACM 31, 2 (February 1988),
72         // 216-222. http://dx.doi.org/10.1145/42372.42381
73 
74         // Threshold for prefering the BINV algorithm. The paper suggests 10,
75         // Ranlib uses 30, and GSL uses 14.
76         const BINV_THRESHOLD: f64 = 10.;
77 
78         if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (::std::i32::MAX as u64) {
79             // Use the BINV algorithm.
80             let s = p / q;
81             let a = ((self.n + 1) as f64) * s;
82             let mut r = q.powi(self.n as i32);
83             let mut u: f64 = rng.gen();
84             let mut x = 0;
85             while u > r as f64 {
86                 u -= r;
87                 x += 1;
88                 r *= a / (x as f64) - s;
89             }
90             result = x;
91         } else {
92             // Use the BTPE algorithm.
93 
94             // Threshold for using the squeeze algorithm. This can be freely
95             // chosen based on performance. Ranlib and GSL use 20.
96             const SQUEEZE_THRESHOLD: i64 = 20;
97 
98             // Step 0: Calculate constants as functions of `n` and `p`.
99             let n = self.n as f64;
100             let np = n * p;
101             let npq = np * q;
102             let f_m = np + p;
103             let m = f64_to_i64(f_m);
104             // radius of triangle region, since height=1 also area of region
105             let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5;
106             // tip of triangle
107             let x_m = (m as f64) + 0.5;
108             // left edge of triangle
109             let x_l = x_m - p1;
110             // right edge of triangle
111             let x_r = x_m + p1;
112             let c = 0.134 + 20.5 / (15.3 + (m as f64));
113             // p1 + area of parallelogram region
114             let p2 = p1 * (1. + 2. * c);
115 
116             fn lambda(a: f64) -> f64 {
117                 a * (1. + 0.5 * a)
118             }
119 
120             let lambda_l = lambda((f_m - x_l) / (f_m - x_l * p));
121             let lambda_r = lambda((x_r - f_m) / (x_r * q));
122             // p1 + area of left tail
123             let p3 = p2 + c / lambda_l;
124             // p1 + area of right tail
125             let p4 = p3 + c / lambda_r;
126 
127             // return value
128             let mut y: i64;
129 
130             let gen_u = Uniform::new(0., p4);
131             let gen_v = Uniform::new(0., 1.);
132 
133             loop {
134                 // Step 1: Generate `u` for selecting the region. If region 1 is
135                 // selected, generate a triangularly distributed variate.
136                 let u = gen_u.sample(rng);
137                 let mut v = gen_v.sample(rng);
138                 if !(u > p1) {
139                     y = f64_to_i64(x_m - p1 * v + u);
140                     break;
141                 }
142 
143                 if !(u > p2) {
144                     // Step 2: Region 2, parallelograms. Check if region 2 is
145                     // used. If so, generate `y`.
146                     let x = x_l + (u - p1) / c;
147                     v = v * c + 1.0 - (x - x_m).abs() / p1;
148                     if v > 1. {
149                         continue;
150                     } else {
151                         y = f64_to_i64(x);
152                     }
153                 } else if !(u > p3) {
154                     // Step 3: Region 3, left exponential tail.
155                     y = f64_to_i64(x_l + v.ln() / lambda_l);
156                     if y < 0 {
157                         continue;
158                     } else {
159                         v *= (u - p2) * lambda_l;
160                     }
161                 } else {
162                     // Step 4: Region 4, right exponential tail.
163                     y = f64_to_i64(x_r - v.ln() / lambda_r);
164                     if y > 0 && (y as u64) > self.n {
165                         continue;
166                     } else {
167                         v *= (u - p3) * lambda_r;
168                     }
169                 }
170 
171                 // Step 5: Acceptance/rejection comparison.
172 
173                 // Step 5.0: Test for appropriate method of evaluating f(y).
174                 let k = (y - m).abs();
175                 if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) {
176                     // Step 5.1: Evaluate f(y) via the recursive relationship. Start the
177                     // search from the mode.
178                     let s = p / q;
179                     let a = s * (n + 1.);
180                     let mut f = 1.0;
181                     if m < y {
182                         let mut i = m;
183                         loop {
184                             i += 1;
185                             f *= a / (i as f64) - s;
186                             if i == y {
187                                 break;
188                             }
189                         }
190                     } else if m > y {
191                         let mut i = y;
192                         loop {
193                             i += 1;
194                             f /= a / (i as f64) - s;
195                             if i == m {
196                                 break;
197                             }
198                         }
199                     }
200                     if v > f {
201                         continue;
202                     } else {
203                         break;
204                     }
205                 }
206 
207                 // Step 5.2: Squeezing. Check the value of ln(v) againts upper and
208                 // lower bound of ln(f(y)).
209                 let k = k as f64;
210                 let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5);
211                 let t = -0.5 * k * k / npq;
212                 let alpha = v.ln();
213                 if alpha < t - rho {
214                     break;
215                 }
216                 if alpha > t + rho {
217                     continue;
218                 }
219 
220                 // Step 5.3: Final acceptance/rejection test.
221                 let x1 = (y + 1) as f64;
222                 let f1 = (m + 1) as f64;
223                 let z = (f64_to_i64(n) + 1 - m) as f64;
224                 let w = (f64_to_i64(n) - y + 1) as f64;
225 
226                 fn stirling(a: f64) -> f64 {
227                     let a2 = a * a;
228                     (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320.
229                 }
230 
231                 if alpha
232                         > x_m * (f1 / x1).ln()
233                         + (n - (m as f64) + 0.5) * (z / w).ln()
234                         + ((y - m) as f64) * (w * p / (x1 * q)).ln()
235                         // We use the signs from the GSL implementation, which are
236                         // different than the ones in the reference. According to
237                         // the GSL authors, the new signs were verified to be
238                         // correct by one of the original designers of the
239                         // algorithm.
240                         + stirling(f1)
241                         + stirling(z)
242                         - stirling(x1)
243                         - stirling(w)
244                 {
245                     continue;
246                 }
247 
248                 break;
249             }
250             assert!(y >= 0);
251             result = y as u64;
252         }
253 
254         // Invert the result for p < 0.5.
255         if p != self.p {
256             self.n - result
257         } else {
258             result
259         }
260     }
261 }
262 
263 #[cfg(test)]
264 mod test {
265     use super::Binomial;
266     use crate::distributions::Distribution;
267     use crate::Rng;
268 
test_binomial_mean_and_variance<R: Rng>(n: u64, p: f64, rng: &mut R)269     fn test_binomial_mean_and_variance<R: Rng>(n: u64, p: f64, rng: &mut R) {
270         let binomial = Binomial::new(n, p);
271 
272         let expected_mean = n as f64 * p;
273         let expected_variance = n as f64 * p * (1.0 - p);
274 
275         let mut results = [0.0; 1000];
276         for i in results.iter_mut() {
277             *i = binomial.sample(rng) as f64;
278         }
279 
280         let mean = results.iter().sum::<f64>() / results.len() as f64;
281         assert!(
282             (mean as f64 - expected_mean).abs() < expected_mean / 50.0,
283             "mean: {}, expected_mean: {}",
284             mean,
285             expected_mean
286         );
287 
288         let variance =
289             results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
290         assert!(
291             (variance - expected_variance).abs() < expected_variance / 10.0,
292             "variance: {}, expected_variance: {}",
293             variance,
294             expected_variance
295         );
296     }
297 
298     #[test]
299     #[cfg_attr(miri, ignore)] // Miri is too slow
test_binomial()300     fn test_binomial() {
301         let mut rng = crate::test::rng(351);
302         test_binomial_mean_and_variance(150, 0.1, &mut rng);
303         test_binomial_mean_and_variance(70, 0.6, &mut rng);
304         test_binomial_mean_and_variance(40, 0.5, &mut rng);
305         test_binomial_mean_and_variance(20, 0.7, &mut rng);
306         test_binomial_mean_and_variance(20, 0.5, &mut rng);
307     }
308 
309     #[test]
test_binomial_end_points()310     fn test_binomial_end_points() {
311         let mut rng = crate::test::rng(352);
312         assert_eq!(rng.sample(Binomial::new(20, 0.0)), 0);
313         assert_eq!(rng.sample(Binomial::new(20, 1.0)), 20);
314     }
315 
316     #[test]
317     #[should_panic]
test_binomial_invalid_lambda_neg()318     fn test_binomial_invalid_lambda_neg() {
319         Binomial::new(20, -10.0);
320     }
321 }
322