1 // Copyright 2017 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // https://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10
11 //! Functions for randomly accessing and sampling sequences.
12
13 use super::Rng;
14
15 // This crate is only enabled when either std or alloc is available.
16 #[cfg(all(feature="alloc", not(feature="std")))] use alloc::vec::Vec;
17 // BTreeMap is not as fast in tests, but better than nothing.
18 #[cfg(feature="std")] use std::collections::HashMap;
19 #[cfg(all(feature="alloc", not(feature="std")))] use alloc::collections::BTreeMap;
20
21 /// Randomly sample `amount` elements from a finite iterator.
22 ///
23 /// The following can be returned:
24 ///
25 /// - `Ok`: `Vec` of `amount` non-repeating randomly sampled elements. The order is not random.
26 /// - `Err`: `Vec` of all the elements from `iterable` in sequential order. This happens when the
27 /// length of `iterable` was less than `amount`. This is considered an error since exactly
28 /// `amount` elements is typically expected.
29 ///
30 /// This implementation uses `O(len(iterable))` time and `O(amount)` memory.
31 ///
32 /// # Example
33 ///
34 /// ```
35 /// use rand::{thread_rng, seq};
36 ///
37 /// let mut rng = thread_rng();
38 /// let sample = seq::sample_iter(&mut rng, 1..100, 5).unwrap();
39 /// println!("{:?}", sample);
40 /// ```
sample_iter<T, I, R>(rng: &mut R, iterable: I, amount: usize) -> Result<Vec<T>, Vec<T>> where I: IntoIterator<Item=T>, R: Rng + ?Sized,41 pub fn sample_iter<T, I, R>(rng: &mut R, iterable: I, amount: usize) -> Result<Vec<T>, Vec<T>>
42 where I: IntoIterator<Item=T>,
43 R: Rng + ?Sized,
44 {
45 let mut iter = iterable.into_iter();
46 let mut reservoir = Vec::with_capacity(amount);
47 reservoir.extend(iter.by_ref().take(amount));
48
49 // Continue unless the iterator was exhausted
50 //
51 // note: this prevents iterators that "restart" from causing problems.
52 // If the iterator stops once, then so do we.
53 if reservoir.len() == amount {
54 for (i, elem) in iter.enumerate() {
55 let k = rng.gen_range(0, i + 1 + amount);
56 if let Some(spot) = reservoir.get_mut(k) {
57 *spot = elem;
58 }
59 }
60 Ok(reservoir)
61 } else {
62 // Don't hang onto extra memory. There is a corner case where
63 // `amount` was much less than `len(iterable)`.
64 reservoir.shrink_to_fit();
65 Err(reservoir)
66 }
67 }
68
69 /// Randomly sample exactly `amount` values from `slice`.
70 ///
71 /// The values are non-repeating and in random order.
72 ///
73 /// This implementation uses `O(amount)` time and memory.
74 ///
75 /// Panics if `amount > slice.len()`
76 ///
77 /// # Example
78 ///
79 /// ```
80 /// use rand::{thread_rng, seq};
81 ///
82 /// let mut rng = thread_rng();
83 /// let values = vec![5, 6, 1, 3, 4, 6, 7];
84 /// println!("{:?}", seq::sample_slice(&mut rng, &values, 3));
85 /// ```
sample_slice<R, T>(rng: &mut R, slice: &[T], amount: usize) -> Vec<T> where R: Rng + ?Sized, T: Clone86 pub fn sample_slice<R, T>(rng: &mut R, slice: &[T], amount: usize) -> Vec<T>
87 where R: Rng + ?Sized,
88 T: Clone
89 {
90 let indices = sample_indices(rng, slice.len(), amount);
91
92 let mut out = Vec::with_capacity(amount);
93 out.extend(indices.iter().map(|i| slice[*i].clone()));
94 out
95 }
96
97 /// Randomly sample exactly `amount` references from `slice`.
98 ///
99 /// The references are non-repeating and in random order.
100 ///
101 /// This implementation uses `O(amount)` time and memory.
102 ///
103 /// Panics if `amount > slice.len()`
104 ///
105 /// # Example
106 ///
107 /// ```
108 /// use rand::{thread_rng, seq};
109 ///
110 /// let mut rng = thread_rng();
111 /// let values = vec![5, 6, 1, 3, 4, 6, 7];
112 /// println!("{:?}", seq::sample_slice_ref(&mut rng, &values, 3));
113 /// ```
sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> Vec<&'a T> where R: Rng + ?Sized114 pub fn sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> Vec<&'a T>
115 where R: Rng + ?Sized
116 {
117 let indices = sample_indices(rng, slice.len(), amount);
118
119 let mut out = Vec::with_capacity(amount);
120 out.extend(indices.iter().map(|i| &slice[*i]));
121 out
122 }
123
124 /// Randomly sample exactly `amount` indices from `0..length`.
125 ///
126 /// The values are non-repeating and in random order.
127 ///
128 /// This implementation uses `O(amount)` time and memory.
129 ///
130 /// This method is used internally by the slice sampling methods, but it can sometimes be useful to
131 /// have the indices themselves so this is provided as an alternative.
132 ///
133 /// Panics if `amount > length`
sample_indices<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize> where R: Rng + ?Sized,134 pub fn sample_indices<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize>
135 where R: Rng + ?Sized,
136 {
137 if amount > length {
138 panic!("`amount` must be less than or equal to `slice.len()`");
139 }
140
141 // We are going to have to allocate at least `amount` for the output no matter what. However,
142 // if we use the `cached` version we will have to allocate `amount` as a HashMap as well since
143 // it inserts an element for every loop.
144 //
145 // Therefore, if `amount >= length / 2` then inplace will be both faster and use less memory.
146 // In fact, benchmarks show the inplace version is faster for length up to about 20 times
147 // faster than amount.
148 //
149 // TODO: there is probably even more fine-tuning that can be done here since
150 // `HashMap::with_capacity(amount)` probably allocates more than `amount` in practice,
151 // and a trade off could probably be made between memory/cpu, since hashmap operations
152 // are slower than array index swapping.
153 if amount >= length / 20 {
154 sample_indices_inplace(rng, length, amount)
155 } else {
156 sample_indices_cache(rng, length, amount)
157 }
158 }
159
160 /// Sample an amount of indices using an inplace partial fisher yates method.
161 ///
162 /// This allocates the entire `length` of indices and randomizes only the first `amount`.
163 /// It then truncates to `amount` and returns.
164 ///
165 /// This is better than using a `HashMap` "cache" when `amount >= length / 2`
166 /// since it does not require allocating an extra cache and is much faster.
sample_indices_inplace<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize> where R: Rng + ?Sized,167 fn sample_indices_inplace<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize>
168 where R: Rng + ?Sized,
169 {
170 debug_assert!(amount <= length);
171 let mut indices: Vec<usize> = Vec::with_capacity(length);
172 indices.extend(0..length);
173 for i in 0..amount {
174 let j: usize = rng.gen_range(i, length);
175 indices.swap(i, j);
176 }
177 indices.truncate(amount);
178 debug_assert_eq!(indices.len(), amount);
179 indices
180 }
181
182
183 /// This method performs a partial fisher-yates on a range of indices using a
184 /// `HashMap` as a cache to record potential collisions.
185 ///
186 /// The cache avoids allocating the entire `length` of values. This is especially useful when
187 /// `amount <<< length`, i.e. select 3 non-repeating from `1_000_000`
sample_indices_cache<R>( rng: &mut R, length: usize, amount: usize, ) -> Vec<usize> where R: Rng + ?Sized,188 fn sample_indices_cache<R>(
189 rng: &mut R,
190 length: usize,
191 amount: usize,
192 ) -> Vec<usize>
193 where R: Rng + ?Sized,
194 {
195 debug_assert!(amount <= length);
196 #[cfg(feature="std")] let mut cache = HashMap::with_capacity(amount);
197 #[cfg(not(feature="std"))] let mut cache = BTreeMap::new();
198 let mut out = Vec::with_capacity(amount);
199 for i in 0..amount {
200 let j: usize = rng.gen_range(i, length);
201
202 // equiv: let tmp = slice[i];
203 let tmp = match cache.get(&i) {
204 Some(e) => *e,
205 None => i,
206 };
207
208 // equiv: slice[i] = slice[j];
209 let x = match cache.get(&j) {
210 Some(x) => *x,
211 None => j,
212 };
213
214 // equiv: slice[j] = tmp;
215 cache.insert(j, tmp);
216
217 // note that in the inplace version, slice[i] is automatically "returned" value
218 out.push(x);
219 }
220 debug_assert_eq!(out.len(), amount);
221 out
222 }
223
224 #[cfg(test)]
225 mod test {
226 use super::*;
227 use {XorShiftRng, Rng, SeedableRng};
228 #[cfg(not(feature="std"))]
229 use alloc::vec::Vec;
230
231 #[test]
test_sample_iter()232 fn test_sample_iter() {
233 let min_val = 1;
234 let max_val = 100;
235
236 let mut r = ::test::rng(401);
237 let vals = (min_val..max_val).collect::<Vec<i32>>();
238 let small_sample = sample_iter(&mut r, vals.iter(), 5).unwrap();
239 let large_sample = sample_iter(&mut r, vals.iter(), vals.len() + 5).unwrap_err();
240
241 assert_eq!(small_sample.len(), 5);
242 assert_eq!(large_sample.len(), vals.len());
243 // no randomization happens when amount >= len
244 assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
245
246 assert!(small_sample.iter().all(|e| {
247 **e >= min_val && **e <= max_val
248 }));
249 }
250 #[test]
test_sample_slice_boundaries()251 fn test_sample_slice_boundaries() {
252 let empty: &[u8] = &[];
253
254 let mut r = ::test::rng(402);
255
256 // sample 0 items
257 assert_eq!(&sample_slice(&mut r, empty, 0)[..], [0u8; 0]);
258 assert_eq!(&sample_slice(&mut r, &[42, 2, 42], 0)[..], [0u8; 0]);
259
260 // sample 1 item
261 assert_eq!(&sample_slice(&mut r, &[42], 1)[..], [42]);
262 let v = sample_slice(&mut r, &[1, 42], 1)[0];
263 assert!(v == 1 || v == 42);
264
265 // sample "all" the items
266 let v = sample_slice(&mut r, &[42, 133], 2);
267 assert!(&v[..] == [42, 133] || v[..] == [133, 42]);
268
269 assert_eq!(&sample_indices_inplace(&mut r, 0, 0)[..], [0usize; 0]);
270 assert_eq!(&sample_indices_inplace(&mut r, 1, 0)[..], [0usize; 0]);
271 assert_eq!(&sample_indices_inplace(&mut r, 1, 1)[..], [0]);
272
273 assert_eq!(&sample_indices_cache(&mut r, 0, 0)[..], [0usize; 0]);
274 assert_eq!(&sample_indices_cache(&mut r, 1, 0)[..], [0usize; 0]);
275 assert_eq!(&sample_indices_cache(&mut r, 1, 1)[..], [0]);
276
277 // Make sure lucky 777's aren't lucky
278 let slice = &[42, 777];
279 let mut num_42 = 0;
280 let total = 1000;
281 for _ in 0..total {
282 let v = sample_slice(&mut r, slice, 1);
283 assert_eq!(v.len(), 1);
284 let v = v[0];
285 assert!(v == 42 || v == 777);
286 if v == 42 {
287 num_42 += 1;
288 }
289 }
290 let ratio_42 = num_42 as f64 / 1000 as f64;
291 assert!(0.4 <= ratio_42 || ratio_42 <= 0.6, "{}", ratio_42);
292 }
293
294 #[test]
test_sample_slice()295 fn test_sample_slice() {
296 let xor_rng = XorShiftRng::from_seed;
297
298 let max_range = 100;
299 let mut r = ::test::rng(403);
300
301 for length in 1usize..max_range {
302 let amount = r.gen_range(0, length);
303 let mut seed = [0u8; 16];
304 r.fill(&mut seed);
305
306 // assert that the two index methods give exactly the same result
307 let inplace = sample_indices_inplace(
308 &mut xor_rng(seed), length, amount);
309 let cache = sample_indices_cache(
310 &mut xor_rng(seed), length, amount);
311 assert_eq!(inplace, cache);
312
313 // assert the basics work
314 let regular = sample_indices(
315 &mut xor_rng(seed), length, amount);
316 assert_eq!(regular.len(), amount);
317 assert!(regular.iter().all(|e| *e < length));
318 assert_eq!(regular, inplace);
319
320 // also test that sampling the slice works
321 let vec: Vec<usize> = (0..length).collect();
322 {
323 let result = sample_slice(&mut xor_rng(seed), &vec, amount);
324 assert_eq!(result, regular);
325 }
326
327 {
328 let result = sample_slice_ref(&mut xor_rng(seed), &vec, amount);
329 let expected = regular.iter().map(|v| v).collect::<Vec<_>>();
330 assert_eq!(result, expected);
331 }
332 }
333 }
334 }
335