1 // Copyright 2018 Developers of the Rand project.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8
9 //! Low-level API for sampling indices
10
11 #[cfg(feature = "alloc")] use core::slice;
12
13 #[cfg(all(feature = "alloc", not(feature = "std")))]
14 use crate::alloc::vec::{self, Vec};
15 #[cfg(feature = "std")] use std::vec;
16 // BTreeMap is not as fast in tests, but better than nothing.
17 #[cfg(all(feature = "alloc", not(feature = "std")))]
18 use crate::alloc::collections::BTreeSet;
19 #[cfg(feature = "std")] use std::collections::HashSet;
20
21 #[cfg(feature = "alloc")]
22 use crate::distributions::{uniform::SampleUniform, Distribution, Uniform};
23 use crate::Rng;
24
25 /// A vector of indices.
26 ///
27 /// Multiple internal representations are possible.
28 #[derive(Clone, Debug)]
29 pub enum IndexVec {
30 #[doc(hidden)]
31 U32(Vec<u32>),
32 #[doc(hidden)]
33 USize(Vec<usize>),
34 }
35
36 impl IndexVec {
37 /// Returns the number of indices
38 #[inline]
len(&self) -> usize39 pub fn len(&self) -> usize {
40 match *self {
41 IndexVec::U32(ref v) => v.len(),
42 IndexVec::USize(ref v) => v.len(),
43 }
44 }
45
46 /// Returns `true` if the length is 0.
47 #[inline]
is_empty(&self) -> bool48 pub fn is_empty(&self) -> bool {
49 match *self {
50 IndexVec::U32(ref v) => v.is_empty(),
51 IndexVec::USize(ref v) => v.is_empty(),
52 }
53 }
54
55 /// Return the value at the given `index`.
56 ///
57 /// (Note: we cannot implement [`std::ops::Index`] because of lifetime
58 /// restrictions.)
59 #[inline]
index(&self, index: usize) -> usize60 pub fn index(&self, index: usize) -> usize {
61 match *self {
62 IndexVec::U32(ref v) => v[index] as usize,
63 IndexVec::USize(ref v) => v[index],
64 }
65 }
66
67 /// Return result as a `Vec<usize>`. Conversion may or may not be trivial.
68 #[inline]
into_vec(self) -> Vec<usize>69 pub fn into_vec(self) -> Vec<usize> {
70 match self {
71 IndexVec::U32(v) => v.into_iter().map(|i| i as usize).collect(),
72 IndexVec::USize(v) => v,
73 }
74 }
75
76 /// Iterate over the indices as a sequence of `usize` values
77 #[inline]
iter(&self) -> IndexVecIter<'_>78 pub fn iter(&self) -> IndexVecIter<'_> {
79 match *self {
80 IndexVec::U32(ref v) => IndexVecIter::U32(v.iter()),
81 IndexVec::USize(ref v) => IndexVecIter::USize(v.iter()),
82 }
83 }
84
85 /// Convert into an iterator over the indices as a sequence of `usize` values
86 #[inline]
into_iter(self) -> IndexVecIntoIter87 pub fn into_iter(self) -> IndexVecIntoIter {
88 match self {
89 IndexVec::U32(v) => IndexVecIntoIter::U32(v.into_iter()),
90 IndexVec::USize(v) => IndexVecIntoIter::USize(v.into_iter()),
91 }
92 }
93 }
94
95 impl PartialEq for IndexVec {
eq(&self, other: &IndexVec) -> bool96 fn eq(&self, other: &IndexVec) -> bool {
97 use self::IndexVec::*;
98 match (self, other) {
99 (&U32(ref v1), &U32(ref v2)) => v1 == v2,
100 (&USize(ref v1), &USize(ref v2)) => v1 == v2,
101 (&U32(ref v1), &USize(ref v2)) => {
102 (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x as usize == *y))
103 }
104 (&USize(ref v1), &U32(ref v2)) => {
105 (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x == *y as usize))
106 }
107 }
108 }
109 }
110
111 impl From<Vec<u32>> for IndexVec {
112 #[inline]
from(v: Vec<u32>) -> Self113 fn from(v: Vec<u32>) -> Self {
114 IndexVec::U32(v)
115 }
116 }
117
118 impl From<Vec<usize>> for IndexVec {
119 #[inline]
from(v: Vec<usize>) -> Self120 fn from(v: Vec<usize>) -> Self {
121 IndexVec::USize(v)
122 }
123 }
124
125 /// Return type of `IndexVec::iter`.
126 #[derive(Debug)]
127 pub enum IndexVecIter<'a> {
128 #[doc(hidden)]
129 U32(slice::Iter<'a, u32>),
130 #[doc(hidden)]
131 USize(slice::Iter<'a, usize>),
132 }
133
134 impl<'a> Iterator for IndexVecIter<'a> {
135 type Item = usize;
136
137 #[inline]
next(&mut self) -> Option<usize>138 fn next(&mut self) -> Option<usize> {
139 use self::IndexVecIter::*;
140 match *self {
141 U32(ref mut iter) => iter.next().map(|i| *i as usize),
142 USize(ref mut iter) => iter.next().cloned(),
143 }
144 }
145
146 #[inline]
size_hint(&self) -> (usize, Option<usize>)147 fn size_hint(&self) -> (usize, Option<usize>) {
148 match *self {
149 IndexVecIter::U32(ref v) => v.size_hint(),
150 IndexVecIter::USize(ref v) => v.size_hint(),
151 }
152 }
153 }
154
155 impl<'a> ExactSizeIterator for IndexVecIter<'a> {}
156
157 /// Return type of `IndexVec::into_iter`.
158 #[derive(Clone, Debug)]
159 pub enum IndexVecIntoIter {
160 #[doc(hidden)]
161 U32(vec::IntoIter<u32>),
162 #[doc(hidden)]
163 USize(vec::IntoIter<usize>),
164 }
165
166 impl Iterator for IndexVecIntoIter {
167 type Item = usize;
168
169 #[inline]
next(&mut self) -> Option<Self::Item>170 fn next(&mut self) -> Option<Self::Item> {
171 use self::IndexVecIntoIter::*;
172 match *self {
173 U32(ref mut v) => v.next().map(|i| i as usize),
174 USize(ref mut v) => v.next(),
175 }
176 }
177
178 #[inline]
size_hint(&self) -> (usize, Option<usize>)179 fn size_hint(&self) -> (usize, Option<usize>) {
180 use self::IndexVecIntoIter::*;
181 match *self {
182 U32(ref v) => v.size_hint(),
183 USize(ref v) => v.size_hint(),
184 }
185 }
186 }
187
188 impl ExactSizeIterator for IndexVecIntoIter {}
189
190
191 /// Randomly sample exactly `amount` distinct indices from `0..length`, and
192 /// return them in random order (fully shuffled).
193 ///
194 /// This method is used internally by the slice sampling methods, but it can
195 /// sometimes be useful to have the indices themselves so this is provided as
196 /// an alternative.
197 ///
198 /// The implementation used is not specified; we automatically select the
199 /// fastest available algorithm for the `length` and `amount` parameters
200 /// (based on detailed profiling on an Intel Haswell CPU). Roughly speaking,
201 /// complexity is `O(amount)`, except that when `amount` is small, performance
202 /// is closer to `O(amount^2)`, and when `length` is close to `amount` then
203 /// `O(length)`.
204 ///
205 /// Note that performance is significantly better over `u32` indices than over
206 /// `u64` indices. Because of this we hide the underlying type behind an
207 /// abstraction, `IndexVec`.
208 ///
209 /// If an allocation-free `no_std` function is required, it is suggested
210 /// to adapt the internal `sample_floyd` implementation.
211 ///
212 /// Panics if `amount > length`.
sample<R>(rng: &mut R, length: usize, amount: usize) -> IndexVec where R: Rng + ?Sized213 pub fn sample<R>(rng: &mut R, length: usize, amount: usize) -> IndexVec
214 where R: Rng + ?Sized {
215 if amount > length {
216 panic!("`amount` of samples must be less than or equal to `length`");
217 }
218 if length > (::core::u32::MAX as usize) {
219 // We never want to use inplace here, but could use floyd's alg
220 // Lazy version: always use the cache alg.
221 return sample_rejection(rng, length, amount);
222 }
223 let amount = amount as u32;
224 let length = length as u32;
225
226 // Choice of algorithm here depends on both length and amount. See:
227 // https://github.com/rust-random/rand/pull/479
228 // We do some calculations with f32. Accuracy is not very important.
229
230 if amount < 163 {
231 const C: [[f32; 2]; 2] = [[1.6, 8.0 / 45.0], [10.0, 70.0 / 9.0]];
232 let j = if length < 500_000 { 0 } else { 1 };
233 let amount_fp = amount as f32;
234 let m4 = C[0][j] * amount_fp;
235 // Short-cut: when amount < 12, floyd's is always faster
236 if amount > 11 && (length as f32) < (C[1][j] + m4) * amount_fp {
237 sample_inplace(rng, length, amount)
238 } else {
239 sample_floyd(rng, length, amount)
240 }
241 } else {
242 const C: [f32; 2] = [270.0, 330.0 / 9.0];
243 let j = if length < 500_000 { 0 } else { 1 };
244 if (length as f32) < C[j] * (amount as f32) {
245 sample_inplace(rng, length, amount)
246 } else {
247 sample_rejection(rng, length, amount)
248 }
249 }
250 }
251
252 /// Randomly sample exactly `amount` indices from `0..length`, using Floyd's
253 /// combination algorithm.
254 ///
255 /// The output values are fully shuffled. (Overhead is under 50%.)
256 ///
257 /// This implementation uses `O(amount)` memory and `O(amount^2)` time.
sample_floyd<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec where R: Rng + ?Sized258 fn sample_floyd<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec
259 where R: Rng + ?Sized {
260 // For small amount we use Floyd's fully-shuffled variant. For larger
261 // amounts this is slow due to Vec::insert performance, so we shuffle
262 // afterwards. Benchmarks show little overhead from extra logic.
263 let floyd_shuffle = amount < 50;
264
265 debug_assert!(amount <= length);
266 let mut indices = Vec::with_capacity(amount as usize);
267 for j in length - amount..length {
268 let t = rng.gen_range(0, j + 1);
269 if floyd_shuffle {
270 if let Some(pos) = indices.iter().position(|&x| x == t) {
271 indices.insert(pos, j);
272 continue;
273 }
274 } else if indices.contains(&t) {
275 indices.push(j);
276 continue;
277 }
278 indices.push(t);
279 }
280 if !floyd_shuffle {
281 // Reimplement SliceRandom::shuffle with smaller indices
282 for i in (1..amount).rev() {
283 // invariant: elements with index > i have been locked in place.
284 indices.swap(i as usize, rng.gen_range(0, i + 1) as usize);
285 }
286 }
287 IndexVec::from(indices)
288 }
289
290 /// Randomly sample exactly `amount` indices from `0..length`, using an inplace
291 /// partial Fisher-Yates method.
292 /// Sample an amount of indices using an inplace partial fisher yates method.
293 ///
294 /// This allocates the entire `length` of indices and randomizes only the first `amount`.
295 /// It then truncates to `amount` and returns.
296 ///
297 /// This method is not appropriate for large `length` and potentially uses a lot
298 /// of memory; because of this we only implement for `u32` index (which improves
299 /// performance in all cases).
300 ///
301 /// Set-up is `O(length)` time and memory and shuffling is `O(amount)` time.
sample_inplace<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec where R: Rng + ?Sized302 fn sample_inplace<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec
303 where R: Rng + ?Sized {
304 debug_assert!(amount <= length);
305 let mut indices: Vec<u32> = Vec::with_capacity(length as usize);
306 indices.extend(0..length);
307 for i in 0..amount {
308 let j: u32 = rng.gen_range(i, length);
309 indices.swap(i as usize, j as usize);
310 }
311 indices.truncate(amount as usize);
312 debug_assert_eq!(indices.len(), amount as usize);
313 IndexVec::from(indices)
314 }
315
316 trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + core::hash::Hash {
zero() -> Self317 fn zero() -> Self;
as_usize(self) -> usize318 fn as_usize(self) -> usize;
319 }
320 impl UInt for u32 {
321 #[inline]
zero() -> Self322 fn zero() -> Self {
323 0
324 }
325
326 #[inline]
as_usize(self) -> usize327 fn as_usize(self) -> usize {
328 self as usize
329 }
330 }
331 impl UInt for usize {
332 #[inline]
zero() -> Self333 fn zero() -> Self {
334 0
335 }
336
337 #[inline]
as_usize(self) -> usize338 fn as_usize(self) -> usize {
339 self
340 }
341 }
342
343 /// Randomly sample exactly `amount` indices from `0..length`, using rejection
344 /// sampling.
345 ///
346 /// Since `amount <<< length` there is a low chance of a random sample in
347 /// `0..length` being a duplicate. We test for duplicates and resample where
348 /// necessary. The algorithm is `O(amount)` time and memory.
349 ///
350 /// This function is generic over X primarily so that results are value-stable
351 /// over 32-bit and 64-bit platforms.
sample_rejection<X: UInt, R>(rng: &mut R, length: X, amount: X) -> IndexVec where R: Rng + ?Sized, IndexVec: From<Vec<X>>,352 fn sample_rejection<X: UInt, R>(rng: &mut R, length: X, amount: X) -> IndexVec
353 where
354 R: Rng + ?Sized,
355 IndexVec: From<Vec<X>>,
356 {
357 debug_assert!(amount < length);
358 #[cfg(feature = "std")]
359 let mut cache = HashSet::with_capacity(amount.as_usize());
360 #[cfg(not(feature = "std"))]
361 let mut cache = BTreeSet::new();
362 let distr = Uniform::new(X::zero(), length);
363 let mut indices = Vec::with_capacity(amount.as_usize());
364 for _ in 0..amount.as_usize() {
365 let mut pos = distr.sample(rng);
366 while !cache.insert(pos) {
367 pos = distr.sample(rng);
368 }
369 indices.push(pos);
370 }
371
372 debug_assert_eq!(indices.len(), amount.as_usize());
373 IndexVec::from(indices)
374 }
375
376 #[cfg(test)]
377 mod test {
378 use super::*;
379 #[cfg(all(feature = "alloc", not(feature = "std")))] use crate::alloc::vec;
380 #[cfg(feature = "std")] use std::vec;
381
382 #[test]
test_sample_boundaries()383 fn test_sample_boundaries() {
384 let mut r = crate::test::rng(404);
385
386 assert_eq!(sample_inplace(&mut r, 0, 0).len(), 0);
387 assert_eq!(sample_inplace(&mut r, 1, 0).len(), 0);
388 assert_eq!(sample_inplace(&mut r, 1, 1).into_vec(), vec![0]);
389
390 assert_eq!(sample_rejection(&mut r, 1u32, 0).len(), 0);
391
392 assert_eq!(sample_floyd(&mut r, 0, 0).len(), 0);
393 assert_eq!(sample_floyd(&mut r, 1, 0).len(), 0);
394 assert_eq!(sample_floyd(&mut r, 1, 1).into_vec(), vec![0]);
395
396 // These algorithms should be fast with big numbers. Test average.
397 let sum: usize = sample_rejection(&mut r, 1 << 25, 10u32).into_iter().sum();
398 assert!(1 << 25 < sum && sum < (1 << 25) * 25);
399
400 let sum: usize = sample_floyd(&mut r, 1 << 25, 10).into_iter().sum();
401 assert!(1 << 25 < sum && sum < (1 << 25) * 25);
402 }
403
404 #[test]
405 #[cfg_attr(miri, ignore)] // Miri is too slow
test_sample_alg()406 fn test_sample_alg() {
407 let seed_rng = crate::test::rng;
408
409 // We can't test which algorithm is used directly, but Floyd's alg
410 // should produce different results from the others. (Also, `inplace`
411 // and `cached` currently use different sizes thus produce different results.)
412
413 // A small length and relatively large amount should use inplace
414 let (length, amount): (usize, usize) = (100, 50);
415 let v1 = sample(&mut seed_rng(420), length, amount);
416 let v2 = sample_inplace(&mut seed_rng(420), length as u32, amount as u32);
417 assert!(v1.iter().all(|e| e < length));
418 assert_eq!(v1, v2);
419
420 // Test Floyd's alg does produce different results
421 let v3 = sample_floyd(&mut seed_rng(420), length as u32, amount as u32);
422 assert!(v1 != v3);
423
424 // A large length and small amount should use Floyd
425 let (length, amount): (usize, usize) = (1 << 20, 50);
426 let v1 = sample(&mut seed_rng(421), length, amount);
427 let v2 = sample_floyd(&mut seed_rng(421), length as u32, amount as u32);
428 assert!(v1.iter().all(|e| e < length));
429 assert_eq!(v1, v2);
430
431 // A large length and larger amount should use cache
432 let (length, amount): (usize, usize) = (1 << 20, 600);
433 let v1 = sample(&mut seed_rng(422), length, amount);
434 let v2 = sample_rejection(&mut seed_rng(422), length as u32, amount as u32);
435 assert!(v1.iter().all(|e| e < length));
436 assert_eq!(v1, v2);
437 }
438 }
439