1 // Copyright 2017 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // https://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10 
11 //! Functions for randomly accessing and sampling sequences.
12 
13 use super::Rng;
14 
15 // This crate is only enabled when either std or alloc is available.
16 #[cfg(all(feature="alloc", not(feature="std")))] use alloc::vec::Vec;
17 // BTreeMap is not as fast in tests, but better than nothing.
18 #[cfg(feature="std")] use std::collections::HashMap;
19 #[cfg(all(feature="alloc", not(feature="std")))] use alloc::collections::BTreeMap;
20 
21 /// Randomly sample `amount` elements from a finite iterator.
22 ///
23 /// The following can be returned:
24 ///
25 /// - `Ok`: `Vec` of `amount` non-repeating randomly sampled elements. The order is not random.
26 /// - `Err`: `Vec` of all the elements from `iterable` in sequential order. This happens when the
27 ///   length of `iterable` was less than `amount`. This is considered an error since exactly
28 ///   `amount` elements is typically expected.
29 ///
30 /// This implementation uses `O(len(iterable))` time and `O(amount)` memory.
31 ///
32 /// # Example
33 ///
34 /// ```
35 /// use rand::{thread_rng, seq};
36 ///
37 /// let mut rng = thread_rng();
38 /// let sample = seq::sample_iter(&mut rng, 1..100, 5).unwrap();
39 /// println!("{:?}", sample);
40 /// ```
sample_iter<T, I, R>(rng: &mut R, iterable: I, amount: usize) -> Result<Vec<T>, Vec<T>> where I: IntoIterator<Item=T>, R: Rng + ?Sized,41 pub fn sample_iter<T, I, R>(rng: &mut R, iterable: I, amount: usize) -> Result<Vec<T>, Vec<T>>
42     where I: IntoIterator<Item=T>,
43           R: Rng + ?Sized,
44 {
45     let mut iter = iterable.into_iter();
46     let mut reservoir = Vec::with_capacity(amount);
47     reservoir.extend(iter.by_ref().take(amount));
48 
49     // Continue unless the iterator was exhausted
50     //
51     // note: this prevents iterators that "restart" from causing problems.
52     // If the iterator stops once, then so do we.
53     if reservoir.len() == amount {
54         for (i, elem) in iter.enumerate() {
55             let k = rng.gen_range(0, i + 1 + amount);
56             if let Some(spot) = reservoir.get_mut(k) {
57                 *spot = elem;
58             }
59         }
60         Ok(reservoir)
61     } else {
62         // Don't hang onto extra memory. There is a corner case where
63         // `amount` was much less than `len(iterable)`.
64         reservoir.shrink_to_fit();
65         Err(reservoir)
66     }
67 }
68 
69 /// Randomly sample exactly `amount` values from `slice`.
70 ///
71 /// The values are non-repeating and in random order.
72 ///
73 /// This implementation uses `O(amount)` time and memory.
74 ///
75 /// Panics if `amount > slice.len()`
76 ///
77 /// # Example
78 ///
79 /// ```
80 /// use rand::{thread_rng, seq};
81 ///
82 /// let mut rng = thread_rng();
83 /// let values = vec![5, 6, 1, 3, 4, 6, 7];
84 /// println!("{:?}", seq::sample_slice(&mut rng, &values, 3));
85 /// ```
sample_slice<R, T>(rng: &mut R, slice: &[T], amount: usize) -> Vec<T> where R: Rng + ?Sized, T: Clone86 pub fn sample_slice<R, T>(rng: &mut R, slice: &[T], amount: usize) -> Vec<T>
87     where R: Rng + ?Sized,
88           T: Clone
89 {
90     let indices = sample_indices(rng, slice.len(), amount);
91 
92     let mut out = Vec::with_capacity(amount);
93     out.extend(indices.iter().map(|i| slice[*i].clone()));
94     out
95 }
96 
97 /// Randomly sample exactly `amount` references from `slice`.
98 ///
99 /// The references are non-repeating and in random order.
100 ///
101 /// This implementation uses `O(amount)` time and memory.
102 ///
103 /// Panics if `amount > slice.len()`
104 ///
105 /// # Example
106 ///
107 /// ```
108 /// use rand::{thread_rng, seq};
109 ///
110 /// let mut rng = thread_rng();
111 /// let values = vec![5, 6, 1, 3, 4, 6, 7];
112 /// println!("{:?}", seq::sample_slice_ref(&mut rng, &values, 3));
113 /// ```
sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> Vec<&'a T> where R: Rng + ?Sized114 pub fn sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> Vec<&'a T>
115     where R: Rng + ?Sized
116 {
117     let indices = sample_indices(rng, slice.len(), amount);
118 
119     let mut out = Vec::with_capacity(amount);
120     out.extend(indices.iter().map(|i| &slice[*i]));
121     out
122 }
123 
124 /// Randomly sample exactly `amount` indices from `0..length`.
125 ///
126 /// The values are non-repeating and in random order.
127 ///
128 /// This implementation uses `O(amount)` time and memory.
129 ///
130 /// This method is used internally by the slice sampling methods, but it can sometimes be useful to
131 /// have the indices themselves so this is provided as an alternative.
132 ///
133 /// Panics if `amount > length`
sample_indices<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize> where R: Rng + ?Sized,134 pub fn sample_indices<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize>
135     where R: Rng + ?Sized,
136 {
137     if amount > length {
138         panic!("`amount` must be less than or equal to `slice.len()`");
139     }
140 
141     // We are going to have to allocate at least `amount` for the output no matter what. However,
142     // if we use the `cached` version we will have to allocate `amount` as a HashMap as well since
143     // it inserts an element for every loop.
144     //
145     // Therefore, if `amount >= length / 2` then inplace will be both faster and use less memory.
146     // In fact, benchmarks show the inplace version is faster for length up to about 20 times
147     // faster than amount.
148     //
149     // TODO: there is probably even more fine-tuning that can be done here since
150     // `HashMap::with_capacity(amount)` probably allocates more than `amount` in practice,
151     // and a trade off could probably be made between memory/cpu, since hashmap operations
152     // are slower than array index swapping.
153     if amount >= length / 20 {
154         sample_indices_inplace(rng, length, amount)
155     } else {
156         sample_indices_cache(rng, length, amount)
157     }
158 }
159 
160 /// Sample an amount of indices using an inplace partial fisher yates method.
161 ///
162 /// This allocates the entire `length` of indices and randomizes only the first `amount`.
163 /// It then truncates to `amount` and returns.
164 ///
165 /// This is better than using a `HashMap` "cache" when `amount >= length / 2`
166 /// since it does not require allocating an extra cache and is much faster.
sample_indices_inplace<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize> where R: Rng + ?Sized,167 fn sample_indices_inplace<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize>
168     where R: Rng + ?Sized,
169 {
170     debug_assert!(amount <= length);
171     let mut indices: Vec<usize> = Vec::with_capacity(length);
172     indices.extend(0..length);
173     for i in 0..amount {
174         let j: usize = rng.gen_range(i, length);
175         indices.swap(i, j);
176     }
177     indices.truncate(amount);
178     debug_assert_eq!(indices.len(), amount);
179     indices
180 }
181 
182 
183 /// This method performs a partial fisher-yates on a range of indices using a
184 /// `HashMap` as a cache to record potential collisions.
185 ///
186 /// The cache avoids allocating the entire `length` of values. This is especially useful when
187 /// `amount <<< length`, i.e. select 3 non-repeating from `1_000_000`
sample_indices_cache<R>( rng: &mut R, length: usize, amount: usize, ) -> Vec<usize> where R: Rng + ?Sized,188 fn sample_indices_cache<R>(
189     rng: &mut R,
190     length: usize,
191     amount: usize,
192 ) -> Vec<usize>
193     where R: Rng + ?Sized,
194 {
195     debug_assert!(amount <= length);
196     #[cfg(feature="std")] let mut cache = HashMap::with_capacity(amount);
197     #[cfg(not(feature="std"))] let mut cache = BTreeMap::new();
198     let mut out = Vec::with_capacity(amount);
199     for i in 0..amount {
200         let j: usize = rng.gen_range(i, length);
201 
202         // equiv: let tmp = slice[i];
203         let tmp = match cache.get(&i) {
204             Some(e) => *e,
205             None => i,
206         };
207 
208         // equiv: slice[i] = slice[j];
209         let x = match cache.get(&j) {
210             Some(x) => *x,
211             None => j,
212         };
213 
214         // equiv: slice[j] = tmp;
215         cache.insert(j, tmp);
216 
217         // note that in the inplace version, slice[i] is automatically "returned" value
218         out.push(x);
219     }
220     debug_assert_eq!(out.len(), amount);
221     out
222 }
223 
224 #[cfg(test)]
225 mod test {
226     use super::*;
227     use {XorShiftRng, Rng, SeedableRng};
228     #[cfg(not(feature="std"))]
229     use alloc::vec::Vec;
230 
231     #[test]
test_sample_iter()232     fn test_sample_iter() {
233         let min_val = 1;
234         let max_val = 100;
235 
236         let mut r = ::test::rng(401);
237         let vals = (min_val..max_val).collect::<Vec<i32>>();
238         let small_sample = sample_iter(&mut r, vals.iter(), 5).unwrap();
239         let large_sample = sample_iter(&mut r, vals.iter(), vals.len() + 5).unwrap_err();
240 
241         assert_eq!(small_sample.len(), 5);
242         assert_eq!(large_sample.len(), vals.len());
243         // no randomization happens when amount >= len
244         assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
245 
246         assert!(small_sample.iter().all(|e| {
247             **e >= min_val && **e <= max_val
248         }));
249     }
250     #[test]
test_sample_slice_boundaries()251     fn test_sample_slice_boundaries() {
252         let empty: &[u8] = &[];
253 
254         let mut r = ::test::rng(402);
255 
256         // sample 0 items
257         assert_eq!(&sample_slice(&mut r, empty, 0)[..], [0u8; 0]);
258         assert_eq!(&sample_slice(&mut r, &[42, 2, 42], 0)[..], [0u8; 0]);
259 
260         // sample 1 item
261         assert_eq!(&sample_slice(&mut r, &[42], 1)[..], [42]);
262         let v = sample_slice(&mut r, &[1, 42], 1)[0];
263         assert!(v == 1 || v == 42);
264 
265         // sample "all" the items
266         let v = sample_slice(&mut r, &[42, 133], 2);
267         assert!(&v[..] == [42, 133] || v[..] == [133, 42]);
268 
269         assert_eq!(&sample_indices_inplace(&mut r, 0, 0)[..], [0usize; 0]);
270         assert_eq!(&sample_indices_inplace(&mut r, 1, 0)[..], [0usize; 0]);
271         assert_eq!(&sample_indices_inplace(&mut r, 1, 1)[..], [0]);
272 
273         assert_eq!(&sample_indices_cache(&mut r, 0, 0)[..], [0usize; 0]);
274         assert_eq!(&sample_indices_cache(&mut r, 1, 0)[..], [0usize; 0]);
275         assert_eq!(&sample_indices_cache(&mut r, 1, 1)[..], [0]);
276 
277         // Make sure lucky 777's aren't lucky
278         let slice = &[42, 777];
279         let mut num_42 = 0;
280         let total = 1000;
281         for _ in 0..total {
282             let v = sample_slice(&mut r, slice, 1);
283             assert_eq!(v.len(), 1);
284             let v = v[0];
285             assert!(v == 42 || v == 777);
286             if v == 42 {
287                 num_42 += 1;
288             }
289         }
290         let ratio_42 = num_42 as f64 / 1000 as f64;
291         assert!(0.4 <= ratio_42 || ratio_42 <= 0.6, "{}", ratio_42);
292     }
293 
294     #[test]
test_sample_slice()295     fn test_sample_slice() {
296         let xor_rng = XorShiftRng::from_seed;
297 
298         let max_range = 100;
299         let mut r = ::test::rng(403);
300 
301         for length in 1usize..max_range {
302             let amount = r.gen_range(0, length);
303             let mut seed = [0u8; 16];
304             r.fill(&mut seed);
305 
306             // assert that the two index methods give exactly the same result
307             let inplace = sample_indices_inplace(
308                 &mut xor_rng(seed), length, amount);
309             let cache = sample_indices_cache(
310                 &mut xor_rng(seed), length, amount);
311             assert_eq!(inplace, cache);
312 
313             // assert the basics work
314             let regular = sample_indices(
315                 &mut xor_rng(seed), length, amount);
316             assert_eq!(regular.len(), amount);
317             assert!(regular.iter().all(|e| *e < length));
318             assert_eq!(regular, inplace);
319 
320             // also test that sampling the slice works
321             let vec: Vec<usize> = (0..length).collect();
322             {
323                 let result = sample_slice(&mut xor_rng(seed), &vec, amount);
324                 assert_eq!(result, regular);
325             }
326 
327             {
328                 let result = sample_slice_ref(&mut xor_rng(seed), &vec, amount);
329                 let expected = regular.iter().map(|v| v).collect::<Vec<_>>();
330                 assert_eq!(result, expected);
331             }
332         }
333     }
334 }
335