1 // Copyright 2018 Developers of the Rand project. 2 // Copyright 2013 The Rust Project Developers. 3 // 4 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or 5 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license 6 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your 7 // option. This file may not be copied, modified, or distributed 8 // except according to those terms. 9 10 //! The dirichlet distribution. 11 12 use Rng; 13 use distributions::Distribution; 14 use distributions::gamma::Gamma; 15 16 /// The dirichelet distribution `Dirichlet(alpha)`. 17 /// 18 /// The Dirichlet distribution is a family of continuous multivariate 19 /// probability distributions parameterized by a vector alpha of positive reals. 20 /// It is a multivariate generalization of the beta distribution. 21 /// 22 /// # Example 23 /// 24 /// ``` 25 /// use rand::prelude::*; 26 /// use rand::distributions::Dirichlet; 27 /// 28 /// let dirichlet = Dirichlet::new(vec![1.0, 2.0, 3.0]); 29 /// let samples = dirichlet.sample(&mut rand::thread_rng()); 30 /// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples); 31 /// ``` 32 33 #[derive(Clone, Debug)] 34 pub struct Dirichlet { 35 /// Concentration parameters (alpha) 36 alpha: Vec<f64>, 37 } 38 39 impl Dirichlet { 40 /// Construct a new `Dirichlet` with the given alpha parameter `alpha`. 41 /// 42 /// # Panics 43 /// - if `alpha.len() < 2` 44 /// 45 #[inline] new<V: Into<Vec<f64>>>(alpha: V) -> Dirichlet46 pub fn new<V: Into<Vec<f64>>>(alpha: V) -> Dirichlet { 47 let a = alpha.into(); 48 assert!(a.len() > 1); 49 for i in 0..a.len() { 50 assert!(a[i] > 0.0); 51 } 52 53 Dirichlet { alpha: a } 54 } 55 56 /// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`. 57 /// 58 /// # Panics 59 /// - if `alpha <= 0.0` 60 /// - if `size < 2` 61 /// 62 #[inline] new_with_param(alpha: f64, size: usize) -> Dirichlet63 pub fn new_with_param(alpha: f64, size: usize) -> Dirichlet { 64 assert!(alpha > 0.0); 65 assert!(size > 1); 66 Dirichlet { 67 alpha: vec![alpha; size], 68 } 69 } 70 } 71 72 impl Distribution<Vec<f64>> for Dirichlet { sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<f64>73 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<f64> { 74 let n = self.alpha.len(); 75 let mut samples = vec![0.0f64; n]; 76 let mut sum = 0.0f64; 77 78 for i in 0..n { 79 let g = Gamma::new(self.alpha[i], 1.0); 80 samples[i] = g.sample(rng); 81 sum += samples[i]; 82 } 83 let invacc = 1.0 / sum; 84 for i in 0..n { 85 samples[i] *= invacc; 86 } 87 samples 88 } 89 } 90 91 #[cfg(test)] 92 mod test { 93 use super::Dirichlet; 94 use distributions::Distribution; 95 96 #[test] test_dirichlet()97 fn test_dirichlet() { 98 let d = Dirichlet::new(vec![1.0, 2.0, 3.0]); 99 let mut rng = ::test::rng(221); 100 let samples = d.sample(&mut rng); 101 let _: Vec<f64> = samples 102 .into_iter() 103 .map(|x| { 104 assert!(x > 0.0); 105 x 106 }) 107 .collect(); 108 } 109 110 #[test] test_dirichlet_with_param()111 fn test_dirichlet_with_param() { 112 let alpha = 0.5f64; 113 let size = 2; 114 let d = Dirichlet::new_with_param(alpha, size); 115 let mut rng = ::test::rng(221); 116 let samples = d.sample(&mut rng); 117 let _: Vec<f64> = samples 118 .into_iter() 119 .map(|x| { 120 assert!(x > 0.0); 121 x 122 }) 123 .collect(); 124 } 125 126 #[test] 127 #[should_panic] test_dirichlet_invalid_length()128 fn test_dirichlet_invalid_length() { 129 Dirichlet::new_with_param(0.5f64, 1); 130 } 131 132 #[test] 133 #[should_panic] test_dirichlet_invalid_alpha()134 fn test_dirichlet_invalid_alpha() { 135 Dirichlet::new_with_param(0.0f64, 2); 136 } 137 } 138