1 // Copyright 2017 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10
11 //! Functions for randomly accessing and sampling sequences.
12
13 use super::Rng;
14
15 // This crate is only enabled when either std or alloc is available.
16 // BTreeMap is not as fast in tests, but better than nothing.
17 #[cfg(feature="std")] use std::collections::HashMap;
18 #[cfg(not(feature="std"))] use alloc::btree_map::BTreeMap;
19
20 #[cfg(not(feature="std"))] use alloc::Vec;
21
22 /// Randomly sample `amount` elements from a finite iterator.
23 ///
24 /// The following can be returned:
25 /// - `Ok`: `Vec` of `amount` non-repeating randomly sampled elements. The order is not random.
26 /// - `Err`: `Vec` of all the elements from `iterable` in sequential order. This happens when the
27 /// length of `iterable` was less than `amount`. This is considered an error since exactly
28 /// `amount` elements is typically expected.
29 ///
30 /// This implementation uses `O(len(iterable))` time and `O(amount)` memory.
31 ///
32 /// # Example
33 ///
34 /// ```rust
35 /// use rand::{thread_rng, seq};
36 ///
37 /// let mut rng = thread_rng();
38 /// let sample = seq::sample_iter(&mut rng, 1..100, 5).unwrap();
39 /// println!("{:?}", sample);
40 /// ```
sample_iter<T, I, R>(rng: &mut R, iterable: I, amount: usize) -> Result<Vec<T>, Vec<T>> where I: IntoIterator<Item=T>, R: Rng,41 pub fn sample_iter<T, I, R>(rng: &mut R, iterable: I, amount: usize) -> Result<Vec<T>, Vec<T>>
42 where I: IntoIterator<Item=T>,
43 R: Rng,
44 {
45 let mut iter = iterable.into_iter();
46 let mut reservoir = Vec::with_capacity(amount);
47 reservoir.extend(iter.by_ref().take(amount));
48
49 // Continue unless the iterator was exhausted
50 //
51 // note: this prevents iterators that "restart" from causing problems.
52 // If the iterator stops once, then so do we.
53 if reservoir.len() == amount {
54 for (i, elem) in iter.enumerate() {
55 let k = rng.gen_range(0, i + 1 + amount);
56 if let Some(spot) = reservoir.get_mut(k) {
57 *spot = elem;
58 }
59 }
60 Ok(reservoir)
61 } else {
62 // Don't hang onto extra memory. There is a corner case where
63 // `amount` was much less than `len(iterable)`.
64 reservoir.shrink_to_fit();
65 Err(reservoir)
66 }
67 }
68
69 /// Randomly sample exactly `amount` values from `slice`.
70 ///
71 /// The values are non-repeating and in random order.
72 ///
73 /// This implementation uses `O(amount)` time and memory.
74 ///
75 /// Panics if `amount > slice.len()`
76 ///
77 /// # Example
78 ///
79 /// ```rust
80 /// use rand::{thread_rng, seq};
81 ///
82 /// let mut rng = thread_rng();
83 /// let values = vec![5, 6, 1, 3, 4, 6, 7];
84 /// println!("{:?}", seq::sample_slice(&mut rng, &values, 3));
85 /// ```
sample_slice<R, T>(rng: &mut R, slice: &[T], amount: usize) -> Vec<T> where R: Rng, T: Clone86 pub fn sample_slice<R, T>(rng: &mut R, slice: &[T], amount: usize) -> Vec<T>
87 where R: Rng,
88 T: Clone
89 {
90 let indices = sample_indices(rng, slice.len(), amount);
91
92 let mut out = Vec::with_capacity(amount);
93 out.extend(indices.iter().map(|i| slice[*i].clone()));
94 out
95 }
96
97 /// Randomly sample exactly `amount` references from `slice`.
98 ///
99 /// The references are non-repeating and in random order.
100 ///
101 /// This implementation uses `O(amount)` time and memory.
102 ///
103 /// Panics if `amount > slice.len()`
104 ///
105 /// # Example
106 ///
107 /// ```rust
108 /// use rand::{thread_rng, seq};
109 ///
110 /// let mut rng = thread_rng();
111 /// let values = vec![5, 6, 1, 3, 4, 6, 7];
112 /// println!("{:?}", seq::sample_slice_ref(&mut rng, &values, 3));
113 /// ```
sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> Vec<&'a T> where R: Rng114 pub fn sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> Vec<&'a T>
115 where R: Rng
116 {
117 let indices = sample_indices(rng, slice.len(), amount);
118
119 let mut out = Vec::with_capacity(amount);
120 out.extend(indices.iter().map(|i| &slice[*i]));
121 out
122 }
123
124 /// Randomly sample exactly `amount` indices from `0..length`.
125 ///
126 /// The values are non-repeating and in random order.
127 ///
128 /// This implementation uses `O(amount)` time and memory.
129 ///
130 /// This method is used internally by the slice sampling methods, but it can sometimes be useful to
131 /// have the indices themselves so this is provided as an alternative.
132 ///
133 /// Panics if `amount > length`
sample_indices<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize> where R: Rng,134 pub fn sample_indices<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize>
135 where R: Rng,
136 {
137 if amount > length {
138 panic!("`amount` must be less than or equal to `slice.len()`");
139 }
140
141 // We are going to have to allocate at least `amount` for the output no matter what. However,
142 // if we use the `cached` version we will have to allocate `amount` as a HashMap as well since
143 // it inserts an element for every loop.
144 //
145 // Therefore, if `amount >= length / 2` then inplace will be both faster and use less memory.
146 // In fact, benchmarks show the inplace version is faster for length up to about 20 times
147 // faster than amount.
148 //
149 // TODO: there is probably even more fine-tuning that can be done here since
150 // `HashMap::with_capacity(amount)` probably allocates more than `amount` in practice,
151 // and a trade off could probably be made between memory/cpu, since hashmap operations
152 // are slower than array index swapping.
153 if amount >= length / 20 {
154 sample_indices_inplace(rng, length, amount)
155 } else {
156 sample_indices_cache(rng, length, amount)
157 }
158 }
159
160 /// Sample an amount of indices using an inplace partial fisher yates method.
161 ///
162 /// This allocates the entire `length` of indices and randomizes only the first `amount`.
163 /// It then truncates to `amount` and returns.
164 ///
165 /// This is better than using a HashMap "cache" when `amount >= length / 2` since it does not
166 /// require allocating an extra cache and is much faster.
sample_indices_inplace<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize> where R: Rng,167 fn sample_indices_inplace<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize>
168 where R: Rng,
169 {
170 debug_assert!(amount <= length);
171 let mut indices: Vec<usize> = Vec::with_capacity(length);
172 indices.extend(0..length);
173 for i in 0..amount {
174 let j: usize = rng.gen_range(i, length);
175 let tmp = indices[i];
176 indices[i] = indices[j];
177 indices[j] = tmp;
178 }
179 indices.truncate(amount);
180 debug_assert_eq!(indices.len(), amount);
181 indices
182 }
183
184
185 /// This method performs a partial fisher-yates on a range of indices using a HashMap
186 /// as a cache to record potential collisions.
187 ///
188 /// The cache avoids allocating the entire `length` of values. This is especially useful when
189 /// `amount <<< length`, i.e. select 3 non-repeating from 1_000_000
sample_indices_cache<R>( rng: &mut R, length: usize, amount: usize, ) -> Vec<usize> where R: Rng,190 fn sample_indices_cache<R>(
191 rng: &mut R,
192 length: usize,
193 amount: usize,
194 ) -> Vec<usize>
195 where R: Rng,
196 {
197 debug_assert!(amount <= length);
198 #[cfg(feature="std")] let mut cache = HashMap::with_capacity(amount);
199 #[cfg(not(feature="std"))] let mut cache = BTreeMap::new();
200 let mut out = Vec::with_capacity(amount);
201 for i in 0..amount {
202 let j: usize = rng.gen_range(i, length);
203
204 // equiv: let tmp = slice[i];
205 let tmp = match cache.get(&i) {
206 Some(e) => *e,
207 None => i,
208 };
209
210 // equiv: slice[i] = slice[j];
211 let x = match cache.get(&j) {
212 Some(x) => *x,
213 None => j,
214 };
215
216 // equiv: slice[j] = tmp;
217 cache.insert(j, tmp);
218
219 // note that in the inplace version, slice[i] is automatically "returned" value
220 out.push(x);
221 }
222 debug_assert_eq!(out.len(), amount);
223 out
224 }
225
226 #[cfg(test)]
227 mod test {
228 use super::*;
229 use {thread_rng, XorShiftRng, SeedableRng};
230
231 #[test]
test_sample_iter()232 fn test_sample_iter() {
233 let min_val = 1;
234 let max_val = 100;
235
236 let mut r = thread_rng();
237 let vals = (min_val..max_val).collect::<Vec<i32>>();
238 let small_sample = sample_iter(&mut r, vals.iter(), 5).unwrap();
239 let large_sample = sample_iter(&mut r, vals.iter(), vals.len() + 5).unwrap_err();
240
241 assert_eq!(small_sample.len(), 5);
242 assert_eq!(large_sample.len(), vals.len());
243 // no randomization happens when amount >= len
244 assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
245
246 assert!(small_sample.iter().all(|e| {
247 **e >= min_val && **e <= max_val
248 }));
249 }
250 #[test]
test_sample_slice_boundaries()251 fn test_sample_slice_boundaries() {
252 let empty: &[u8] = &[];
253
254 let mut r = thread_rng();
255
256 // sample 0 items
257 assert_eq!(sample_slice(&mut r, empty, 0), vec![]);
258 assert_eq!(sample_slice(&mut r, &[42, 2, 42], 0), vec![]);
259
260 // sample 1 item
261 assert_eq!(sample_slice(&mut r, &[42], 1), vec![42]);
262 let v = sample_slice(&mut r, &[1, 42], 1)[0];
263 assert!(v == 1 || v == 42);
264
265 // sample "all" the items
266 let v = sample_slice(&mut r, &[42, 133], 2);
267 assert!(v == vec![42, 133] || v == vec![133, 42]);
268
269 assert_eq!(sample_indices_inplace(&mut r, 0, 0), vec![]);
270 assert_eq!(sample_indices_inplace(&mut r, 1, 0), vec![]);
271 assert_eq!(sample_indices_inplace(&mut r, 1, 1), vec![0]);
272
273 assert_eq!(sample_indices_cache(&mut r, 0, 0), vec![]);
274 assert_eq!(sample_indices_cache(&mut r, 1, 0), vec![]);
275 assert_eq!(sample_indices_cache(&mut r, 1, 1), vec![0]);
276
277 // Make sure lucky 777's aren't lucky
278 let slice = &[42, 777];
279 let mut num_42 = 0;
280 let total = 1000;
281 for _ in 0..total {
282 let v = sample_slice(&mut r, slice, 1);
283 assert_eq!(v.len(), 1);
284 let v = v[0];
285 assert!(v == 42 || v == 777);
286 if v == 42 {
287 num_42 += 1;
288 }
289 }
290 let ratio_42 = num_42 as f64 / 1000 as f64;
291 assert!(0.4 <= ratio_42 || ratio_42 <= 0.6, "{}", ratio_42);
292 }
293
294 #[test]
test_sample_slice()295 fn test_sample_slice() {
296 let xor_rng = XorShiftRng::from_seed;
297
298 let max_range = 100;
299 let mut r = thread_rng();
300
301 for length in 1usize..max_range {
302 let amount = r.gen_range(0, length);
303 let seed: [u32; 4] = [
304 r.next_u32(), r.next_u32(), r.next_u32(), r.next_u32()
305 ];
306
307 println!("Selecting indices: len={}, amount={}, seed={:?}", length, amount, seed);
308
309 // assert that the two index methods give exactly the same result
310 let inplace = sample_indices_inplace(
311 &mut xor_rng(seed), length, amount);
312 let cache = sample_indices_cache(
313 &mut xor_rng(seed), length, amount);
314 assert_eq!(inplace, cache);
315
316 // assert the basics work
317 let regular = sample_indices(
318 &mut xor_rng(seed), length, amount);
319 assert_eq!(regular.len(), amount);
320 assert!(regular.iter().all(|e| *e < length));
321 assert_eq!(regular, inplace);
322
323 // also test that sampling the slice works
324 let vec: Vec<usize> = (0..length).collect();
325 {
326 let result = sample_slice(&mut xor_rng(seed), &vec, amount);
327 assert_eq!(result, regular);
328 }
329
330 {
331 let result = sample_slice_ref(&mut xor_rng(seed), &vec, amount);
332 let expected = regular.iter().map(|v| v).collect::<Vec<_>>();
333 assert_eq!(result, expected);
334 }
335 }
336 }
337 }
338