1 // Copyright 2018 Developers of the Rand project.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8
9 //! Sequence-related functionality
10 //!
11 //! This module provides:
12 //!
13 //! * [`SliceRandom`] slice sampling and mutation
14 //! * [`IteratorRandom`] iterator sampling
15 //! * [`index::sample`] low-level API to choose multiple indices from
16 //! `0..length`
17 //!
18 //! Also see:
19 //!
20 //! * [`crate::distributions::WeightedIndex`] distribution which provides
21 //! weighted index sampling.
22 //!
23 //! In order to make results reproducible across 32-64 bit architectures, all
24 //! `usize` indices are sampled as a `u32` where possible (also providing a
25 //! small performance boost in some cases).
26
27
28 #[cfg(feature = "alloc")]
29 #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
30 pub mod index;
31
32 #[cfg(feature = "alloc")] use core::ops::Index;
33
34 #[cfg(feature = "alloc")] use alloc::vec::Vec;
35
36 #[cfg(feature = "alloc")]
37 use crate::distributions::uniform::{SampleBorrow, SampleUniform};
38 #[cfg(feature = "alloc")] use crate::distributions::WeightedError;
39 use crate::Rng;
40
41 /// Extension trait on slices, providing random mutation and sampling methods.
42 ///
43 /// This trait is implemented on all `[T]` slice types, providing several
44 /// methods for choosing and shuffling elements. You must `use` this trait:
45 ///
46 /// ```
47 /// use rand::seq::SliceRandom;
48 ///
49 /// let mut rng = rand::thread_rng();
50 /// let mut bytes = "Hello, random!".to_string().into_bytes();
51 /// bytes.shuffle(&mut rng);
52 /// let str = String::from_utf8(bytes).unwrap();
53 /// println!("{}", str);
54 /// ```
55 /// Example output (non-deterministic):
56 /// ```none
57 /// l,nmroHado !le
58 /// ```
59 pub trait SliceRandom {
60 /// The element type.
61 type Item;
62
63 /// Returns a reference to one random element of the slice, or `None` if the
64 /// slice is empty.
65 ///
66 /// For slices, complexity is `O(1)`.
67 ///
68 /// # Example
69 ///
70 /// ```
71 /// use rand::thread_rng;
72 /// use rand::seq::SliceRandom;
73 ///
74 /// let choices = [1, 2, 4, 8, 16, 32];
75 /// let mut rng = thread_rng();
76 /// println!("{:?}", choices.choose(&mut rng));
77 /// assert_eq!(choices[..0].choose(&mut rng), None);
78 /// ```
choose<R>(&self, rng: &mut R) -> Option<&Self::Item> where R: Rng + ?Sized79 fn choose<R>(&self, rng: &mut R) -> Option<&Self::Item>
80 where R: Rng + ?Sized;
81
82 /// Returns a mutable reference to one random element of the slice, or
83 /// `None` if the slice is empty.
84 ///
85 /// For slices, complexity is `O(1)`.
choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Item> where R: Rng + ?Sized86 fn choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Item>
87 where R: Rng + ?Sized;
88
89 /// Chooses `amount` elements from the slice at random, without repetition,
90 /// and in random order. The returned iterator is appropriate both for
91 /// collection into a `Vec` and filling an existing buffer (see example).
92 ///
93 /// In case this API is not sufficiently flexible, use [`index::sample`].
94 ///
95 /// For slices, complexity is the same as [`index::sample`].
96 ///
97 /// # Example
98 /// ```
99 /// use rand::seq::SliceRandom;
100 ///
101 /// let mut rng = &mut rand::thread_rng();
102 /// let sample = "Hello, audience!".as_bytes();
103 ///
104 /// // collect the results into a vector:
105 /// let v: Vec<u8> = sample.choose_multiple(&mut rng, 3).cloned().collect();
106 ///
107 /// // store in a buffer:
108 /// let mut buf = [0u8; 5];
109 /// for (b, slot) in sample.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) {
110 /// *slot = *b;
111 /// }
112 /// ```
113 #[cfg(feature = "alloc")]
114 #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item> where R: Rng + ?Sized115 fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item>
116 where R: Rng + ?Sized;
117
118 /// Similar to [`choose`], but where the likelihood of each outcome may be
119 /// specified.
120 ///
121 /// The specified function `weight` maps each item `x` to a relative
122 /// likelihood `weight(x)`. The probability of each item being selected is
123 /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
124 ///
125 /// For slices of length `n`, complexity is `O(n)`.
126 /// See also [`choose_weighted_mut`], [`distributions::weighted`].
127 ///
128 /// # Example
129 ///
130 /// ```
131 /// use rand::prelude::*;
132 ///
133 /// let choices = [('a', 2), ('b', 1), ('c', 1)];
134 /// let mut rng = thread_rng();
135 /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
136 /// println!("{:?}", choices.choose_weighted(&mut rng, |item| item.1).unwrap().0);
137 /// ```
138 /// [`choose`]: SliceRandom::choose
139 /// [`choose_weighted_mut`]: SliceRandom::choose_weighted_mut
140 /// [`distributions::weighted`]: crate::distributions::weighted
141 #[cfg(feature = "alloc")]
142 #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
choose_weighted<R, F, B, X>( &self, rng: &mut R, weight: F, ) -> Result<&Self::Item, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow<X>, X: SampleUniform + for<'a> ::core::ops::AddAssign<&'a X> + ::core::cmp::PartialOrd<X> + Clone + Default143 fn choose_weighted<R, F, B, X>(
144 &self, rng: &mut R, weight: F,
145 ) -> Result<&Self::Item, WeightedError>
146 where
147 R: Rng + ?Sized,
148 F: Fn(&Self::Item) -> B,
149 B: SampleBorrow<X>,
150 X: SampleUniform
151 + for<'a> ::core::ops::AddAssign<&'a X>
152 + ::core::cmp::PartialOrd<X>
153 + Clone
154 + Default;
155
156 /// Similar to [`choose_mut`], but where the likelihood of each outcome may
157 /// be specified.
158 ///
159 /// The specified function `weight` maps each item `x` to a relative
160 /// likelihood `weight(x)`. The probability of each item being selected is
161 /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
162 ///
163 /// For slices of length `n`, complexity is `O(n)`.
164 /// See also [`choose_weighted`], [`distributions::weighted`].
165 ///
166 /// [`choose_mut`]: SliceRandom::choose_mut
167 /// [`choose_weighted`]: SliceRandom::choose_weighted
168 /// [`distributions::weighted`]: crate::distributions::weighted
169 #[cfg(feature = "alloc")]
170 #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
choose_weighted_mut<R, F, B, X>( &mut self, rng: &mut R, weight: F, ) -> Result<&mut Self::Item, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow<X>, X: SampleUniform + for<'a> ::core::ops::AddAssign<&'a X> + ::core::cmp::PartialOrd<X> + Clone + Default171 fn choose_weighted_mut<R, F, B, X>(
172 &mut self, rng: &mut R, weight: F,
173 ) -> Result<&mut Self::Item, WeightedError>
174 where
175 R: Rng + ?Sized,
176 F: Fn(&Self::Item) -> B,
177 B: SampleBorrow<X>,
178 X: SampleUniform
179 + for<'a> ::core::ops::AddAssign<&'a X>
180 + ::core::cmp::PartialOrd<X>
181 + Clone
182 + Default;
183
184 /// Similar to [`choose_multiple`], but where the likelihood of each element's
185 /// inclusion in the output may be specified. The elements are returned in an
186 /// arbitrary, unspecified order.
187 ///
188 /// The specified function `weight` maps each item `x` to a relative
189 /// likelihood `weight(x)`. The probability of each item being selected is
190 /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
191 ///
192 /// If all of the weights are equal, even if they are all zero, each element has
193 /// an equal likelihood of being selected.
194 ///
195 /// The complexity of this method depends on the feature `partition_at_index`.
196 /// If the feature is enabled, then for slices of length `n`, the complexity
197 /// is `O(n)` space and `O(n)` time. Otherwise, the complexity is `O(n)` space and
198 /// `O(n * log amount)` time.
199 ///
200 /// # Example
201 ///
202 /// ```
203 /// use rand::prelude::*;
204 ///
205 /// let choices = [('a', 2), ('b', 1), ('c', 1)];
206 /// let mut rng = thread_rng();
207 /// // First Draw * Second Draw = total odds
208 /// // -----------------------
209 /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'b']` in some order.
210 /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'c']` in some order.
211 /// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order.
212 /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::<Vec<_>>());
213 /// ```
214 /// [`choose_multiple`]: SliceRandom::choose_multiple
215 //
216 // Note: this is feature-gated on std due to usage of f64::powf.
217 // If necessary, we may use alloc+libm as an alternative (see PR #1089).
218 #[cfg(feature = "std")]
219 #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
choose_multiple_weighted<R, F, X>( &self, rng: &mut R, amount: usize, weight: F, ) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> X, X: Into<f64>220 fn choose_multiple_weighted<R, F, X>(
221 &self, rng: &mut R, amount: usize, weight: F,
222 ) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError>
223 where
224 R: Rng + ?Sized,
225 F: Fn(&Self::Item) -> X,
226 X: Into<f64>;
227
228 /// Shuffle a mutable slice in place.
229 ///
230 /// For slices of length `n`, complexity is `O(n)`.
231 ///
232 /// # Example
233 ///
234 /// ```
235 /// use rand::seq::SliceRandom;
236 /// use rand::thread_rng;
237 ///
238 /// let mut rng = thread_rng();
239 /// let mut y = [1, 2, 3, 4, 5];
240 /// println!("Unshuffled: {:?}", y);
241 /// y.shuffle(&mut rng);
242 /// println!("Shuffled: {:?}", y);
243 /// ```
shuffle<R>(&mut self, rng: &mut R) where R: Rng + ?Sized244 fn shuffle<R>(&mut self, rng: &mut R)
245 where R: Rng + ?Sized;
246
247 /// Shuffle a slice in place, but exit early.
248 ///
249 /// Returns two mutable slices from the source slice. The first contains
250 /// `amount` elements randomly permuted. The second has the remaining
251 /// elements that are not fully shuffled.
252 ///
253 /// This is an efficient method to select `amount` elements at random from
254 /// the slice, provided the slice may be mutated.
255 ///
256 /// If you only need to choose elements randomly and `amount > self.len()/2`
257 /// then you may improve performance by taking
258 /// `amount = values.len() - amount` and using only the second slice.
259 ///
260 /// If `amount` is greater than the number of elements in the slice, this
261 /// will perform a full shuffle.
262 ///
263 /// For slices, complexity is `O(m)` where `m = amount`.
partial_shuffle<R>( &mut self, rng: &mut R, amount: usize, ) -> (&mut [Self::Item], &mut [Self::Item]) where R: Rng + ?Sized264 fn partial_shuffle<R>(
265 &mut self, rng: &mut R, amount: usize,
266 ) -> (&mut [Self::Item], &mut [Self::Item])
267 where R: Rng + ?Sized;
268 }
269
270 /// Extension trait on iterators, providing random sampling methods.
271 ///
272 /// This trait is implemented on all iterators `I` where `I: Iterator + Sized`
273 /// and provides methods for
274 /// choosing one or more elements. You must `use` this trait:
275 ///
276 /// ```
277 /// use rand::seq::IteratorRandom;
278 ///
279 /// let mut rng = rand::thread_rng();
280 ///
281 /// let faces = "";
282 /// println!("I am {}!", faces.chars().choose(&mut rng).unwrap());
283 /// ```
284 /// Example output (non-deterministic):
285 /// ```none
286 /// I am !
287 /// ```
288 pub trait IteratorRandom: Iterator + Sized {
289 /// Choose one element at random from the iterator.
290 ///
291 /// Returns `None` if and only if the iterator is empty.
292 ///
293 /// This method uses [`Iterator::size_hint`] for optimisation. With an
294 /// accurate hint and where [`Iterator::nth`] is a constant-time operation
295 /// this method can offer `O(1)` performance. Where no size hint is
296 /// available, complexity is `O(n)` where `n` is the iterator length.
297 /// Partial hints (where `lower > 0`) also improve performance.
298 ///
299 /// Note that the output values and the number of RNG samples used
300 /// depends on size hints. In particular, `Iterator` combinators that don't
301 /// change the values yielded but change the size hints may result in
302 /// `choose` returning different elements. If you want consistent results
303 /// and RNG usage consider using [`IteratorRandom::choose_stable`].
choose<R>(mut self, rng: &mut R) -> Option<Self::Item> where R: Rng + ?Sized304 fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item>
305 where R: Rng + ?Sized {
306 let (mut lower, mut upper) = self.size_hint();
307 let mut consumed = 0;
308 let mut result = None;
309
310 // Handling for this condition outside the loop allows the optimizer to eliminate the loop
311 // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g.
312 // seq_iter_choose_from_1000.
313 if upper == Some(lower) {
314 return if lower == 0 {
315 None
316 } else {
317 self.nth(gen_index(rng, lower))
318 };
319 }
320
321 // Continue until the iterator is exhausted
322 loop {
323 if lower > 1 {
324 let ix = gen_index(rng, lower + consumed);
325 let skip = if ix < lower {
326 result = self.nth(ix);
327 lower - (ix + 1)
328 } else {
329 lower
330 };
331 if upper == Some(lower) {
332 return result;
333 }
334 consumed += lower;
335 if skip > 0 {
336 self.nth(skip - 1);
337 }
338 } else {
339 let elem = self.next();
340 if elem.is_none() {
341 return result;
342 }
343 consumed += 1;
344 if gen_index(rng, consumed) == 0 {
345 result = elem;
346 }
347 }
348
349 let hint = self.size_hint();
350 lower = hint.0;
351 upper = hint.1;
352 }
353 }
354
355 /// Choose one element at random from the iterator.
356 ///
357 /// Returns `None` if and only if the iterator is empty.
358 ///
359 /// This method is very similar to [`choose`] except that the result
360 /// only depends on the length of the iterator and the values produced by
361 /// `rng`. Notably for any iterator of a given length this will make the
362 /// same requests to `rng` and if the same sequence of values are produced
363 /// the same index will be selected from `self`. This may be useful if you
364 /// need consistent results no matter what type of iterator you are working
365 /// with. If you do not need this stability prefer [`choose`].
366 ///
367 /// Note that this method still uses [`Iterator::size_hint`] to skip
368 /// constructing elements where possible, however the selection and `rng`
369 /// calls are the same in the face of this optimization. If you want to
370 /// force every element to be created regardless call `.inspect(|e| ())`.
371 ///
372 /// [`choose`]: IteratorRandom::choose
choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item> where R: Rng + ?Sized373 fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item>
374 where R: Rng + ?Sized {
375 let mut consumed = 0;
376 let mut result = None;
377
378 loop {
379 // Currently the only way to skip elements is `nth()`. So we need to
380 // store what index to access next here.
381 // This should be replaced by `advance_by()` once it is stable:
382 // https://github.com/rust-lang/rust/issues/77404
383 let mut next = 0;
384
385 let (lower, _) = self.size_hint();
386 if lower >= 2 {
387 let highest_selected = (0..lower)
388 .filter(|ix| gen_index(rng, consumed+ix+1) == 0)
389 .last();
390
391 consumed += lower;
392 next = lower;
393
394 if let Some(ix) = highest_selected {
395 result = self.nth(ix);
396 next -= ix + 1;
397 debug_assert!(result.is_some(), "iterator shorter than size_hint().0");
398 }
399 }
400
401 let elem = self.nth(next);
402 if elem.is_none() {
403 return result
404 }
405
406 if gen_index(rng, consumed+1) == 0 {
407 result = elem;
408 }
409 consumed += 1;
410 }
411 }
412
413 /// Collects values at random from the iterator into a supplied buffer
414 /// until that buffer is filled.
415 ///
416 /// Although the elements are selected randomly, the order of elements in
417 /// the buffer is neither stable nor fully random. If random ordering is
418 /// desired, shuffle the result.
419 ///
420 /// Returns the number of elements added to the buffer. This equals the length
421 /// of the buffer unless the iterator contains insufficient elements, in which
422 /// case this equals the number of elements available.
423 ///
424 /// Complexity is `O(n)` where `n` is the length of the iterator.
425 /// For slices, prefer [`SliceRandom::choose_multiple`].
choose_multiple_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize where R: Rng + ?Sized426 fn choose_multiple_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize
427 where R: Rng + ?Sized {
428 let amount = buf.len();
429 let mut len = 0;
430 while len < amount {
431 if let Some(elem) = self.next() {
432 buf[len] = elem;
433 len += 1;
434 } else {
435 // Iterator exhausted; stop early
436 return len;
437 }
438 }
439
440 // Continue, since the iterator was not exhausted
441 for (i, elem) in self.enumerate() {
442 let k = gen_index(rng, i + 1 + amount);
443 if let Some(slot) = buf.get_mut(k) {
444 *slot = elem;
445 }
446 }
447 len
448 }
449
450 /// Collects `amount` values at random from the iterator into a vector.
451 ///
452 /// This is equivalent to `choose_multiple_fill` except for the result type.
453 ///
454 /// Although the elements are selected randomly, the order of elements in
455 /// the buffer is neither stable nor fully random. If random ordering is
456 /// desired, shuffle the result.
457 ///
458 /// The length of the returned vector equals `amount` unless the iterator
459 /// contains insufficient elements, in which case it equals the number of
460 /// elements available.
461 ///
462 /// Complexity is `O(n)` where `n` is the length of the iterator.
463 /// For slices, prefer [`SliceRandom::choose_multiple`].
464 #[cfg(feature = "alloc")]
465 #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
choose_multiple<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item> where R: Rng + ?Sized466 fn choose_multiple<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item>
467 where R: Rng + ?Sized {
468 let mut reservoir = Vec::with_capacity(amount);
469 reservoir.extend(self.by_ref().take(amount));
470
471 // Continue unless the iterator was exhausted
472 //
473 // note: this prevents iterators that "restart" from causing problems.
474 // If the iterator stops once, then so do we.
475 if reservoir.len() == amount {
476 for (i, elem) in self.enumerate() {
477 let k = gen_index(rng, i + 1 + amount);
478 if let Some(slot) = reservoir.get_mut(k) {
479 *slot = elem;
480 }
481 }
482 } else {
483 // Don't hang onto extra memory. There is a corner case where
484 // `amount` was much less than `self.len()`.
485 reservoir.shrink_to_fit();
486 }
487 reservoir
488 }
489 }
490
491
492 impl<T> SliceRandom for [T] {
493 type Item = T;
494
choose<R>(&self, rng: &mut R) -> Option<&Self::Item> where R: Rng + ?Sized495 fn choose<R>(&self, rng: &mut R) -> Option<&Self::Item>
496 where R: Rng + ?Sized {
497 if self.is_empty() {
498 None
499 } else {
500 Some(&self[gen_index(rng, self.len())])
501 }
502 }
503
choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Item> where R: Rng + ?Sized504 fn choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Item>
505 where R: Rng + ?Sized {
506 if self.is_empty() {
507 None
508 } else {
509 let len = self.len();
510 Some(&mut self[gen_index(rng, len)])
511 }
512 }
513
514 #[cfg(feature = "alloc")]
choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item> where R: Rng + ?Sized515 fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item>
516 where R: Rng + ?Sized {
517 let amount = ::core::cmp::min(amount, self.len());
518 SliceChooseIter {
519 slice: self,
520 _phantom: Default::default(),
521 indices: index::sample(rng, self.len(), amount).into_iter(),
522 }
523 }
524
525 #[cfg(feature = "alloc")]
choose_weighted<R, F, B, X>( &self, rng: &mut R, weight: F, ) -> Result<&Self::Item, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow<X>, X: SampleUniform + for<'a> ::core::ops::AddAssign<&'a X> + ::core::cmp::PartialOrd<X> + Clone + Default,526 fn choose_weighted<R, F, B, X>(
527 &self, rng: &mut R, weight: F,
528 ) -> Result<&Self::Item, WeightedError>
529 where
530 R: Rng + ?Sized,
531 F: Fn(&Self::Item) -> B,
532 B: SampleBorrow<X>,
533 X: SampleUniform
534 + for<'a> ::core::ops::AddAssign<&'a X>
535 + ::core::cmp::PartialOrd<X>
536 + Clone
537 + Default,
538 {
539 use crate::distributions::{Distribution, WeightedIndex};
540 let distr = WeightedIndex::new(self.iter().map(weight))?;
541 Ok(&self[distr.sample(rng)])
542 }
543
544 #[cfg(feature = "alloc")]
choose_weighted_mut<R, F, B, X>( &mut self, rng: &mut R, weight: F, ) -> Result<&mut Self::Item, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow<X>, X: SampleUniform + for<'a> ::core::ops::AddAssign<&'a X> + ::core::cmp::PartialOrd<X> + Clone + Default,545 fn choose_weighted_mut<R, F, B, X>(
546 &mut self, rng: &mut R, weight: F,
547 ) -> Result<&mut Self::Item, WeightedError>
548 where
549 R: Rng + ?Sized,
550 F: Fn(&Self::Item) -> B,
551 B: SampleBorrow<X>,
552 X: SampleUniform
553 + for<'a> ::core::ops::AddAssign<&'a X>
554 + ::core::cmp::PartialOrd<X>
555 + Clone
556 + Default,
557 {
558 use crate::distributions::{Distribution, WeightedIndex};
559 let distr = WeightedIndex::new(self.iter().map(weight))?;
560 Ok(&mut self[distr.sample(rng)])
561 }
562
563 #[cfg(feature = "std")]
choose_multiple_weighted<R, F, X>( &self, rng: &mut R, amount: usize, weight: F, ) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> X, X: Into<f64>,564 fn choose_multiple_weighted<R, F, X>(
565 &self, rng: &mut R, amount: usize, weight: F,
566 ) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError>
567 where
568 R: Rng + ?Sized,
569 F: Fn(&Self::Item) -> X,
570 X: Into<f64>,
571 {
572 let amount = ::core::cmp::min(amount, self.len());
573 Ok(SliceChooseIter {
574 slice: self,
575 _phantom: Default::default(),
576 indices: index::sample_weighted(
577 rng,
578 self.len(),
579 |idx| weight(&self[idx]).into(),
580 amount,
581 )?
582 .into_iter(),
583 })
584 }
585
shuffle<R>(&mut self, rng: &mut R) where R: Rng + ?Sized586 fn shuffle<R>(&mut self, rng: &mut R)
587 where R: Rng + ?Sized {
588 for i in (1..self.len()).rev() {
589 // invariant: elements with index > i have been locked in place.
590 self.swap(i, gen_index(rng, i + 1));
591 }
592 }
593
partial_shuffle<R>( &mut self, rng: &mut R, amount: usize, ) -> (&mut [Self::Item], &mut [Self::Item]) where R: Rng + ?Sized594 fn partial_shuffle<R>(
595 &mut self, rng: &mut R, amount: usize,
596 ) -> (&mut [Self::Item], &mut [Self::Item])
597 where R: Rng + ?Sized {
598 // This applies Durstenfeld's algorithm for the
599 // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm)
600 // for an unbiased permutation, but exits early after choosing `amount`
601 // elements.
602
603 let len = self.len();
604 let end = if amount >= len { 0 } else { len - amount };
605
606 for i in (end..len).rev() {
607 // invariant: elements with index > i have been locked in place.
608 self.swap(i, gen_index(rng, i + 1));
609 }
610 let r = self.split_at_mut(end);
611 (r.1, r.0)
612 }
613 }
614
615 impl<I> IteratorRandom for I where I: Iterator + Sized {}
616
617
618 /// An iterator over multiple slice elements.
619 ///
620 /// This struct is created by
621 /// [`SliceRandom::choose_multiple`](trait.SliceRandom.html#tymethod.choose_multiple).
622 #[cfg(feature = "alloc")]
623 #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
624 #[derive(Debug)]
625 pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> {
626 slice: &'a S,
627 _phantom: ::core::marker::PhantomData<T>,
628 indices: index::IndexVecIntoIter,
629 }
630
631 #[cfg(feature = "alloc")]
632 impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> Iterator for SliceChooseIter<'a, S, T> {
633 type Item = &'a T;
634
next(&mut self) -> Option<Self::Item>635 fn next(&mut self) -> Option<Self::Item> {
636 // TODO: investigate using SliceIndex::get_unchecked when stable
637 self.indices.next().map(|i| &self.slice[i as usize])
638 }
639
size_hint(&self) -> (usize, Option<usize>)640 fn size_hint(&self) -> (usize, Option<usize>) {
641 (self.indices.len(), Some(self.indices.len()))
642 }
643 }
644
645 #[cfg(feature = "alloc")]
646 impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> ExactSizeIterator
647 for SliceChooseIter<'a, S, T>
648 {
len(&self) -> usize649 fn len(&self) -> usize {
650 self.indices.len()
651 }
652 }
653
654
655 // Sample a number uniformly between 0 and `ubound`. Uses 32-bit sampling where
656 // possible, primarily in order to produce the same output on 32-bit and 64-bit
657 // platforms.
658 #[inline]
gen_index<R: Rng + ?Sized>(rng: &mut R, ubound: usize) -> usize659 fn gen_index<R: Rng + ?Sized>(rng: &mut R, ubound: usize) -> usize {
660 if ubound <= (core::u32::MAX as usize) {
661 rng.gen_range(0..ubound as u32) as usize
662 } else {
663 rng.gen_range(0..ubound)
664 }
665 }
666
667
668 #[cfg(test)]
669 mod test {
670 use super::*;
671 #[cfg(feature = "alloc")] use crate::Rng;
672 #[cfg(all(feature = "alloc", not(feature = "std")))] use alloc::vec::Vec;
673
674 #[test]
test_slice_choose()675 fn test_slice_choose() {
676 let mut r = crate::test::rng(107);
677 let chars = [
678 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
679 ];
680 let mut chosen = [0i32; 14];
681 // The below all use a binomial distribution with n=1000, p=1/14.
682 // binocdf(40, 1000, 1/14) ~= 2e-5; 1-binocdf(106, ..) ~= 2e-5
683 for _ in 0..1000 {
684 let picked = *chars.choose(&mut r).unwrap();
685 chosen[(picked as usize) - ('a' as usize)] += 1;
686 }
687 for count in chosen.iter() {
688 assert!(40 < *count && *count < 106);
689 }
690
691 chosen.iter_mut().for_each(|x| *x = 0);
692 for _ in 0..1000 {
693 *chosen.choose_mut(&mut r).unwrap() += 1;
694 }
695 for count in chosen.iter() {
696 assert!(40 < *count && *count < 106);
697 }
698
699 let mut v: [isize; 0] = [];
700 assert_eq!(v.choose(&mut r), None);
701 assert_eq!(v.choose_mut(&mut r), None);
702 }
703
704 #[test]
value_stability_slice()705 fn value_stability_slice() {
706 let mut r = crate::test::rng(413);
707 let chars = [
708 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
709 ];
710 let mut nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
711
712 assert_eq!(chars.choose(&mut r), Some(&'l'));
713 assert_eq!(nums.choose_mut(&mut r), Some(&mut 10));
714
715 #[cfg(feature = "alloc")]
716 assert_eq!(
717 &chars
718 .choose_multiple(&mut r, 8)
719 .cloned()
720 .collect::<Vec<char>>(),
721 &['d', 'm', 'b', 'n', 'c', 'k', 'h', 'e']
722 );
723
724 #[cfg(feature = "alloc")]
725 assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'f'));
726 #[cfg(feature = "alloc")]
727 assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 5));
728
729 let mut r = crate::test::rng(414);
730 nums.shuffle(&mut r);
731 assert_eq!(nums, [9, 5, 3, 10, 7, 12, 8, 11, 6, 4, 0, 2, 1]);
732 nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
733 let res = nums.partial_shuffle(&mut r, 6);
734 assert_eq!(res.0, &mut [7, 4, 8, 6, 9, 3]);
735 assert_eq!(res.1, &mut [0, 1, 2, 12, 11, 5, 10]);
736 }
737
738 #[derive(Clone)]
739 struct UnhintedIterator<I: Iterator + Clone> {
740 iter: I,
741 }
742 impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
743 type Item = I::Item;
744
next(&mut self) -> Option<Self::Item>745 fn next(&mut self) -> Option<Self::Item> {
746 self.iter.next()
747 }
748 }
749
750 #[derive(Clone)]
751 struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
752 iter: I,
753 chunk_remaining: usize,
754 chunk_size: usize,
755 hint_total_size: bool,
756 }
757 impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> {
758 type Item = I::Item;
759
next(&mut self) -> Option<Self::Item>760 fn next(&mut self) -> Option<Self::Item> {
761 if self.chunk_remaining == 0 {
762 self.chunk_remaining = ::core::cmp::min(self.chunk_size, self.iter.len());
763 }
764 self.chunk_remaining = self.chunk_remaining.saturating_sub(1);
765
766 self.iter.next()
767 }
768
size_hint(&self) -> (usize, Option<usize>)769 fn size_hint(&self) -> (usize, Option<usize>) {
770 (
771 self.chunk_remaining,
772 if self.hint_total_size {
773 Some(self.iter.len())
774 } else {
775 None
776 },
777 )
778 }
779 }
780
781 #[derive(Clone)]
782 struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
783 iter: I,
784 window_size: usize,
785 hint_total_size: bool,
786 }
787 impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
788 type Item = I::Item;
789
next(&mut self) -> Option<Self::Item>790 fn next(&mut self) -> Option<Self::Item> {
791 self.iter.next()
792 }
793
size_hint(&self) -> (usize, Option<usize>)794 fn size_hint(&self) -> (usize, Option<usize>) {
795 (
796 ::core::cmp::min(self.iter.len(), self.window_size),
797 if self.hint_total_size {
798 Some(self.iter.len())
799 } else {
800 None
801 },
802 )
803 }
804 }
805
806 #[test]
807 #[cfg_attr(miri, ignore)] // Miri is too slow
test_iterator_choose()808 fn test_iterator_choose() {
809 let r = &mut crate::test::rng(109);
810 fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
811 let mut chosen = [0i32; 9];
812 for _ in 0..1000 {
813 let picked = iter.clone().choose(r).unwrap();
814 chosen[picked] += 1;
815 }
816 for count in chosen.iter() {
817 // Samples should follow Binomial(1000, 1/9)
818 // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
819 // Note: have seen 153, which is unlikely but not impossible.
820 assert!(
821 72 < *count && *count < 154,
822 "count not close to 1000/9: {}",
823 count
824 );
825 }
826 }
827
828 test_iter(r, 0..9);
829 test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
830 #[cfg(feature = "alloc")]
831 test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
832 test_iter(r, UnhintedIterator { iter: 0..9 });
833 test_iter(r, ChunkHintedIterator {
834 iter: 0..9,
835 chunk_size: 4,
836 chunk_remaining: 4,
837 hint_total_size: false,
838 });
839 test_iter(r, ChunkHintedIterator {
840 iter: 0..9,
841 chunk_size: 4,
842 chunk_remaining: 4,
843 hint_total_size: true,
844 });
845 test_iter(r, WindowHintedIterator {
846 iter: 0..9,
847 window_size: 2,
848 hint_total_size: false,
849 });
850 test_iter(r, WindowHintedIterator {
851 iter: 0..9,
852 window_size: 2,
853 hint_total_size: true,
854 });
855
856 assert_eq!((0..0).choose(r), None);
857 assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
858 }
859
860 #[test]
861 #[cfg_attr(miri, ignore)] // Miri is too slow
test_iterator_choose_stable()862 fn test_iterator_choose_stable() {
863 let r = &mut crate::test::rng(109);
864 fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
865 let mut chosen = [0i32; 9];
866 for _ in 0..1000 {
867 let picked = iter.clone().choose_stable(r).unwrap();
868 chosen[picked] += 1;
869 }
870 for count in chosen.iter() {
871 // Samples should follow Binomial(1000, 1/9)
872 // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
873 // Note: have seen 153, which is unlikely but not impossible.
874 assert!(
875 72 < *count && *count < 154,
876 "count not close to 1000/9: {}",
877 count
878 );
879 }
880 }
881
882 test_iter(r, 0..9);
883 test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
884 #[cfg(feature = "alloc")]
885 test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
886 test_iter(r, UnhintedIterator { iter: 0..9 });
887 test_iter(r, ChunkHintedIterator {
888 iter: 0..9,
889 chunk_size: 4,
890 chunk_remaining: 4,
891 hint_total_size: false,
892 });
893 test_iter(r, ChunkHintedIterator {
894 iter: 0..9,
895 chunk_size: 4,
896 chunk_remaining: 4,
897 hint_total_size: true,
898 });
899 test_iter(r, WindowHintedIterator {
900 iter: 0..9,
901 window_size: 2,
902 hint_total_size: false,
903 });
904 test_iter(r, WindowHintedIterator {
905 iter: 0..9,
906 window_size: 2,
907 hint_total_size: true,
908 });
909
910 assert_eq!((0..0).choose(r), None);
911 assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
912 }
913
914 #[test]
915 #[cfg_attr(miri, ignore)] // Miri is too slow
test_iterator_choose_stable_stability()916 fn test_iterator_choose_stable_stability() {
917 fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] {
918 let r = &mut crate::test::rng(109);
919 let mut chosen = [0i32; 9];
920 for _ in 0..1000 {
921 let picked = iter.clone().choose_stable(r).unwrap();
922 chosen[picked] += 1;
923 }
924 chosen
925 }
926
927 let reference = test_iter(0..9);
928 assert_eq!(test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), reference);
929
930 #[cfg(feature = "alloc")]
931 assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference);
932 assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference);
933 assert_eq!(test_iter(ChunkHintedIterator {
934 iter: 0..9,
935 chunk_size: 4,
936 chunk_remaining: 4,
937 hint_total_size: false,
938 }), reference);
939 assert_eq!(test_iter(ChunkHintedIterator {
940 iter: 0..9,
941 chunk_size: 4,
942 chunk_remaining: 4,
943 hint_total_size: true,
944 }), reference);
945 assert_eq!(test_iter(WindowHintedIterator {
946 iter: 0..9,
947 window_size: 2,
948 hint_total_size: false,
949 }), reference);
950 assert_eq!(test_iter(WindowHintedIterator {
951 iter: 0..9,
952 window_size: 2,
953 hint_total_size: true,
954 }), reference);
955 }
956
957 #[test]
958 #[cfg_attr(miri, ignore)] // Miri is too slow
test_shuffle()959 fn test_shuffle() {
960 let mut r = crate::test::rng(108);
961 let empty: &mut [isize] = &mut [];
962 empty.shuffle(&mut r);
963 let mut one = [1];
964 one.shuffle(&mut r);
965 let b: &[_] = &[1];
966 assert_eq!(one, b);
967
968 let mut two = [1, 2];
969 two.shuffle(&mut r);
970 assert!(two == [1, 2] || two == [2, 1]);
971
972 fn move_last(slice: &mut [usize], pos: usize) {
973 // use slice[pos..].rotate_left(1); once we can use that
974 let last_val = slice[pos];
975 for i in pos..slice.len() - 1 {
976 slice[i] = slice[i + 1];
977 }
978 *slice.last_mut().unwrap() = last_val;
979 }
980 let mut counts = [0i32; 24];
981 for _ in 0..10000 {
982 let mut arr: [usize; 4] = [0, 1, 2, 3];
983 arr.shuffle(&mut r);
984 let mut permutation = 0usize;
985 let mut pos_value = counts.len();
986 for i in 0..4 {
987 pos_value /= 4 - i;
988 let pos = arr.iter().position(|&x| x == i).unwrap();
989 assert!(pos < (4 - i));
990 permutation += pos * pos_value;
991 move_last(&mut arr, pos);
992 assert_eq!(arr[3], i);
993 }
994 for (i, &a) in arr.iter().enumerate() {
995 assert_eq!(a, i);
996 }
997 counts[permutation] += 1;
998 }
999 for count in counts.iter() {
1000 // Binomial(10000, 1/24) with average 416.667
1001 // Octave: binocdf(n, 10000, 1/24)
1002 // 99.9% chance samples lie within this range:
1003 assert!(352 <= *count && *count <= 483, "count: {}", count);
1004 }
1005 }
1006
1007 #[test]
test_partial_shuffle()1008 fn test_partial_shuffle() {
1009 let mut r = crate::test::rng(118);
1010
1011 let mut empty: [u32; 0] = [];
1012 let res = empty.partial_shuffle(&mut r, 10);
1013 assert_eq!((res.0.len(), res.1.len()), (0, 0));
1014
1015 let mut v = [1, 2, 3, 4, 5];
1016 let res = v.partial_shuffle(&mut r, 2);
1017 assert_eq!((res.0.len(), res.1.len()), (2, 3));
1018 assert!(res.0[0] != res.0[1]);
1019 // First elements are only modified if selected, so at least one isn't modified:
1020 assert!(res.1[0] == 1 || res.1[1] == 2 || res.1[2] == 3);
1021 }
1022
1023 #[test]
1024 #[cfg(feature = "alloc")]
test_sample_iter()1025 fn test_sample_iter() {
1026 let min_val = 1;
1027 let max_val = 100;
1028
1029 let mut r = crate::test::rng(401);
1030 let vals = (min_val..max_val).collect::<Vec<i32>>();
1031 let small_sample = vals.iter().choose_multiple(&mut r, 5);
1032 let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5);
1033
1034 assert_eq!(small_sample.len(), 5);
1035 assert_eq!(large_sample.len(), vals.len());
1036 // no randomization happens when amount >= len
1037 assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
1038
1039 assert!(small_sample
1040 .iter()
1041 .all(|e| { **e >= min_val && **e <= max_val }));
1042 }
1043
1044 #[test]
1045 #[cfg(feature = "alloc")]
1046 #[cfg_attr(miri, ignore)] // Miri is too slow
test_weighted()1047 fn test_weighted() {
1048 let mut r = crate::test::rng(406);
1049 const N_REPS: u32 = 3000;
1050 let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
1051 let total_weight = weights.iter().sum::<u32>() as f32;
1052
1053 let verify = |result: [i32; 14]| {
1054 for (i, count) in result.iter().enumerate() {
1055 let exp = (weights[i] * N_REPS) as f32 / total_weight;
1056 let mut err = (*count as f32 - exp).abs();
1057 if err != 0.0 {
1058 err /= exp;
1059 }
1060 assert!(err <= 0.25);
1061 }
1062 };
1063
1064 // choose_weighted
1065 fn get_weight<T>(item: &(u32, T)) -> u32 {
1066 item.0
1067 }
1068 let mut chosen = [0i32; 14];
1069 let mut items = [(0u32, 0usize); 14]; // (weight, index)
1070 for (i, item) in items.iter_mut().enumerate() {
1071 *item = (weights[i], i);
1072 }
1073 for _ in 0..N_REPS {
1074 let item = items.choose_weighted(&mut r, get_weight).unwrap();
1075 chosen[item.1] += 1;
1076 }
1077 verify(chosen);
1078
1079 // choose_weighted_mut
1080 let mut items = [(0u32, 0i32); 14]; // (weight, count)
1081 for (i, item) in items.iter_mut().enumerate() {
1082 *item = (weights[i], 0);
1083 }
1084 for _ in 0..N_REPS {
1085 items.choose_weighted_mut(&mut r, get_weight).unwrap().1 += 1;
1086 }
1087 for (ch, item) in chosen.iter_mut().zip(items.iter()) {
1088 *ch = item.1;
1089 }
1090 verify(chosen);
1091
1092 // Check error cases
1093 let empty_slice = &mut [10][0..0];
1094 assert_eq!(
1095 empty_slice.choose_weighted(&mut r, |_| 1),
1096 Err(WeightedError::NoItem)
1097 );
1098 assert_eq!(
1099 empty_slice.choose_weighted_mut(&mut r, |_| 1),
1100 Err(WeightedError::NoItem)
1101 );
1102 assert_eq!(
1103 ['x'].choose_weighted_mut(&mut r, |_| 0),
1104 Err(WeightedError::AllWeightsZero)
1105 );
1106 assert_eq!(
1107 [0, -1].choose_weighted_mut(&mut r, |x| *x),
1108 Err(WeightedError::InvalidWeight)
1109 );
1110 assert_eq!(
1111 [-1, 0].choose_weighted_mut(&mut r, |x| *x),
1112 Err(WeightedError::InvalidWeight)
1113 );
1114 }
1115
1116 #[test]
value_stability_choose()1117 fn value_stability_choose() {
1118 fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
1119 let mut rng = crate::test::rng(411);
1120 iter.choose(&mut rng)
1121 }
1122
1123 assert_eq!(choose([].iter().cloned()), None);
1124 assert_eq!(choose(0..100), Some(33));
1125 assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40));
1126 assert_eq!(
1127 choose(ChunkHintedIterator {
1128 iter: 0..100,
1129 chunk_size: 32,
1130 chunk_remaining: 32,
1131 hint_total_size: false,
1132 }),
1133 Some(39)
1134 );
1135 assert_eq!(
1136 choose(ChunkHintedIterator {
1137 iter: 0..100,
1138 chunk_size: 32,
1139 chunk_remaining: 32,
1140 hint_total_size: true,
1141 }),
1142 Some(39)
1143 );
1144 assert_eq!(
1145 choose(WindowHintedIterator {
1146 iter: 0..100,
1147 window_size: 32,
1148 hint_total_size: false,
1149 }),
1150 Some(90)
1151 );
1152 assert_eq!(
1153 choose(WindowHintedIterator {
1154 iter: 0..100,
1155 window_size: 32,
1156 hint_total_size: true,
1157 }),
1158 Some(90)
1159 );
1160 }
1161
1162 #[test]
value_stability_choose_stable()1163 fn value_stability_choose_stable() {
1164 fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
1165 let mut rng = crate::test::rng(411);
1166 iter.choose_stable(&mut rng)
1167 }
1168
1169 assert_eq!(choose([].iter().cloned()), None);
1170 assert_eq!(choose(0..100), Some(40));
1171 assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40));
1172 assert_eq!(
1173 choose(ChunkHintedIterator {
1174 iter: 0..100,
1175 chunk_size: 32,
1176 chunk_remaining: 32,
1177 hint_total_size: false,
1178 }),
1179 Some(40)
1180 );
1181 assert_eq!(
1182 choose(ChunkHintedIterator {
1183 iter: 0..100,
1184 chunk_size: 32,
1185 chunk_remaining: 32,
1186 hint_total_size: true,
1187 }),
1188 Some(40)
1189 );
1190 assert_eq!(
1191 choose(WindowHintedIterator {
1192 iter: 0..100,
1193 window_size: 32,
1194 hint_total_size: false,
1195 }),
1196 Some(40)
1197 );
1198 assert_eq!(
1199 choose(WindowHintedIterator {
1200 iter: 0..100,
1201 window_size: 32,
1202 hint_total_size: true,
1203 }),
1204 Some(40)
1205 );
1206 }
1207
1208 #[test]
value_stability_choose_multiple()1209 fn value_stability_choose_multiple() {
1210 fn do_test<I: Iterator<Item = u32>>(iter: I, v: &[u32]) {
1211 let mut rng = crate::test::rng(412);
1212 let mut buf = [0u32; 8];
1213 assert_eq!(iter.choose_multiple_fill(&mut rng, &mut buf), v.len());
1214 assert_eq!(&buf[0..v.len()], v);
1215 }
1216
1217 do_test(0..4, &[0, 1, 2, 3]);
1218 do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]);
1219 do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]);
1220
1221 #[cfg(feature = "alloc")]
1222 {
1223 fn do_test<I: Iterator<Item = u32>>(iter: I, v: &[u32]) {
1224 let mut rng = crate::test::rng(412);
1225 assert_eq!(iter.choose_multiple(&mut rng, v.len()), v);
1226 }
1227
1228 do_test(0..4, &[0, 1, 2, 3]);
1229 do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]);
1230 do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]);
1231 }
1232 }
1233
1234 #[test]
1235 #[cfg(feature = "std")]
test_multiple_weighted_edge_cases()1236 fn test_multiple_weighted_edge_cases() {
1237 use super::*;
1238
1239 let mut rng = crate::test::rng(413);
1240
1241 // Case 1: One of the weights is 0
1242 let choices = [('a', 2), ('b', 1), ('c', 0)];
1243 for _ in 0..100 {
1244 let result = choices
1245 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1246 .unwrap()
1247 .collect::<Vec<_>>();
1248
1249 assert_eq!(result.len(), 2);
1250 assert!(!result.iter().any(|val| val.0 == 'c'));
1251 }
1252
1253 // Case 2: All of the weights are 0
1254 let choices = [('a', 0), ('b', 0), ('c', 0)];
1255 let result = choices
1256 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1257 .unwrap()
1258 .collect::<Vec<_>>();
1259 assert_eq!(result.len(), 2);
1260
1261 // Case 3: Negative weights
1262 let choices = [('a', -1), ('b', 1), ('c', 1)];
1263 assert_eq!(
1264 choices
1265 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1266 .unwrap_err(),
1267 WeightedError::InvalidWeight
1268 );
1269
1270 // Case 4: Empty list
1271 let choices = [];
1272 let result = choices
1273 .choose_multiple_weighted(&mut rng, 0, |_: &()| 0)
1274 .unwrap()
1275 .collect::<Vec<_>>();
1276 assert_eq!(result.len(), 0);
1277
1278 // Case 5: NaN weights
1279 let choices = [('a', core::f64::NAN), ('b', 1.0), ('c', 1.0)];
1280 assert_eq!(
1281 choices
1282 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1283 .unwrap_err(),
1284 WeightedError::InvalidWeight
1285 );
1286
1287 // Case 6: +infinity weights
1288 let choices = [('a', core::f64::INFINITY), ('b', 1.0), ('c', 1.0)];
1289 for _ in 0..100 {
1290 let result = choices
1291 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1292 .unwrap()
1293 .collect::<Vec<_>>();
1294 assert_eq!(result.len(), 2);
1295 assert!(result.iter().any(|val| val.0 == 'a'));
1296 }
1297
1298 // Case 7: -infinity weights
1299 let choices = [('a', core::f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)];
1300 assert_eq!(
1301 choices
1302 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1303 .unwrap_err(),
1304 WeightedError::InvalidWeight
1305 );
1306
1307 // Case 8: -0 weights
1308 let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)];
1309 assert!(choices
1310 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1311 .is_ok());
1312 }
1313
1314 #[test]
1315 #[cfg(feature = "std")]
test_multiple_weighted_distributions()1316 fn test_multiple_weighted_distributions() {
1317 use super::*;
1318
1319 // The theoretical probabilities of the different outcomes are:
1320 // AB: 0.5 * 0.5 = 0.250
1321 // AC: 0.5 * 0.5 = 0.250
1322 // BA: 0.25 * 0.67 = 0.167
1323 // BC: 0.25 * 0.33 = 0.082
1324 // CA: 0.25 * 0.67 = 0.167
1325 // CB: 0.25 * 0.33 = 0.082
1326 let choices = [('a', 2), ('b', 1), ('c', 1)];
1327 let mut rng = crate::test::rng(414);
1328
1329 let mut results = [0i32; 3];
1330 let expected_results = [4167, 4167, 1666];
1331 for _ in 0..10000 {
1332 let result = choices
1333 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1334 .unwrap()
1335 .collect::<Vec<_>>();
1336
1337 assert_eq!(result.len(), 2);
1338
1339 match (result[0].0, result[1].0) {
1340 ('a', 'b') | ('b', 'a') => {
1341 results[0] += 1;
1342 }
1343 ('a', 'c') | ('c', 'a') => {
1344 results[1] += 1;
1345 }
1346 ('b', 'c') | ('c', 'b') => {
1347 results[2] += 1;
1348 }
1349 (_, _) => panic!("unexpected result"),
1350 }
1351 }
1352
1353 let mut diffs = results
1354 .iter()
1355 .zip(&expected_results)
1356 .map(|(a, b)| (a - b).abs());
1357 assert!(!diffs.any(|deviation| deviation > 100));
1358 }
1359 }
1360