1 // Copyright 2018 Developers of the Rand project. 2 // 3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or 4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license 5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your 6 // option. This file may not be copied, modified, or distributed 7 // except according to those terms. 8 9 use Rng; 10 use distributions::Distribution; 11 use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow}; 12 use ::core::cmp::PartialOrd; 13 use core::fmt; 14 15 // Note that this whole module is only imported if feature="alloc" is enabled. 16 #[cfg(not(feature="std"))] use alloc::vec::Vec; 17 18 /// A distribution using weighted sampling to pick a discretely selected 19 /// item. 20 /// 21 /// Sampling a `WeightedIndex` distribution returns the index of a randomly 22 /// selected element from the iterator used when the `WeightedIndex` was 23 /// created. The chance of a given element being picked is proportional to the 24 /// value of the element. The weights can use any type `X` for which an 25 /// implementation of [`Uniform<X>`] exists. 26 /// 27 /// # Performance 28 /// 29 /// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its 30 /// size is the sum of the size of those objects, possibly plus some alignment. 31 /// 32 /// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1` 33 /// weights of type `X`, where `N` is the number of weights. However, since 34 /// `Vec` doesn't guarantee a particular growth strategy, additional memory 35 /// might be allocated but not used. Since the `WeightedIndex` object also 36 /// contains, this might cause additional allocations, though for primitive 37 /// types, ['Uniform<X>`] doesn't allocate any memory. 38 /// 39 /// Time complexity of sampling from `WeightedIndex` is `O(log N)` where 40 /// `N` is the number of weights. 41 /// 42 /// Sampling from `WeightedIndex` will result in a single call to 43 /// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically 44 /// will request a single value from the underlying [`RngCore`], though the 45 /// exact number depends on the implementaiton of `Uniform<X>::sample`. 46 /// 47 /// # Example 48 /// 49 /// ``` 50 /// use rand::prelude::*; 51 /// use rand::distributions::WeightedIndex; 52 /// 53 /// let choices = ['a', 'b', 'c']; 54 /// let weights = [2, 1, 1]; 55 /// let dist = WeightedIndex::new(&weights).unwrap(); 56 /// let mut rng = thread_rng(); 57 /// for _ in 0..100 { 58 /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' 59 /// println!("{}", choices[dist.sample(&mut rng)]); 60 /// } 61 /// 62 /// let items = [('a', 0), ('b', 3), ('c', 7)]; 63 /// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap(); 64 /// for _ in 0..100 { 65 /// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' 66 /// println!("{}", items[dist2.sample(&mut rng)].0); 67 /// } 68 /// ``` 69 /// 70 /// [`Uniform<X>`]: crate::distributions::uniform::Uniform 71 /// [`RngCore`]: rand_core::RngCore 72 #[derive(Debug, Clone)] 73 pub struct WeightedIndex<X: SampleUniform + PartialOrd> { 74 cumulative_weights: Vec<X>, 75 weight_distribution: X::Sampler, 76 } 77 78 impl<X: SampleUniform + PartialOrd> WeightedIndex<X> { 79 /// Creates a new a `WeightedIndex` [`Distribution`] using the values 80 /// in `weights`. The weights can use any type `X` for which an 81 /// implementation of [`Uniform<X>`] exists. 82 /// 83 /// Returns an error if the iterator is empty, if any weight is `< 0`, or 84 /// if its total value is 0. 85 /// 86 /// [`Uniform<X>`]: crate::distributions::uniform::Uniform new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError> where I: IntoIterator, I::Item: SampleBorrow<X>, X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default87 pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError> 88 where I: IntoIterator, 89 I::Item: SampleBorrow<X>, 90 X: for<'a> ::core::ops::AddAssign<&'a X> + 91 Clone + 92 Default { 93 let mut iter = weights.into_iter(); 94 let mut total_weight: X = iter.next() 95 .ok_or(WeightedError::NoItem)? 96 .borrow() 97 .clone(); 98 99 let zero = <X as Default>::default(); 100 if total_weight < zero { 101 return Err(WeightedError::NegativeWeight); 102 } 103 104 let mut weights = Vec::<X>::with_capacity(iter.size_hint().0); 105 for w in iter { 106 if *w.borrow() < zero { 107 return Err(WeightedError::NegativeWeight); 108 } 109 weights.push(total_weight.clone()); 110 total_weight += w.borrow(); 111 } 112 113 if total_weight == zero { 114 return Err(WeightedError::AllWeightsZero); 115 } 116 let distr = X::Sampler::new(zero, total_weight); 117 118 Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr }) 119 } 120 } 121 122 impl<X> Distribution<usize> for WeightedIndex<X> where 123 X: SampleUniform + PartialOrd { sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize124 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize { 125 use ::core::cmp::Ordering; 126 let chosen_weight = self.weight_distribution.sample(rng); 127 // Find the first item which has a weight *higher* than the chosen weight. 128 self.cumulative_weights.binary_search_by( 129 |w| if *w <= chosen_weight { Ordering::Less } else { Ordering::Greater }).unwrap_err() 130 } 131 } 132 133 #[cfg(test)] 134 mod test { 135 use super::*; 136 137 #[test] test_weightedindex()138 fn test_weightedindex() { 139 let mut r = ::test::rng(700); 140 const N_REPS: u32 = 5000; 141 let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; 142 let total_weight = weights.iter().sum::<u32>() as f32; 143 144 let verify = |result: [i32; 14]| { 145 for (i, count) in result.iter().enumerate() { 146 let exp = (weights[i] * N_REPS) as f32 / total_weight; 147 let mut err = (*count as f32 - exp).abs(); 148 if err != 0.0 { 149 err /= exp; 150 } 151 assert!(err <= 0.25); 152 } 153 }; 154 155 // WeightedIndex from vec 156 let mut chosen = [0i32; 14]; 157 let distr = WeightedIndex::new(weights.to_vec()).unwrap(); 158 for _ in 0..N_REPS { 159 chosen[distr.sample(&mut r)] += 1; 160 } 161 verify(chosen); 162 163 // WeightedIndex from slice 164 chosen = [0i32; 14]; 165 let distr = WeightedIndex::new(&weights[..]).unwrap(); 166 for _ in 0..N_REPS { 167 chosen[distr.sample(&mut r)] += 1; 168 } 169 verify(chosen); 170 171 // WeightedIndex from iterator 172 chosen = [0i32; 14]; 173 let distr = WeightedIndex::new(weights.iter()).unwrap(); 174 for _ in 0..N_REPS { 175 chosen[distr.sample(&mut r)] += 1; 176 } 177 verify(chosen); 178 179 for _ in 0..5 { 180 assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1); 181 assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0); 182 assert_eq!(WeightedIndex::new(&[0, 0, 0, 0, 10, 0]).unwrap().sample(&mut r), 4); 183 } 184 185 assert_eq!(WeightedIndex::new(&[10][0..0]).unwrap_err(), WeightedError::NoItem); 186 assert_eq!(WeightedIndex::new(&[0]).unwrap_err(), WeightedError::AllWeightsZero); 187 assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::NegativeWeight); 188 assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::NegativeWeight); 189 assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::NegativeWeight); 190 } 191 } 192 193 /// Error type returned from `WeightedIndex::new`. 194 #[derive(Debug, Clone, Copy, PartialEq, Eq)] 195 pub enum WeightedError { 196 /// The provided iterator contained no items. 197 NoItem, 198 199 /// A weight lower than zero was used. 200 NegativeWeight, 201 202 /// All items in the provided iterator had a weight of zero. 203 AllWeightsZero, 204 } 205 206 impl WeightedError { msg(&self) -> &str207 fn msg(&self) -> &str { 208 match *self { 209 WeightedError::NoItem => "No items found", 210 WeightedError::NegativeWeight => "Item has negative weight", 211 WeightedError::AllWeightsZero => "All items had weight zero", 212 } 213 } 214 } 215 216 #[cfg(feature="std")] 217 impl ::std::error::Error for WeightedError { description(&self) -> &str218 fn description(&self) -> &str { 219 self.msg() 220 } cause(&self) -> Option<&::std::error::Error>221 fn cause(&self) -> Option<&::std::error::Error> { 222 None 223 } 224 } 225 226 impl fmt::Display for WeightedError { fmt(&self, f: &mut fmt::Formatter) -> fmt::Result227 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 228 write!(f, "{}", self.msg()) 229 } 230 } 231