1 // Copyright 2018 Developers of the Rand project. 2 // 3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or 4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license 5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your 6 // option. This file may not be copied, modified, or distributed 7 // except according to those terms. 8 9 //! The Bernoulli distribution. 10 11 use Rng; 12 use distributions::Distribution; 13 14 /// The Bernoulli distribution. 15 /// 16 /// This is a special case of the Binomial distribution where `n = 1`. 17 /// 18 /// # Example 19 /// 20 /// ```rust 21 /// use rand::distributions::{Bernoulli, Distribution}; 22 /// 23 /// let d = Bernoulli::new(0.3); 24 /// let v = d.sample(&mut rand::thread_rng()); 25 /// println!("{} is from a Bernoulli distribution", v); 26 /// ``` 27 /// 28 /// # Precision 29 /// 30 /// This `Bernoulli` distribution uses 64 bits from the RNG (a `u64`), 31 /// so only probabilities that are multiples of 2<sup>-64</sup> can be 32 /// represented. 33 #[derive(Clone, Copy, Debug)] 34 pub struct Bernoulli { 35 /// Probability of success, relative to the maximal integer. 36 p_int: u64, 37 } 38 39 // To sample from the Bernoulli distribution we use a method that compares a 40 // random `u64` value `v < (p * 2^64)`. 41 // 42 // If `p == 1.0`, the integer `v` to compare against can not represented as a 43 // `u64`. We manually set it to `u64::MAX` instead (2^64 - 1 instead of 2^64). 44 // Note that value of `p < 1.0` can never result in `u64::MAX`, because an 45 // `f64` only has 53 bits of precision, and the next largest value of `p` will 46 // result in `2^64 - 2048`. 47 // 48 // Also there is a 100% theoretical concern: if someone consistenly wants to 49 // generate `true` using the Bernoulli distribution (i.e. by using a probability 50 // of `1.0`), just using `u64::MAX` is not enough. On average it would return 51 // false once every 2^64 iterations. Some people apparently care about this 52 // case. 53 // 54 // That is why we special-case `u64::MAX` to always return `true`, without using 55 // the RNG, and pay the performance price for all uses that *are* reasonable. 56 // Luckily, if `new()` and `sample` are close, the compiler can optimize out the 57 // extra check. 58 const ALWAYS_TRUE: u64 = ::core::u64::MAX; 59 60 // This is just `2.0.powi(64)`, but written this way because it is not available 61 // in `no_std` mode. 62 const SCALE: f64 = 2.0 * (1u64 << 63) as f64; 63 64 impl Bernoulli { 65 /// Construct a new `Bernoulli` with the given probability of success `p`. 66 /// 67 /// # Panics 68 /// 69 /// If `p < 0` or `p > 1`. 70 /// 71 /// # Precision 72 /// 73 /// For `p = 1.0`, the resulting distribution will always generate true. 74 /// For `p = 0.0`, the resulting distribution will always generate false. 75 /// 76 /// This method is accurate for any input `p` in the range `[0, 1]` which is 77 /// a multiple of 2<sup>-64</sup>. (Note that not all multiples of 78 /// 2<sup>-64</sup> in `[0, 1]` can be represented as a `f64`.) 79 #[inline] new(p: f64) -> Bernoulli80 pub fn new(p: f64) -> Bernoulli { 81 if p < 0.0 || p >= 1.0 { 82 if p == 1.0 { return Bernoulli { p_int: ALWAYS_TRUE } } 83 panic!("Bernoulli::new not called with 0.0 <= p <= 1.0"); 84 } 85 Bernoulli { p_int: (p * SCALE) as u64 } 86 } 87 88 /// Construct a new `Bernoulli` with the probability of success of 89 /// `numerator`-in-`denominator`. I.e. `new_ratio(2, 3)` will return 90 /// a `Bernoulli` with a 2-in-3 chance, or about 67%, of returning `true`. 91 /// 92 /// If `numerator == denominator` then the returned `Bernoulli` will always 93 /// return `true`. If `numerator == 0` it will always return `false`. 94 /// 95 /// # Panics 96 /// 97 /// If `denominator == 0` or `numerator > denominator`. 98 /// 99 #[inline] from_ratio(numerator: u32, denominator: u32) -> Bernoulli100 pub fn from_ratio(numerator: u32, denominator: u32) -> Bernoulli { 101 assert!(numerator <= denominator); 102 if numerator == denominator { 103 return Bernoulli { p_int: ::core::u64::MAX } 104 } 105 let p_int = ((numerator as f64 / denominator as f64) * SCALE) as u64; 106 Bernoulli { p_int } 107 } 108 } 109 110 impl Distribution<bool> for Bernoulli { 111 #[inline] sample<R: Rng + ?Sized>(&self, rng: &mut R) -> bool112 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> bool { 113 // Make sure to always return true for p = 1.0. 114 if self.p_int == ALWAYS_TRUE { return true; } 115 let v: u64 = rng.gen(); 116 v < self.p_int 117 } 118 } 119 120 #[cfg(test)] 121 mod test { 122 use Rng; 123 use distributions::Distribution; 124 use super::Bernoulli; 125 126 #[test] test_trivial()127 fn test_trivial() { 128 let mut r = ::test::rng(1); 129 let always_false = Bernoulli::new(0.0); 130 let always_true = Bernoulli::new(1.0); 131 for _ in 0..5 { 132 assert_eq!(r.sample::<bool, _>(&always_false), false); 133 assert_eq!(r.sample::<bool, _>(&always_true), true); 134 assert_eq!(Distribution::<bool>::sample(&always_false, &mut r), false); 135 assert_eq!(Distribution::<bool>::sample(&always_true, &mut r), true); 136 } 137 } 138 139 #[test] test_average()140 fn test_average() { 141 const P: f64 = 0.3; 142 const NUM: u32 = 3; 143 const DENOM: u32 = 10; 144 let d1 = Bernoulli::new(P); 145 let d2 = Bernoulli::from_ratio(NUM, DENOM); 146 const N: u32 = 100_000; 147 148 let mut sum1: u32 = 0; 149 let mut sum2: u32 = 0; 150 let mut rng = ::test::rng(2); 151 for _ in 0..N { 152 if d1.sample(&mut rng) { 153 sum1 += 1; 154 } 155 if d2.sample(&mut rng) { 156 sum2 += 1; 157 } 158 } 159 let avg1 = (sum1 as f64) / (N as f64); 160 assert!((avg1 - P).abs() < 5e-3); 161 162 let avg2 = (sum2 as f64) / (N as f64); 163 assert!((avg2 - (NUM as f64)/(DENOM as f64)).abs() < 5e-3); 164 } 165 } 166