1 // Copyright 2017 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10 
11 //! Functions for randomly accessing and sampling sequences.
12 
13 use super::Rng;
14 
15 // This crate is only enabled when either std or alloc is available.
16 // BTreeMap is not as fast in tests, but better than nothing.
17 #[cfg(feature="std")] use std::collections::HashMap;
18 #[cfg(not(feature="std"))] use alloc::btree_map::BTreeMap;
19 
20 #[cfg(not(feature="std"))] use alloc::Vec;
21 
22 /// Randomly sample `amount` elements from a finite iterator.
23 ///
24 /// The following can be returned:
25 /// - `Ok`: `Vec` of `amount` non-repeating randomly sampled elements. The order is not random.
26 /// - `Err`: `Vec` of all the elements from `iterable` in sequential order. This happens when the
27 ///   length of `iterable` was less than `amount`. This is considered an error since exactly
28 ///   `amount` elements is typically expected.
29 ///
30 /// This implementation uses `O(len(iterable))` time and `O(amount)` memory.
31 ///
32 /// # Example
33 ///
34 /// ```rust
35 /// use rand::{thread_rng, seq};
36 ///
37 /// let mut rng = thread_rng();
38 /// let sample = seq::sample_iter(&mut rng, 1..100, 5).unwrap();
39 /// println!("{:?}", sample);
40 /// ```
sample_iter<T, I, R>(rng: &mut R, iterable: I, amount: usize) -> Result<Vec<T>, Vec<T>> where I: IntoIterator<Item=T>, R: Rng,41 pub fn sample_iter<T, I, R>(rng: &mut R, iterable: I, amount: usize) -> Result<Vec<T>, Vec<T>>
42     where I: IntoIterator<Item=T>,
43           R: Rng,
44 {
45     let mut iter = iterable.into_iter();
46     let mut reservoir = Vec::with_capacity(amount);
47     reservoir.extend(iter.by_ref().take(amount));
48 
49     // Continue unless the iterator was exhausted
50     //
51     // note: this prevents iterators that "restart" from causing problems.
52     // If the iterator stops once, then so do we.
53     if reservoir.len() == amount {
54         for (i, elem) in iter.enumerate() {
55             let k = rng.gen_range(0, i + 1 + amount);
56             if let Some(spot) = reservoir.get_mut(k) {
57                 *spot = elem;
58             }
59         }
60         Ok(reservoir)
61     } else {
62         // Don't hang onto extra memory. There is a corner case where
63         // `amount` was much less than `len(iterable)`.
64         reservoir.shrink_to_fit();
65         Err(reservoir)
66     }
67 }
68 
69 /// Randomly sample exactly `amount` values from `slice`.
70 ///
71 /// The values are non-repeating and in random order.
72 ///
73 /// This implementation uses `O(amount)` time and memory.
74 ///
75 /// Panics if `amount > slice.len()`
76 ///
77 /// # Example
78 ///
79 /// ```rust
80 /// use rand::{thread_rng, seq};
81 ///
82 /// let mut rng = thread_rng();
83 /// let values = vec![5, 6, 1, 3, 4, 6, 7];
84 /// println!("{:?}", seq::sample_slice(&mut rng, &values, 3));
85 /// ```
sample_slice<R, T>(rng: &mut R, slice: &[T], amount: usize) -> Vec<T> where R: Rng, T: Clone86 pub fn sample_slice<R, T>(rng: &mut R, slice: &[T], amount: usize) -> Vec<T>
87     where R: Rng,
88           T: Clone
89 {
90     let indices = sample_indices(rng, slice.len(), amount);
91 
92     let mut out = Vec::with_capacity(amount);
93     out.extend(indices.iter().map(|i| slice[*i].clone()));
94     out
95 }
96 
97 /// Randomly sample exactly `amount` references from `slice`.
98 ///
99 /// The references are non-repeating and in random order.
100 ///
101 /// This implementation uses `O(amount)` time and memory.
102 ///
103 /// Panics if `amount > slice.len()`
104 ///
105 /// # Example
106 ///
107 /// ```rust
108 /// use rand::{thread_rng, seq};
109 ///
110 /// let mut rng = thread_rng();
111 /// let values = vec![5, 6, 1, 3, 4, 6, 7];
112 /// println!("{:?}", seq::sample_slice_ref(&mut rng, &values, 3));
113 /// ```
sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> Vec<&'a T> where R: Rng114 pub fn sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> Vec<&'a T>
115     where R: Rng
116 {
117     let indices = sample_indices(rng, slice.len(), amount);
118 
119     let mut out = Vec::with_capacity(amount);
120     out.extend(indices.iter().map(|i| &slice[*i]));
121     out
122 }
123 
124 /// Randomly sample exactly `amount` indices from `0..length`.
125 ///
126 /// The values are non-repeating and in random order.
127 ///
128 /// This implementation uses `O(amount)` time and memory.
129 ///
130 /// This method is used internally by the slice sampling methods, but it can sometimes be useful to
131 /// have the indices themselves so this is provided as an alternative.
132 ///
133 /// Panics if `amount > length`
sample_indices<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize> where R: Rng,134 pub fn sample_indices<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize>
135     where R: Rng,
136 {
137     if amount > length {
138         panic!("`amount` must be less than or equal to `slice.len()`");
139     }
140 
141     // We are going to have to allocate at least `amount` for the output no matter what. However,
142     // if we use the `cached` version we will have to allocate `amount` as a HashMap as well since
143     // it inserts an element for every loop.
144     //
145     // Therefore, if `amount >= length / 2` then inplace will be both faster and use less memory.
146     // In fact, benchmarks show the inplace version is faster for length up to about 20 times
147     // faster than amount.
148     //
149     // TODO: there is probably even more fine-tuning that can be done here since
150     // `HashMap::with_capacity(amount)` probably allocates more than `amount` in practice,
151     // and a trade off could probably be made between memory/cpu, since hashmap operations
152     // are slower than array index swapping.
153     if amount >= length / 20 {
154         sample_indices_inplace(rng, length, amount)
155     } else {
156         sample_indices_cache(rng, length, amount)
157     }
158 }
159 
160 /// Sample an amount of indices using an inplace partial fisher yates method.
161 ///
162 /// This allocates the entire `length` of indices and randomizes only the first `amount`.
163 /// It then truncates to `amount` and returns.
164 ///
165 /// This is better than using a HashMap "cache" when `amount >= length / 2` since it does not
166 /// require allocating an extra cache and is much faster.
sample_indices_inplace<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize> where R: Rng,167 fn sample_indices_inplace<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize>
168     where R: Rng,
169 {
170     debug_assert!(amount <= length);
171     let mut indices: Vec<usize> = Vec::with_capacity(length);
172     indices.extend(0..length);
173     for i in 0..amount {
174         let j: usize = rng.gen_range(i, length);
175         let tmp = indices[i];
176         indices[i] = indices[j];
177         indices[j] = tmp;
178     }
179     indices.truncate(amount);
180     debug_assert_eq!(indices.len(), amount);
181     indices
182 }
183 
184 
185 /// This method performs a partial fisher-yates on a range of indices using a HashMap
186 /// as a cache to record potential collisions.
187 ///
188 /// The cache avoids allocating the entire `length` of values. This is especially useful when
189 /// `amount <<< length`, i.e. select 3 non-repeating from 1_000_000
sample_indices_cache<R>( rng: &mut R, length: usize, amount: usize, ) -> Vec<usize> where R: Rng,190 fn sample_indices_cache<R>(
191     rng: &mut R,
192     length: usize,
193     amount: usize,
194 ) -> Vec<usize>
195     where R: Rng,
196 {
197     debug_assert!(amount <= length);
198     #[cfg(feature="std")] let mut cache = HashMap::with_capacity(amount);
199     #[cfg(not(feature="std"))] let mut cache = BTreeMap::new();
200     let mut out = Vec::with_capacity(amount);
201     for i in 0..amount {
202         let j: usize = rng.gen_range(i, length);
203 
204         // equiv: let tmp = slice[i];
205         let tmp = match cache.get(&i) {
206             Some(e) => *e,
207             None => i,
208         };
209 
210         // equiv: slice[i] = slice[j];
211         let x = match cache.get(&j) {
212             Some(x) => *x,
213             None => j,
214         };
215 
216         // equiv: slice[j] = tmp;
217         cache.insert(j, tmp);
218 
219         // note that in the inplace version, slice[i] is automatically "returned" value
220         out.push(x);
221     }
222     debug_assert_eq!(out.len(), amount);
223     out
224 }
225 
226 #[cfg(test)]
227 mod test {
228     use super::*;
229     use {thread_rng, XorShiftRng, SeedableRng};
230 
231     #[test]
test_sample_iter()232     fn test_sample_iter() {
233         let min_val = 1;
234         let max_val = 100;
235 
236         let mut r = thread_rng();
237         let vals = (min_val..max_val).collect::<Vec<i32>>();
238         let small_sample = sample_iter(&mut r, vals.iter(), 5).unwrap();
239         let large_sample = sample_iter(&mut r, vals.iter(), vals.len() + 5).unwrap_err();
240 
241         assert_eq!(small_sample.len(), 5);
242         assert_eq!(large_sample.len(), vals.len());
243         // no randomization happens when amount >= len
244         assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
245 
246         assert!(small_sample.iter().all(|e| {
247             **e >= min_val && **e <= max_val
248         }));
249     }
250     #[test]
test_sample_slice_boundaries()251     fn test_sample_slice_boundaries() {
252         let empty: &[u8] = &[];
253 
254         let mut r = thread_rng();
255 
256         // sample 0 items
257         assert_eq!(sample_slice(&mut r, empty, 0), vec![]);
258         assert_eq!(sample_slice(&mut r, &[42, 2, 42], 0), vec![]);
259 
260         // sample 1 item
261         assert_eq!(sample_slice(&mut r, &[42], 1), vec![42]);
262         let v = sample_slice(&mut r, &[1, 42], 1)[0];
263         assert!(v == 1 || v == 42);
264 
265         // sample "all" the items
266         let v = sample_slice(&mut r, &[42, 133], 2);
267         assert!(v == vec![42, 133] || v == vec![133, 42]);
268 
269         assert_eq!(sample_indices_inplace(&mut r, 0, 0), vec![]);
270         assert_eq!(sample_indices_inplace(&mut r, 1, 0), vec![]);
271         assert_eq!(sample_indices_inplace(&mut r, 1, 1), vec![0]);
272 
273         assert_eq!(sample_indices_cache(&mut r, 0, 0), vec![]);
274         assert_eq!(sample_indices_cache(&mut r, 1, 0), vec![]);
275         assert_eq!(sample_indices_cache(&mut r, 1, 1), vec![0]);
276 
277         // Make sure lucky 777's aren't lucky
278         let slice = &[42, 777];
279         let mut num_42 = 0;
280         let total = 1000;
281         for _ in 0..total {
282             let v = sample_slice(&mut r, slice, 1);
283             assert_eq!(v.len(), 1);
284             let v = v[0];
285             assert!(v == 42 || v == 777);
286             if v == 42 {
287                 num_42 += 1;
288             }
289         }
290         let ratio_42 = num_42 as f64 / 1000 as f64;
291         assert!(0.4 <= ratio_42 || ratio_42 <= 0.6, "{}", ratio_42);
292     }
293 
294     #[test]
test_sample_slice()295     fn test_sample_slice() {
296         let xor_rng = XorShiftRng::from_seed;
297 
298         let max_range = 100;
299         let mut r = thread_rng();
300 
301         for length in 1usize..max_range {
302             let amount = r.gen_range(0, length);
303             let seed: [u32; 4] = [
304                 r.next_u32(), r.next_u32(), r.next_u32(), r.next_u32()
305             ];
306 
307             println!("Selecting indices: len={}, amount={}, seed={:?}", length, amount, seed);
308 
309             // assert that the two index methods give exactly the same result
310             let inplace = sample_indices_inplace(
311                 &mut xor_rng(seed), length, amount);
312             let cache = sample_indices_cache(
313                 &mut xor_rng(seed), length, amount);
314             assert_eq!(inplace, cache);
315 
316             // assert the basics work
317             let regular = sample_indices(
318                 &mut xor_rng(seed), length, amount);
319             assert_eq!(regular.len(), amount);
320             assert!(regular.iter().all(|e| *e < length));
321             assert_eq!(regular, inplace);
322 
323             // also test that sampling the slice works
324             let vec: Vec<usize> = (0..length).collect();
325             {
326                 let result = sample_slice(&mut xor_rng(seed), &vec, amount);
327                 assert_eq!(result, regular);
328             }
329 
330             {
331                 let result = sample_slice_ref(&mut xor_rng(seed), &vec, amount);
332                 let expected = regular.iter().map(|v| v).collect::<Vec<_>>();
333                 assert_eq!(result, expected);
334             }
335         }
336     }
337 }
338