1 // Copyright 2018 Developers of the Rand project.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8 
9 use Rng;
10 use distributions::Distribution;
11 use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow};
12 use ::core::cmp::PartialOrd;
13 use core::fmt;
14 
15 // Note that this whole module is only imported if feature="alloc" is enabled.
16 #[cfg(not(feature="std"))] use alloc::vec::Vec;
17 
18 /// A distribution using weighted sampling to pick a discretely selected
19 /// item.
20 ///
21 /// Sampling a `WeightedIndex` distribution returns the index of a randomly
22 /// selected element from the iterator used when the `WeightedIndex` was
23 /// created. The chance of a given element being picked is proportional to the
24 /// value of the element. The weights can use any type `X` for which an
25 /// implementation of [`Uniform<X>`] exists.
26 ///
27 /// # Performance
28 ///
29 /// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its
30 /// size is the sum of the size of those objects, possibly plus some alignment.
31 ///
32 /// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1`
33 /// weights of type `X`, where `N` is the number of weights. However, since
34 /// `Vec` doesn't guarantee a particular growth strategy, additional memory
35 /// might be allocated but not used. Since the `WeightedIndex` object also
36 /// contains, this might cause additional allocations, though for primitive
37 /// types, ['Uniform<X>`] doesn't allocate any memory.
38 ///
39 /// Time complexity of sampling from `WeightedIndex` is `O(log N)` where
40 /// `N` is the number of weights.
41 ///
42 /// Sampling from `WeightedIndex` will result in a single call to
43 /// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically
44 /// will request a single value from the underlying [`RngCore`], though the
45 /// exact number depends on the implementaiton of `Uniform<X>::sample`.
46 ///
47 /// # Example
48 ///
49 /// ```
50 /// use rand::prelude::*;
51 /// use rand::distributions::WeightedIndex;
52 ///
53 /// let choices = ['a', 'b', 'c'];
54 /// let weights = [2,   1,   1];
55 /// let dist = WeightedIndex::new(&weights).unwrap();
56 /// let mut rng = thread_rng();
57 /// for _ in 0..100 {
58 ///     // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
59 ///     println!("{}", choices[dist.sample(&mut rng)]);
60 /// }
61 ///
62 /// let items = [('a', 0), ('b', 3), ('c', 7)];
63 /// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
64 /// for _ in 0..100 {
65 ///     // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
66 ///     println!("{}", items[dist2.sample(&mut rng)].0);
67 /// }
68 /// ```
69 ///
70 /// [`Uniform<X>`]: crate::distributions::uniform::Uniform
71 /// [`RngCore`]: rand_core::RngCore
72 #[derive(Debug, Clone)]
73 pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
74     cumulative_weights: Vec<X>,
75     weight_distribution: X::Sampler,
76 }
77 
78 impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
79     /// Creates a new a `WeightedIndex` [`Distribution`] using the values
80     /// in `weights`. The weights can use any type `X` for which an
81     /// implementation of [`Uniform<X>`] exists.
82     ///
83     /// Returns an error if the iterator is empty, if any weight is `< 0`, or
84     /// if its total value is 0.
85     ///
86     /// [`Uniform<X>`]: crate::distributions::uniform::Uniform
new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError> where I: IntoIterator, I::Item: SampleBorrow<X>, X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default87     pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
88         where I: IntoIterator,
89               I::Item: SampleBorrow<X>,
90               X: for<'a> ::core::ops::AddAssign<&'a X> +
91                  Clone +
92                  Default {
93         let mut iter = weights.into_iter();
94         let mut total_weight: X = iter.next()
95                                       .ok_or(WeightedError::NoItem)?
96                                       .borrow()
97                                       .clone();
98 
99         let zero = <X as Default>::default();
100         if total_weight < zero {
101             return Err(WeightedError::NegativeWeight);
102         }
103 
104         let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
105         for w in iter {
106             if *w.borrow() < zero {
107                 return Err(WeightedError::NegativeWeight);
108             }
109             weights.push(total_weight.clone());
110             total_weight += w.borrow();
111         }
112 
113         if total_weight == zero {
114             return Err(WeightedError::AllWeightsZero);
115         }
116         let distr = X::Sampler::new(zero, total_weight);
117 
118         Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr })
119     }
120 }
121 
122 impl<X> Distribution<usize> for WeightedIndex<X> where
123     X: SampleUniform + PartialOrd {
sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize124     fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
125         use ::core::cmp::Ordering;
126         let chosen_weight = self.weight_distribution.sample(rng);
127         // Find the first item which has a weight *higher* than the chosen weight.
128         self.cumulative_weights.binary_search_by(
129             |w| if *w <= chosen_weight { Ordering::Less } else { Ordering::Greater }).unwrap_err()
130     }
131 }
132 
133 #[cfg(test)]
134 mod test {
135     use super::*;
136 
137     #[test]
test_weightedindex()138     fn test_weightedindex() {
139         let mut r = ::test::rng(700);
140         const N_REPS: u32 = 5000;
141         let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
142         let total_weight = weights.iter().sum::<u32>() as f32;
143 
144         let verify = |result: [i32; 14]| {
145             for (i, count) in result.iter().enumerate() {
146                 let exp = (weights[i] * N_REPS) as f32 / total_weight;
147                 let mut err = (*count as f32 - exp).abs();
148                 if err != 0.0 {
149                     err /= exp;
150                 }
151                 assert!(err <= 0.25);
152             }
153         };
154 
155         // WeightedIndex from vec
156         let mut chosen = [0i32; 14];
157         let distr = WeightedIndex::new(weights.to_vec()).unwrap();
158         for _ in 0..N_REPS {
159             chosen[distr.sample(&mut r)] += 1;
160         }
161         verify(chosen);
162 
163         // WeightedIndex from slice
164         chosen = [0i32; 14];
165         let distr = WeightedIndex::new(&weights[..]).unwrap();
166         for _ in 0..N_REPS {
167             chosen[distr.sample(&mut r)] += 1;
168         }
169         verify(chosen);
170 
171         // WeightedIndex from iterator
172         chosen = [0i32; 14];
173         let distr = WeightedIndex::new(weights.iter()).unwrap();
174         for _ in 0..N_REPS {
175             chosen[distr.sample(&mut r)] += 1;
176         }
177         verify(chosen);
178 
179         for _ in 0..5 {
180             assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1);
181             assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0);
182             assert_eq!(WeightedIndex::new(&[0, 0, 0, 0, 10, 0]).unwrap().sample(&mut r), 4);
183         }
184 
185         assert_eq!(WeightedIndex::new(&[10][0..0]).unwrap_err(), WeightedError::NoItem);
186         assert_eq!(WeightedIndex::new(&[0]).unwrap_err(), WeightedError::AllWeightsZero);
187         assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::NegativeWeight);
188         assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::NegativeWeight);
189         assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::NegativeWeight);
190     }
191 }
192 
193 /// Error type returned from `WeightedIndex::new`.
194 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
195 pub enum WeightedError {
196     /// The provided iterator contained no items.
197     NoItem,
198 
199     /// A weight lower than zero was used.
200     NegativeWeight,
201 
202     /// All items in the provided iterator had a weight of zero.
203     AllWeightsZero,
204 }
205 
206 impl WeightedError {
msg(&self) -> &str207     fn msg(&self) -> &str {
208         match *self {
209             WeightedError::NoItem => "No items found",
210             WeightedError::NegativeWeight => "Item has negative weight",
211             WeightedError::AllWeightsZero => "All items had weight zero",
212         }
213     }
214 }
215 
216 #[cfg(feature="std")]
217 impl ::std::error::Error for WeightedError {
description(&self) -> &str218     fn description(&self) -> &str {
219         self.msg()
220     }
cause(&self) -> Option<&::std::error::Error>221     fn cause(&self) -> Option<&::std::error::Error> {
222         None
223     }
224 }
225 
226 impl fmt::Display for WeightedError {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result227     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
228         write!(f, "{}", self.msg())
229     }
230 }
231