1 // Copyright 2018 Developers of the Rand project.
2 // Copyright 2013 The Rust Project Developers.
3 //
4 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7 // option. This file may not be copied, modified, or distributed
8 // except according to those terms.
9 
10 //! The dirichlet distribution.
11 
12 use Rng;
13 use distributions::Distribution;
14 use distributions::gamma::Gamma;
15 
16 /// The dirichelet distribution `Dirichlet(alpha)`.
17 ///
18 /// The Dirichlet distribution is a family of continuous multivariate
19 /// probability distributions parameterized by a vector alpha of positive reals.
20 /// It is a multivariate generalization of the beta distribution.
21 ///
22 /// # Example
23 ///
24 /// ```
25 /// use rand::prelude::*;
26 /// use rand::distributions::Dirichlet;
27 ///
28 /// let dirichlet = Dirichlet::new(vec![1.0, 2.0, 3.0]);
29 /// let samples = dirichlet.sample(&mut rand::thread_rng());
30 /// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
31 /// ```
32 
33 #[derive(Clone, Debug)]
34 pub struct Dirichlet {
35     /// Concentration parameters (alpha)
36     alpha: Vec<f64>,
37 }
38 
39 impl Dirichlet {
40     /// Construct a new `Dirichlet` with the given alpha parameter `alpha`.
41     ///
42     /// # Panics
43     /// - if `alpha.len() < 2`
44     ///
45     #[inline]
new<V: Into<Vec<f64>>>(alpha: V) -> Dirichlet46     pub fn new<V: Into<Vec<f64>>>(alpha: V) -> Dirichlet {
47         let a = alpha.into();
48         assert!(a.len() > 1);
49         for i in 0..a.len() {
50             assert!(a[i] > 0.0);
51         }
52 
53         Dirichlet { alpha: a }
54     }
55 
56     /// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`.
57     ///
58     /// # Panics
59     /// - if `alpha <= 0.0`
60     /// - if `size < 2`
61     ///
62     #[inline]
new_with_param(alpha: f64, size: usize) -> Dirichlet63     pub fn new_with_param(alpha: f64, size: usize) -> Dirichlet {
64         assert!(alpha > 0.0);
65         assert!(size > 1);
66         Dirichlet {
67             alpha: vec![alpha; size],
68         }
69     }
70 }
71 
72 impl Distribution<Vec<f64>> for Dirichlet {
sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<f64>73     fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<f64> {
74         let n = self.alpha.len();
75         let mut samples = vec![0.0f64; n];
76         let mut sum = 0.0f64;
77 
78         for i in 0..n {
79             let g = Gamma::new(self.alpha[i], 1.0);
80             samples[i] = g.sample(rng);
81             sum += samples[i];
82         }
83         let invacc = 1.0 / sum;
84         for i in 0..n {
85             samples[i] *= invacc;
86         }
87         samples
88     }
89 }
90 
91 #[cfg(test)]
92 mod test {
93     use super::Dirichlet;
94     use distributions::Distribution;
95 
96     #[test]
test_dirichlet()97     fn test_dirichlet() {
98         let d = Dirichlet::new(vec![1.0, 2.0, 3.0]);
99         let mut rng = ::test::rng(221);
100         let samples = d.sample(&mut rng);
101         let _: Vec<f64> = samples
102             .into_iter()
103             .map(|x| {
104                 assert!(x > 0.0);
105                 x
106             })
107             .collect();
108     }
109 
110     #[test]
test_dirichlet_with_param()111     fn test_dirichlet_with_param() {
112         let alpha = 0.5f64;
113         let size = 2;
114         let d = Dirichlet::new_with_param(alpha, size);
115         let mut rng = ::test::rng(221);
116         let samples = d.sample(&mut rng);
117         let _: Vec<f64> = samples
118             .into_iter()
119             .map(|x| {
120                 assert!(x > 0.0);
121                 x
122             })
123             .collect();
124     }
125 
126     #[test]
127     #[should_panic]
test_dirichlet_invalid_length()128     fn test_dirichlet_invalid_length() {
129         Dirichlet::new_with_param(0.5f64, 1);
130     }
131 
132     #[test]
133     #[should_panic]
test_dirichlet_invalid_alpha()134     fn test_dirichlet_invalid_alpha() {
135         Dirichlet::new_with_param(0.0f64, 2);
136     }
137 }
138