1 // Copyright 2018 Developers of the Rand project.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8 
9 //! Sequence-related functionality
10 //!
11 //! This module provides:
12 //!
13 //! *   [`SliceRandom`] slice sampling and mutation
14 //! *   [`IteratorRandom`] iterator sampling
15 //! *   [`index::sample`] low-level API to choose multiple indices from
16 //!     `0..length`
17 //!
18 //! Also see:
19 //!
20 //! *   [`crate::distributions::WeightedIndex`] distribution which provides
21 //!     weighted index sampling.
22 //!
23 //! In order to make results reproducible across 32-64 bit architectures, all
24 //! `usize` indices are sampled as a `u32` where possible (also providing a
25 //! small performance boost in some cases).
26 
27 
28 #[cfg(feature = "alloc")]
29 #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
30 pub mod index;
31 
32 #[cfg(feature = "alloc")] use core::ops::Index;
33 
34 #[cfg(feature = "alloc")] use alloc::vec::Vec;
35 
36 #[cfg(feature = "alloc")]
37 use crate::distributions::uniform::{SampleBorrow, SampleUniform};
38 #[cfg(feature = "alloc")] use crate::distributions::WeightedError;
39 use crate::Rng;
40 
41 /// Extension trait on slices, providing random mutation and sampling methods.
42 ///
43 /// This trait is implemented on all `[T]` slice types, providing several
44 /// methods for choosing and shuffling elements. You must `use` this trait:
45 ///
46 /// ```
47 /// use rand::seq::SliceRandom;
48 ///
49 /// let mut rng = rand::thread_rng();
50 /// let mut bytes = "Hello, random!".to_string().into_bytes();
51 /// bytes.shuffle(&mut rng);
52 /// let str = String::from_utf8(bytes).unwrap();
53 /// println!("{}", str);
54 /// ```
55 /// Example output (non-deterministic):
56 /// ```none
57 /// l,nmroHado !le
58 /// ```
59 pub trait SliceRandom {
60     /// The element type.
61     type Item;
62 
63     /// Returns a reference to one random element of the slice, or `None` if the
64     /// slice is empty.
65     ///
66     /// For slices, complexity is `O(1)`.
67     ///
68     /// # Example
69     ///
70     /// ```
71     /// use rand::thread_rng;
72     /// use rand::seq::SliceRandom;
73     ///
74     /// let choices = [1, 2, 4, 8, 16, 32];
75     /// let mut rng = thread_rng();
76     /// println!("{:?}", choices.choose(&mut rng));
77     /// assert_eq!(choices[..0].choose(&mut rng), None);
78     /// ```
choose<R>(&self, rng: &mut R) -> Option<&Self::Item> where R: Rng + ?Sized79     fn choose<R>(&self, rng: &mut R) -> Option<&Self::Item>
80     where R: Rng + ?Sized;
81 
82     /// Returns a mutable reference to one random element of the slice, or
83     /// `None` if the slice is empty.
84     ///
85     /// For slices, complexity is `O(1)`.
choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Item> where R: Rng + ?Sized86     fn choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Item>
87     where R: Rng + ?Sized;
88 
89     /// Chooses `amount` elements from the slice at random, without repetition,
90     /// and in random order. The returned iterator is appropriate both for
91     /// collection into a `Vec` and filling an existing buffer (see example).
92     ///
93     /// In case this API is not sufficiently flexible, use [`index::sample`].
94     ///
95     /// For slices, complexity is the same as [`index::sample`].
96     ///
97     /// # Example
98     /// ```
99     /// use rand::seq::SliceRandom;
100     ///
101     /// let mut rng = &mut rand::thread_rng();
102     /// let sample = "Hello, audience!".as_bytes();
103     ///
104     /// // collect the results into a vector:
105     /// let v: Vec<u8> = sample.choose_multiple(&mut rng, 3).cloned().collect();
106     ///
107     /// // store in a buffer:
108     /// let mut buf = [0u8; 5];
109     /// for (b, slot) in sample.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) {
110     ///     *slot = *b;
111     /// }
112     /// ```
113     #[cfg(feature = "alloc")]
114     #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item> where R: Rng + ?Sized115     fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item>
116     where R: Rng + ?Sized;
117 
118     /// Similar to [`choose`], but where the likelihood of each outcome may be
119     /// specified.
120     ///
121     /// The specified function `weight` maps each item `x` to a relative
122     /// likelihood `weight(x)`. The probability of each item being selected is
123     /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
124     ///
125     /// For slices of length `n`, complexity is `O(n)`.
126     /// See also [`choose_weighted_mut`], [`distributions::weighted`].
127     ///
128     /// # Example
129     ///
130     /// ```
131     /// use rand::prelude::*;
132     ///
133     /// let choices = [('a', 2), ('b', 1), ('c', 1)];
134     /// let mut rng = thread_rng();
135     /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
136     /// println!("{:?}", choices.choose_weighted(&mut rng, |item| item.1).unwrap().0);
137     /// ```
138     /// [`choose`]: SliceRandom::choose
139     /// [`choose_weighted_mut`]: SliceRandom::choose_weighted_mut
140     /// [`distributions::weighted`]: crate::distributions::weighted
141     #[cfg(feature = "alloc")]
142     #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
choose_weighted<R, F, B, X>( &self, rng: &mut R, weight: F, ) -> Result<&Self::Item, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow<X>, X: SampleUniform + for<'a> ::core::ops::AddAssign<&'a X> + ::core::cmp::PartialOrd<X> + Clone + Default143     fn choose_weighted<R, F, B, X>(
144         &self, rng: &mut R, weight: F,
145     ) -> Result<&Self::Item, WeightedError>
146     where
147         R: Rng + ?Sized,
148         F: Fn(&Self::Item) -> B,
149         B: SampleBorrow<X>,
150         X: SampleUniform
151             + for<'a> ::core::ops::AddAssign<&'a X>
152             + ::core::cmp::PartialOrd<X>
153             + Clone
154             + Default;
155 
156     /// Similar to [`choose_mut`], but where the likelihood of each outcome may
157     /// be specified.
158     ///
159     /// The specified function `weight` maps each item `x` to a relative
160     /// likelihood `weight(x)`. The probability of each item being selected is
161     /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
162     ///
163     /// For slices of length `n`, complexity is `O(n)`.
164     /// See also [`choose_weighted`], [`distributions::weighted`].
165     ///
166     /// [`choose_mut`]: SliceRandom::choose_mut
167     /// [`choose_weighted`]: SliceRandom::choose_weighted
168     /// [`distributions::weighted`]: crate::distributions::weighted
169     #[cfg(feature = "alloc")]
170     #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
choose_weighted_mut<R, F, B, X>( &mut self, rng: &mut R, weight: F, ) -> Result<&mut Self::Item, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow<X>, X: SampleUniform + for<'a> ::core::ops::AddAssign<&'a X> + ::core::cmp::PartialOrd<X> + Clone + Default171     fn choose_weighted_mut<R, F, B, X>(
172         &mut self, rng: &mut R, weight: F,
173     ) -> Result<&mut Self::Item, WeightedError>
174     where
175         R: Rng + ?Sized,
176         F: Fn(&Self::Item) -> B,
177         B: SampleBorrow<X>,
178         X: SampleUniform
179             + for<'a> ::core::ops::AddAssign<&'a X>
180             + ::core::cmp::PartialOrd<X>
181             + Clone
182             + Default;
183 
184     /// Similar to [`choose_multiple`], but where the likelihood of each element's
185     /// inclusion in the output may be specified. The elements are returned in an
186     /// arbitrary, unspecified order.
187     ///
188     /// The specified function `weight` maps each item `x` to a relative
189     /// likelihood `weight(x)`. The probability of each item being selected is
190     /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
191     ///
192     /// If all of the weights are equal, even if they are all zero, each element has
193     /// an equal likelihood of being selected.
194     ///
195     /// The complexity of this method depends on the feature `partition_at_index`.
196     /// If the feature is enabled, then for slices of length `n`, the complexity
197     /// is `O(n)` space and `O(n)` time. Otherwise, the complexity is `O(n)` space and
198     /// `O(n * log amount)` time.
199     ///
200     /// # Example
201     ///
202     /// ```
203     /// use rand::prelude::*;
204     ///
205     /// let choices = [('a', 2), ('b', 1), ('c', 1)];
206     /// let mut rng = thread_rng();
207     /// // First Draw * Second Draw = total odds
208     /// // -----------------------
209     /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'b']` in some order.
210     /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'c']` in some order.
211     /// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order.
212     /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::<Vec<_>>());
213     /// ```
214     /// [`choose_multiple`]: SliceRandom::choose_multiple
215     //
216     // Note: this is feature-gated on std due to usage of f64::powf.
217     // If necessary, we may use alloc+libm as an alternative (see PR #1089).
218     #[cfg(feature = "std")]
219     #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
choose_multiple_weighted<R, F, X>( &self, rng: &mut R, amount: usize, weight: F, ) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> X, X: Into<f64>220     fn choose_multiple_weighted<R, F, X>(
221         &self, rng: &mut R, amount: usize, weight: F,
222     ) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError>
223     where
224         R: Rng + ?Sized,
225         F: Fn(&Self::Item) -> X,
226         X: Into<f64>;
227 
228     /// Shuffle a mutable slice in place.
229     ///
230     /// For slices of length `n`, complexity is `O(n)`.
231     ///
232     /// # Example
233     ///
234     /// ```
235     /// use rand::seq::SliceRandom;
236     /// use rand::thread_rng;
237     ///
238     /// let mut rng = thread_rng();
239     /// let mut y = [1, 2, 3, 4, 5];
240     /// println!("Unshuffled: {:?}", y);
241     /// y.shuffle(&mut rng);
242     /// println!("Shuffled:   {:?}", y);
243     /// ```
shuffle<R>(&mut self, rng: &mut R) where R: Rng + ?Sized244     fn shuffle<R>(&mut self, rng: &mut R)
245     where R: Rng + ?Sized;
246 
247     /// Shuffle a slice in place, but exit early.
248     ///
249     /// Returns two mutable slices from the source slice. The first contains
250     /// `amount` elements randomly permuted. The second has the remaining
251     /// elements that are not fully shuffled.
252     ///
253     /// This is an efficient method to select `amount` elements at random from
254     /// the slice, provided the slice may be mutated.
255     ///
256     /// If you only need to choose elements randomly and `amount > self.len()/2`
257     /// then you may improve performance by taking
258     /// `amount = values.len() - amount` and using only the second slice.
259     ///
260     /// If `amount` is greater than the number of elements in the slice, this
261     /// will perform a full shuffle.
262     ///
263     /// For slices, complexity is `O(m)` where `m = amount`.
partial_shuffle<R>( &mut self, rng: &mut R, amount: usize, ) -> (&mut [Self::Item], &mut [Self::Item]) where R: Rng + ?Sized264     fn partial_shuffle<R>(
265         &mut self, rng: &mut R, amount: usize,
266     ) -> (&mut [Self::Item], &mut [Self::Item])
267     where R: Rng + ?Sized;
268 }
269 
270 /// Extension trait on iterators, providing random sampling methods.
271 ///
272 /// This trait is implemented on all iterators `I` where `I: Iterator + Sized`
273 /// and provides methods for
274 /// choosing one or more elements. You must `use` this trait:
275 ///
276 /// ```
277 /// use rand::seq::IteratorRandom;
278 ///
279 /// let mut rng = rand::thread_rng();
280 ///
281 /// let faces = "������������";
282 /// println!("I am {}!", faces.chars().choose(&mut rng).unwrap());
283 /// ```
284 /// Example output (non-deterministic):
285 /// ```none
286 /// I am ��!
287 /// ```
288 pub trait IteratorRandom: Iterator + Sized {
289     /// Choose one element at random from the iterator.
290     ///
291     /// Returns `None` if and only if the iterator is empty.
292     ///
293     /// This method uses [`Iterator::size_hint`] for optimisation. With an
294     /// accurate hint and where [`Iterator::nth`] is a constant-time operation
295     /// this method can offer `O(1)` performance. Where no size hint is
296     /// available, complexity is `O(n)` where `n` is the iterator length.
297     /// Partial hints (where `lower > 0`) also improve performance.
298     ///
299     /// Note that the output values and the number of RNG samples used
300     /// depends on size hints. In particular, `Iterator` combinators that don't
301     /// change the values yielded but change the size hints may result in
302     /// `choose` returning different elements. If you want consistent results
303     /// and RNG usage consider using [`IteratorRandom::choose_stable`].
choose<R>(mut self, rng: &mut R) -> Option<Self::Item> where R: Rng + ?Sized304     fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item>
305     where R: Rng + ?Sized {
306         let (mut lower, mut upper) = self.size_hint();
307         let mut consumed = 0;
308         let mut result = None;
309 
310         // Handling for this condition outside the loop allows the optimizer to eliminate the loop
311         // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g.
312         // seq_iter_choose_from_1000.
313         if upper == Some(lower) {
314             return if lower == 0 {
315                 None
316             } else {
317                 self.nth(gen_index(rng, lower))
318             };
319         }
320 
321         // Continue until the iterator is exhausted
322         loop {
323             if lower > 1 {
324                 let ix = gen_index(rng, lower + consumed);
325                 let skip = if ix < lower {
326                     result = self.nth(ix);
327                     lower - (ix + 1)
328                 } else {
329                     lower
330                 };
331                 if upper == Some(lower) {
332                     return result;
333                 }
334                 consumed += lower;
335                 if skip > 0 {
336                     self.nth(skip - 1);
337                 }
338             } else {
339                 let elem = self.next();
340                 if elem.is_none() {
341                     return result;
342                 }
343                 consumed += 1;
344                 if gen_index(rng, consumed) == 0 {
345                     result = elem;
346                 }
347             }
348 
349             let hint = self.size_hint();
350             lower = hint.0;
351             upper = hint.1;
352         }
353     }
354 
355     /// Choose one element at random from the iterator.
356     ///
357     /// Returns `None` if and only if the iterator is empty.
358     ///
359     /// This method is very similar to [`choose`] except that the result
360     /// only depends on the length of the iterator and the values produced by
361     /// `rng`. Notably for any iterator of a given length this will make the
362     /// same requests to `rng` and if the same sequence of values are produced
363     /// the same index will be selected from `self`. This may be useful if you
364     /// need consistent results no matter what type of iterator you are working
365     /// with. If you do not need this stability prefer [`choose`].
366     ///
367     /// Note that this method still uses [`Iterator::size_hint`] to skip
368     /// constructing elements where possible, however the selection and `rng`
369     /// calls are the same in the face of this optimization. If you want to
370     /// force every element to be created regardless call `.inspect(|e| ())`.
371     ///
372     /// [`choose`]: IteratorRandom::choose
choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item> where R: Rng + ?Sized373     fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item>
374     where R: Rng + ?Sized {
375         let mut consumed = 0;
376         let mut result = None;
377 
378         loop {
379             // Currently the only way to skip elements is `nth()`. So we need to
380             // store what index to access next here.
381             // This should be replaced by `advance_by()` once it is stable:
382             // https://github.com/rust-lang/rust/issues/77404
383             let mut next = 0;
384 
385             let (lower, _) = self.size_hint();
386             if lower >= 2 {
387                 let highest_selected = (0..lower)
388                     .filter(|ix| gen_index(rng, consumed+ix+1) == 0)
389                     .last();
390 
391                 consumed += lower;
392                 next = lower;
393 
394                 if let Some(ix) = highest_selected {
395                     result = self.nth(ix);
396                     next -= ix + 1;
397                     debug_assert!(result.is_some(), "iterator shorter than size_hint().0");
398                 }
399             }
400 
401             let elem = self.nth(next);
402             if elem.is_none() {
403                 return result
404             }
405 
406             if gen_index(rng, consumed+1) == 0 {
407                 result = elem;
408             }
409             consumed += 1;
410         }
411     }
412 
413     /// Collects values at random from the iterator into a supplied buffer
414     /// until that buffer is filled.
415     ///
416     /// Although the elements are selected randomly, the order of elements in
417     /// the buffer is neither stable nor fully random. If random ordering is
418     /// desired, shuffle the result.
419     ///
420     /// Returns the number of elements added to the buffer. This equals the length
421     /// of the buffer unless the iterator contains insufficient elements, in which
422     /// case this equals the number of elements available.
423     ///
424     /// Complexity is `O(n)` where `n` is the length of the iterator.
425     /// For slices, prefer [`SliceRandom::choose_multiple`].
choose_multiple_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize where R: Rng + ?Sized426     fn choose_multiple_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize
427     where R: Rng + ?Sized {
428         let amount = buf.len();
429         let mut len = 0;
430         while len < amount {
431             if let Some(elem) = self.next() {
432                 buf[len] = elem;
433                 len += 1;
434             } else {
435                 // Iterator exhausted; stop early
436                 return len;
437             }
438         }
439 
440         // Continue, since the iterator was not exhausted
441         for (i, elem) in self.enumerate() {
442             let k = gen_index(rng, i + 1 + amount);
443             if let Some(slot) = buf.get_mut(k) {
444                 *slot = elem;
445             }
446         }
447         len
448     }
449 
450     /// Collects `amount` values at random from the iterator into a vector.
451     ///
452     /// This is equivalent to `choose_multiple_fill` except for the result type.
453     ///
454     /// Although the elements are selected randomly, the order of elements in
455     /// the buffer is neither stable nor fully random. If random ordering is
456     /// desired, shuffle the result.
457     ///
458     /// The length of the returned vector equals `amount` unless the iterator
459     /// contains insufficient elements, in which case it equals the number of
460     /// elements available.
461     ///
462     /// Complexity is `O(n)` where `n` is the length of the iterator.
463     /// For slices, prefer [`SliceRandom::choose_multiple`].
464     #[cfg(feature = "alloc")]
465     #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
choose_multiple<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item> where R: Rng + ?Sized466     fn choose_multiple<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item>
467     where R: Rng + ?Sized {
468         let mut reservoir = Vec::with_capacity(amount);
469         reservoir.extend(self.by_ref().take(amount));
470 
471         // Continue unless the iterator was exhausted
472         //
473         // note: this prevents iterators that "restart" from causing problems.
474         // If the iterator stops once, then so do we.
475         if reservoir.len() == amount {
476             for (i, elem) in self.enumerate() {
477                 let k = gen_index(rng, i + 1 + amount);
478                 if let Some(slot) = reservoir.get_mut(k) {
479                     *slot = elem;
480                 }
481             }
482         } else {
483             // Don't hang onto extra memory. There is a corner case where
484             // `amount` was much less than `self.len()`.
485             reservoir.shrink_to_fit();
486         }
487         reservoir
488     }
489 }
490 
491 
492 impl<T> SliceRandom for [T] {
493     type Item = T;
494 
choose<R>(&self, rng: &mut R) -> Option<&Self::Item> where R: Rng + ?Sized495     fn choose<R>(&self, rng: &mut R) -> Option<&Self::Item>
496     where R: Rng + ?Sized {
497         if self.is_empty() {
498             None
499         } else {
500             Some(&self[gen_index(rng, self.len())])
501         }
502     }
503 
choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Item> where R: Rng + ?Sized504     fn choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Item>
505     where R: Rng + ?Sized {
506         if self.is_empty() {
507             None
508         } else {
509             let len = self.len();
510             Some(&mut self[gen_index(rng, len)])
511         }
512     }
513 
514     #[cfg(feature = "alloc")]
choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item> where R: Rng + ?Sized515     fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item>
516     where R: Rng + ?Sized {
517         let amount = ::core::cmp::min(amount, self.len());
518         SliceChooseIter {
519             slice: self,
520             _phantom: Default::default(),
521             indices: index::sample(rng, self.len(), amount).into_iter(),
522         }
523     }
524 
525     #[cfg(feature = "alloc")]
choose_weighted<R, F, B, X>( &self, rng: &mut R, weight: F, ) -> Result<&Self::Item, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow<X>, X: SampleUniform + for<'a> ::core::ops::AddAssign<&'a X> + ::core::cmp::PartialOrd<X> + Clone + Default,526     fn choose_weighted<R, F, B, X>(
527         &self, rng: &mut R, weight: F,
528     ) -> Result<&Self::Item, WeightedError>
529     where
530         R: Rng + ?Sized,
531         F: Fn(&Self::Item) -> B,
532         B: SampleBorrow<X>,
533         X: SampleUniform
534             + for<'a> ::core::ops::AddAssign<&'a X>
535             + ::core::cmp::PartialOrd<X>
536             + Clone
537             + Default,
538     {
539         use crate::distributions::{Distribution, WeightedIndex};
540         let distr = WeightedIndex::new(self.iter().map(weight))?;
541         Ok(&self[distr.sample(rng)])
542     }
543 
544     #[cfg(feature = "alloc")]
choose_weighted_mut<R, F, B, X>( &mut self, rng: &mut R, weight: F, ) -> Result<&mut Self::Item, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow<X>, X: SampleUniform + for<'a> ::core::ops::AddAssign<&'a X> + ::core::cmp::PartialOrd<X> + Clone + Default,545     fn choose_weighted_mut<R, F, B, X>(
546         &mut self, rng: &mut R, weight: F,
547     ) -> Result<&mut Self::Item, WeightedError>
548     where
549         R: Rng + ?Sized,
550         F: Fn(&Self::Item) -> B,
551         B: SampleBorrow<X>,
552         X: SampleUniform
553             + for<'a> ::core::ops::AddAssign<&'a X>
554             + ::core::cmp::PartialOrd<X>
555             + Clone
556             + Default,
557     {
558         use crate::distributions::{Distribution, WeightedIndex};
559         let distr = WeightedIndex::new(self.iter().map(weight))?;
560         Ok(&mut self[distr.sample(rng)])
561     }
562 
563     #[cfg(feature = "std")]
choose_multiple_weighted<R, F, X>( &self, rng: &mut R, amount: usize, weight: F, ) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> X, X: Into<f64>,564     fn choose_multiple_weighted<R, F, X>(
565         &self, rng: &mut R, amount: usize, weight: F,
566     ) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError>
567     where
568         R: Rng + ?Sized,
569         F: Fn(&Self::Item) -> X,
570         X: Into<f64>,
571     {
572         let amount = ::core::cmp::min(amount, self.len());
573         Ok(SliceChooseIter {
574             slice: self,
575             _phantom: Default::default(),
576             indices: index::sample_weighted(
577                 rng,
578                 self.len(),
579                 |idx| weight(&self[idx]).into(),
580                 amount,
581             )?
582             .into_iter(),
583         })
584     }
585 
shuffle<R>(&mut self, rng: &mut R) where R: Rng + ?Sized586     fn shuffle<R>(&mut self, rng: &mut R)
587     where R: Rng + ?Sized {
588         for i in (1..self.len()).rev() {
589             // invariant: elements with index > i have been locked in place.
590             self.swap(i, gen_index(rng, i + 1));
591         }
592     }
593 
partial_shuffle<R>( &mut self, rng: &mut R, amount: usize, ) -> (&mut [Self::Item], &mut [Self::Item]) where R: Rng + ?Sized594     fn partial_shuffle<R>(
595         &mut self, rng: &mut R, amount: usize,
596     ) -> (&mut [Self::Item], &mut [Self::Item])
597     where R: Rng + ?Sized {
598         // This applies Durstenfeld's algorithm for the
599         // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm)
600         // for an unbiased permutation, but exits early after choosing `amount`
601         // elements.
602 
603         let len = self.len();
604         let end = if amount >= len { 0 } else { len - amount };
605 
606         for i in (end..len).rev() {
607             // invariant: elements with index > i have been locked in place.
608             self.swap(i, gen_index(rng, i + 1));
609         }
610         let r = self.split_at_mut(end);
611         (r.1, r.0)
612     }
613 }
614 
615 impl<I> IteratorRandom for I where I: Iterator + Sized {}
616 
617 
618 /// An iterator over multiple slice elements.
619 ///
620 /// This struct is created by
621 /// [`SliceRandom::choose_multiple`](trait.SliceRandom.html#tymethod.choose_multiple).
622 #[cfg(feature = "alloc")]
623 #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
624 #[derive(Debug)]
625 pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> {
626     slice: &'a S,
627     _phantom: ::core::marker::PhantomData<T>,
628     indices: index::IndexVecIntoIter,
629 }
630 
631 #[cfg(feature = "alloc")]
632 impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> Iterator for SliceChooseIter<'a, S, T> {
633     type Item = &'a T;
634 
next(&mut self) -> Option<Self::Item>635     fn next(&mut self) -> Option<Self::Item> {
636         // TODO: investigate using SliceIndex::get_unchecked when stable
637         self.indices.next().map(|i| &self.slice[i as usize])
638     }
639 
size_hint(&self) -> (usize, Option<usize>)640     fn size_hint(&self) -> (usize, Option<usize>) {
641         (self.indices.len(), Some(self.indices.len()))
642     }
643 }
644 
645 #[cfg(feature = "alloc")]
646 impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> ExactSizeIterator
647     for SliceChooseIter<'a, S, T>
648 {
len(&self) -> usize649     fn len(&self) -> usize {
650         self.indices.len()
651     }
652 }
653 
654 
655 // Sample a number uniformly between 0 and `ubound`. Uses 32-bit sampling where
656 // possible, primarily in order to produce the same output on 32-bit and 64-bit
657 // platforms.
658 #[inline]
gen_index<R: Rng + ?Sized>(rng: &mut R, ubound: usize) -> usize659 fn gen_index<R: Rng + ?Sized>(rng: &mut R, ubound: usize) -> usize {
660     if ubound <= (core::u32::MAX as usize) {
661         rng.gen_range(0..ubound as u32) as usize
662     } else {
663         rng.gen_range(0..ubound)
664     }
665 }
666 
667 
668 #[cfg(test)]
669 mod test {
670     use super::*;
671     #[cfg(feature = "alloc")] use crate::Rng;
672     #[cfg(all(feature = "alloc", not(feature = "std")))] use alloc::vec::Vec;
673 
674     #[test]
test_slice_choose()675     fn test_slice_choose() {
676         let mut r = crate::test::rng(107);
677         let chars = [
678             'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
679         ];
680         let mut chosen = [0i32; 14];
681         // The below all use a binomial distribution with n=1000, p=1/14.
682         // binocdf(40, 1000, 1/14) ~= 2e-5; 1-binocdf(106, ..) ~= 2e-5
683         for _ in 0..1000 {
684             let picked = *chars.choose(&mut r).unwrap();
685             chosen[(picked as usize) - ('a' as usize)] += 1;
686         }
687         for count in chosen.iter() {
688             assert!(40 < *count && *count < 106);
689         }
690 
691         chosen.iter_mut().for_each(|x| *x = 0);
692         for _ in 0..1000 {
693             *chosen.choose_mut(&mut r).unwrap() += 1;
694         }
695         for count in chosen.iter() {
696             assert!(40 < *count && *count < 106);
697         }
698 
699         let mut v: [isize; 0] = [];
700         assert_eq!(v.choose(&mut r), None);
701         assert_eq!(v.choose_mut(&mut r), None);
702     }
703 
704     #[test]
value_stability_slice()705     fn value_stability_slice() {
706         let mut r = crate::test::rng(413);
707         let chars = [
708             'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
709         ];
710         let mut nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
711 
712         assert_eq!(chars.choose(&mut r), Some(&'l'));
713         assert_eq!(nums.choose_mut(&mut r), Some(&mut 10));
714 
715         #[cfg(feature = "alloc")]
716         assert_eq!(
717             &chars
718                 .choose_multiple(&mut r, 8)
719                 .cloned()
720                 .collect::<Vec<char>>(),
721             &['d', 'm', 'b', 'n', 'c', 'k', 'h', 'e']
722         );
723 
724         #[cfg(feature = "alloc")]
725         assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'f'));
726         #[cfg(feature = "alloc")]
727         assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 5));
728 
729         let mut r = crate::test::rng(414);
730         nums.shuffle(&mut r);
731         assert_eq!(nums, [9, 5, 3, 10, 7, 12, 8, 11, 6, 4, 0, 2, 1]);
732         nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
733         let res = nums.partial_shuffle(&mut r, 6);
734         assert_eq!(res.0, &mut [7, 4, 8, 6, 9, 3]);
735         assert_eq!(res.1, &mut [0, 1, 2, 12, 11, 5, 10]);
736     }
737 
738     #[derive(Clone)]
739     struct UnhintedIterator<I: Iterator + Clone> {
740         iter: I,
741     }
742     impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
743         type Item = I::Item;
744 
next(&mut self) -> Option<Self::Item>745         fn next(&mut self) -> Option<Self::Item> {
746             self.iter.next()
747         }
748     }
749 
750     #[derive(Clone)]
751     struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
752         iter: I,
753         chunk_remaining: usize,
754         chunk_size: usize,
755         hint_total_size: bool,
756     }
757     impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> {
758         type Item = I::Item;
759 
next(&mut self) -> Option<Self::Item>760         fn next(&mut self) -> Option<Self::Item> {
761             if self.chunk_remaining == 0 {
762                 self.chunk_remaining = ::core::cmp::min(self.chunk_size, self.iter.len());
763             }
764             self.chunk_remaining = self.chunk_remaining.saturating_sub(1);
765 
766             self.iter.next()
767         }
768 
size_hint(&self) -> (usize, Option<usize>)769         fn size_hint(&self) -> (usize, Option<usize>) {
770             (
771                 self.chunk_remaining,
772                 if self.hint_total_size {
773                     Some(self.iter.len())
774                 } else {
775                     None
776                 },
777             )
778         }
779     }
780 
781     #[derive(Clone)]
782     struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
783         iter: I,
784         window_size: usize,
785         hint_total_size: bool,
786     }
787     impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
788         type Item = I::Item;
789 
next(&mut self) -> Option<Self::Item>790         fn next(&mut self) -> Option<Self::Item> {
791             self.iter.next()
792         }
793 
size_hint(&self) -> (usize, Option<usize>)794         fn size_hint(&self) -> (usize, Option<usize>) {
795             (
796                 ::core::cmp::min(self.iter.len(), self.window_size),
797                 if self.hint_total_size {
798                     Some(self.iter.len())
799                 } else {
800                     None
801                 },
802             )
803         }
804     }
805 
806     #[test]
807     #[cfg_attr(miri, ignore)] // Miri is too slow
test_iterator_choose()808     fn test_iterator_choose() {
809         let r = &mut crate::test::rng(109);
810         fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
811             let mut chosen = [0i32; 9];
812             for _ in 0..1000 {
813                 let picked = iter.clone().choose(r).unwrap();
814                 chosen[picked] += 1;
815             }
816             for count in chosen.iter() {
817                 // Samples should follow Binomial(1000, 1/9)
818                 // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
819                 // Note: have seen 153, which is unlikely but not impossible.
820                 assert!(
821                     72 < *count && *count < 154,
822                     "count not close to 1000/9: {}",
823                     count
824                 );
825             }
826         }
827 
828         test_iter(r, 0..9);
829         test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
830         #[cfg(feature = "alloc")]
831         test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
832         test_iter(r, UnhintedIterator { iter: 0..9 });
833         test_iter(r, ChunkHintedIterator {
834             iter: 0..9,
835             chunk_size: 4,
836             chunk_remaining: 4,
837             hint_total_size: false,
838         });
839         test_iter(r, ChunkHintedIterator {
840             iter: 0..9,
841             chunk_size: 4,
842             chunk_remaining: 4,
843             hint_total_size: true,
844         });
845         test_iter(r, WindowHintedIterator {
846             iter: 0..9,
847             window_size: 2,
848             hint_total_size: false,
849         });
850         test_iter(r, WindowHintedIterator {
851             iter: 0..9,
852             window_size: 2,
853             hint_total_size: true,
854         });
855 
856         assert_eq!((0..0).choose(r), None);
857         assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
858     }
859 
860     #[test]
861     #[cfg_attr(miri, ignore)] // Miri is too slow
test_iterator_choose_stable()862     fn test_iterator_choose_stable() {
863         let r = &mut crate::test::rng(109);
864         fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
865             let mut chosen = [0i32; 9];
866             for _ in 0..1000 {
867                 let picked = iter.clone().choose_stable(r).unwrap();
868                 chosen[picked] += 1;
869             }
870             for count in chosen.iter() {
871                 // Samples should follow Binomial(1000, 1/9)
872                 // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
873                 // Note: have seen 153, which is unlikely but not impossible.
874                 assert!(
875                     72 < *count && *count < 154,
876                     "count not close to 1000/9: {}",
877                     count
878                 );
879             }
880         }
881 
882         test_iter(r, 0..9);
883         test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
884         #[cfg(feature = "alloc")]
885         test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
886         test_iter(r, UnhintedIterator { iter: 0..9 });
887         test_iter(r, ChunkHintedIterator {
888             iter: 0..9,
889             chunk_size: 4,
890             chunk_remaining: 4,
891             hint_total_size: false,
892         });
893         test_iter(r, ChunkHintedIterator {
894             iter: 0..9,
895             chunk_size: 4,
896             chunk_remaining: 4,
897             hint_total_size: true,
898         });
899         test_iter(r, WindowHintedIterator {
900             iter: 0..9,
901             window_size: 2,
902             hint_total_size: false,
903         });
904         test_iter(r, WindowHintedIterator {
905             iter: 0..9,
906             window_size: 2,
907             hint_total_size: true,
908         });
909 
910         assert_eq!((0..0).choose(r), None);
911         assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
912     }
913 
914     #[test]
915     #[cfg_attr(miri, ignore)] // Miri is too slow
test_iterator_choose_stable_stability()916     fn test_iterator_choose_stable_stability() {
917         fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] {
918             let r = &mut crate::test::rng(109);
919             let mut chosen = [0i32; 9];
920             for _ in 0..1000 {
921                 let picked = iter.clone().choose_stable(r).unwrap();
922                 chosen[picked] += 1;
923             }
924             chosen
925         }
926 
927         let reference = test_iter(0..9);
928         assert_eq!(test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), reference);
929 
930         #[cfg(feature = "alloc")]
931         assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference);
932         assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference);
933         assert_eq!(test_iter(ChunkHintedIterator {
934             iter: 0..9,
935             chunk_size: 4,
936             chunk_remaining: 4,
937             hint_total_size: false,
938         }), reference);
939         assert_eq!(test_iter(ChunkHintedIterator {
940             iter: 0..9,
941             chunk_size: 4,
942             chunk_remaining: 4,
943             hint_total_size: true,
944         }), reference);
945         assert_eq!(test_iter(WindowHintedIterator {
946             iter: 0..9,
947             window_size: 2,
948             hint_total_size: false,
949         }), reference);
950         assert_eq!(test_iter(WindowHintedIterator {
951             iter: 0..9,
952             window_size: 2,
953             hint_total_size: true,
954         }), reference);
955     }
956 
957     #[test]
958     #[cfg_attr(miri, ignore)] // Miri is too slow
test_shuffle()959     fn test_shuffle() {
960         let mut r = crate::test::rng(108);
961         let empty: &mut [isize] = &mut [];
962         empty.shuffle(&mut r);
963         let mut one = [1];
964         one.shuffle(&mut r);
965         let b: &[_] = &[1];
966         assert_eq!(one, b);
967 
968         let mut two = [1, 2];
969         two.shuffle(&mut r);
970         assert!(two == [1, 2] || two == [2, 1]);
971 
972         fn move_last(slice: &mut [usize], pos: usize) {
973             // use slice[pos..].rotate_left(1); once we can use that
974             let last_val = slice[pos];
975             for i in pos..slice.len() - 1 {
976                 slice[i] = slice[i + 1];
977             }
978             *slice.last_mut().unwrap() = last_val;
979         }
980         let mut counts = [0i32; 24];
981         for _ in 0..10000 {
982             let mut arr: [usize; 4] = [0, 1, 2, 3];
983             arr.shuffle(&mut r);
984             let mut permutation = 0usize;
985             let mut pos_value = counts.len();
986             for i in 0..4 {
987                 pos_value /= 4 - i;
988                 let pos = arr.iter().position(|&x| x == i).unwrap();
989                 assert!(pos < (4 - i));
990                 permutation += pos * pos_value;
991                 move_last(&mut arr, pos);
992                 assert_eq!(arr[3], i);
993             }
994             for (i, &a) in arr.iter().enumerate() {
995                 assert_eq!(a, i);
996             }
997             counts[permutation] += 1;
998         }
999         for count in counts.iter() {
1000             // Binomial(10000, 1/24) with average 416.667
1001             // Octave: binocdf(n, 10000, 1/24)
1002             // 99.9% chance samples lie within this range:
1003             assert!(352 <= *count && *count <= 483, "count: {}", count);
1004         }
1005     }
1006 
1007     #[test]
test_partial_shuffle()1008     fn test_partial_shuffle() {
1009         let mut r = crate::test::rng(118);
1010 
1011         let mut empty: [u32; 0] = [];
1012         let res = empty.partial_shuffle(&mut r, 10);
1013         assert_eq!((res.0.len(), res.1.len()), (0, 0));
1014 
1015         let mut v = [1, 2, 3, 4, 5];
1016         let res = v.partial_shuffle(&mut r, 2);
1017         assert_eq!((res.0.len(), res.1.len()), (2, 3));
1018         assert!(res.0[0] != res.0[1]);
1019         // First elements are only modified if selected, so at least one isn't modified:
1020         assert!(res.1[0] == 1 || res.1[1] == 2 || res.1[2] == 3);
1021     }
1022 
1023     #[test]
1024     #[cfg(feature = "alloc")]
test_sample_iter()1025     fn test_sample_iter() {
1026         let min_val = 1;
1027         let max_val = 100;
1028 
1029         let mut r = crate::test::rng(401);
1030         let vals = (min_val..max_val).collect::<Vec<i32>>();
1031         let small_sample = vals.iter().choose_multiple(&mut r, 5);
1032         let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5);
1033 
1034         assert_eq!(small_sample.len(), 5);
1035         assert_eq!(large_sample.len(), vals.len());
1036         // no randomization happens when amount >= len
1037         assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
1038 
1039         assert!(small_sample
1040             .iter()
1041             .all(|e| { **e >= min_val && **e <= max_val }));
1042     }
1043 
1044     #[test]
1045     #[cfg(feature = "alloc")]
1046     #[cfg_attr(miri, ignore)] // Miri is too slow
test_weighted()1047     fn test_weighted() {
1048         let mut r = crate::test::rng(406);
1049         const N_REPS: u32 = 3000;
1050         let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
1051         let total_weight = weights.iter().sum::<u32>() as f32;
1052 
1053         let verify = |result: [i32; 14]| {
1054             for (i, count) in result.iter().enumerate() {
1055                 let exp = (weights[i] * N_REPS) as f32 / total_weight;
1056                 let mut err = (*count as f32 - exp).abs();
1057                 if err != 0.0 {
1058                     err /= exp;
1059                 }
1060                 assert!(err <= 0.25);
1061             }
1062         };
1063 
1064         // choose_weighted
1065         fn get_weight<T>(item: &(u32, T)) -> u32 {
1066             item.0
1067         }
1068         let mut chosen = [0i32; 14];
1069         let mut items = [(0u32, 0usize); 14]; // (weight, index)
1070         for (i, item) in items.iter_mut().enumerate() {
1071             *item = (weights[i], i);
1072         }
1073         for _ in 0..N_REPS {
1074             let item = items.choose_weighted(&mut r, get_weight).unwrap();
1075             chosen[item.1] += 1;
1076         }
1077         verify(chosen);
1078 
1079         // choose_weighted_mut
1080         let mut items = [(0u32, 0i32); 14]; // (weight, count)
1081         for (i, item) in items.iter_mut().enumerate() {
1082             *item = (weights[i], 0);
1083         }
1084         for _ in 0..N_REPS {
1085             items.choose_weighted_mut(&mut r, get_weight).unwrap().1 += 1;
1086         }
1087         for (ch, item) in chosen.iter_mut().zip(items.iter()) {
1088             *ch = item.1;
1089         }
1090         verify(chosen);
1091 
1092         // Check error cases
1093         let empty_slice = &mut [10][0..0];
1094         assert_eq!(
1095             empty_slice.choose_weighted(&mut r, |_| 1),
1096             Err(WeightedError::NoItem)
1097         );
1098         assert_eq!(
1099             empty_slice.choose_weighted_mut(&mut r, |_| 1),
1100             Err(WeightedError::NoItem)
1101         );
1102         assert_eq!(
1103             ['x'].choose_weighted_mut(&mut r, |_| 0),
1104             Err(WeightedError::AllWeightsZero)
1105         );
1106         assert_eq!(
1107             [0, -1].choose_weighted_mut(&mut r, |x| *x),
1108             Err(WeightedError::InvalidWeight)
1109         );
1110         assert_eq!(
1111             [-1, 0].choose_weighted_mut(&mut r, |x| *x),
1112             Err(WeightedError::InvalidWeight)
1113         );
1114     }
1115 
1116     #[test]
value_stability_choose()1117     fn value_stability_choose() {
1118         fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
1119             let mut rng = crate::test::rng(411);
1120             iter.choose(&mut rng)
1121         }
1122 
1123         assert_eq!(choose([].iter().cloned()), None);
1124         assert_eq!(choose(0..100), Some(33));
1125         assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40));
1126         assert_eq!(
1127             choose(ChunkHintedIterator {
1128                 iter: 0..100,
1129                 chunk_size: 32,
1130                 chunk_remaining: 32,
1131                 hint_total_size: false,
1132             }),
1133             Some(39)
1134         );
1135         assert_eq!(
1136             choose(ChunkHintedIterator {
1137                 iter: 0..100,
1138                 chunk_size: 32,
1139                 chunk_remaining: 32,
1140                 hint_total_size: true,
1141             }),
1142             Some(39)
1143         );
1144         assert_eq!(
1145             choose(WindowHintedIterator {
1146                 iter: 0..100,
1147                 window_size: 32,
1148                 hint_total_size: false,
1149             }),
1150             Some(90)
1151         );
1152         assert_eq!(
1153             choose(WindowHintedIterator {
1154                 iter: 0..100,
1155                 window_size: 32,
1156                 hint_total_size: true,
1157             }),
1158             Some(90)
1159         );
1160     }
1161 
1162     #[test]
value_stability_choose_stable()1163     fn value_stability_choose_stable() {
1164         fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
1165             let mut rng = crate::test::rng(411);
1166             iter.choose_stable(&mut rng)
1167         }
1168 
1169         assert_eq!(choose([].iter().cloned()), None);
1170         assert_eq!(choose(0..100), Some(40));
1171         assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40));
1172         assert_eq!(
1173             choose(ChunkHintedIterator {
1174                 iter: 0..100,
1175                 chunk_size: 32,
1176                 chunk_remaining: 32,
1177                 hint_total_size: false,
1178             }),
1179             Some(40)
1180         );
1181         assert_eq!(
1182             choose(ChunkHintedIterator {
1183                 iter: 0..100,
1184                 chunk_size: 32,
1185                 chunk_remaining: 32,
1186                 hint_total_size: true,
1187             }),
1188             Some(40)
1189         );
1190         assert_eq!(
1191             choose(WindowHintedIterator {
1192                 iter: 0..100,
1193                 window_size: 32,
1194                 hint_total_size: false,
1195             }),
1196             Some(40)
1197         );
1198         assert_eq!(
1199             choose(WindowHintedIterator {
1200                 iter: 0..100,
1201                 window_size: 32,
1202                 hint_total_size: true,
1203             }),
1204             Some(40)
1205         );
1206     }
1207 
1208     #[test]
value_stability_choose_multiple()1209     fn value_stability_choose_multiple() {
1210         fn do_test<I: Iterator<Item = u32>>(iter: I, v: &[u32]) {
1211             let mut rng = crate::test::rng(412);
1212             let mut buf = [0u32; 8];
1213             assert_eq!(iter.choose_multiple_fill(&mut rng, &mut buf), v.len());
1214             assert_eq!(&buf[0..v.len()], v);
1215         }
1216 
1217         do_test(0..4, &[0, 1, 2, 3]);
1218         do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]);
1219         do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]);
1220 
1221         #[cfg(feature = "alloc")]
1222         {
1223             fn do_test<I: Iterator<Item = u32>>(iter: I, v: &[u32]) {
1224                 let mut rng = crate::test::rng(412);
1225                 assert_eq!(iter.choose_multiple(&mut rng, v.len()), v);
1226             }
1227 
1228             do_test(0..4, &[0, 1, 2, 3]);
1229             do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]);
1230             do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]);
1231         }
1232     }
1233 
1234     #[test]
1235     #[cfg(feature = "std")]
test_multiple_weighted_edge_cases()1236     fn test_multiple_weighted_edge_cases() {
1237         use super::*;
1238 
1239         let mut rng = crate::test::rng(413);
1240 
1241         // Case 1: One of the weights is 0
1242         let choices = [('a', 2), ('b', 1), ('c', 0)];
1243         for _ in 0..100 {
1244             let result = choices
1245                 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1246                 .unwrap()
1247                 .collect::<Vec<_>>();
1248 
1249             assert_eq!(result.len(), 2);
1250             assert!(!result.iter().any(|val| val.0 == 'c'));
1251         }
1252 
1253         // Case 2: All of the weights are 0
1254         let choices = [('a', 0), ('b', 0), ('c', 0)];
1255 
1256         assert_eq!(choices
1257             .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1258             .unwrap().count(), 2);
1259 
1260         // Case 3: Negative weights
1261         let choices = [('a', -1), ('b', 1), ('c', 1)];
1262         assert_eq!(
1263             choices
1264                 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1265                 .unwrap_err(),
1266             WeightedError::InvalidWeight
1267         );
1268 
1269         // Case 4: Empty list
1270         let choices = [];
1271         assert_eq!(choices
1272             .choose_multiple_weighted(&mut rng, 0, |_: &()| 0)
1273             .unwrap().count(), 0);
1274 
1275         // Case 5: NaN weights
1276         let choices = [('a', core::f64::NAN), ('b', 1.0), ('c', 1.0)];
1277         assert_eq!(
1278             choices
1279                 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1280                 .unwrap_err(),
1281             WeightedError::InvalidWeight
1282         );
1283 
1284         // Case 6: +infinity weights
1285         let choices = [('a', core::f64::INFINITY), ('b', 1.0), ('c', 1.0)];
1286         for _ in 0..100 {
1287             let result = choices
1288                 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1289                 .unwrap()
1290                 .collect::<Vec<_>>();
1291             assert_eq!(result.len(), 2);
1292             assert!(result.iter().any(|val| val.0 == 'a'));
1293         }
1294 
1295         // Case 7: -infinity weights
1296         let choices = [('a', core::f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)];
1297         assert_eq!(
1298             choices
1299                 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1300                 .unwrap_err(),
1301             WeightedError::InvalidWeight
1302         );
1303 
1304         // Case 8: -0 weights
1305         let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)];
1306         assert!(choices
1307             .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1308             .is_ok());
1309     }
1310 
1311     #[test]
1312     #[cfg(feature = "std")]
test_multiple_weighted_distributions()1313     fn test_multiple_weighted_distributions() {
1314         use super::*;
1315 
1316         // The theoretical probabilities of the different outcomes are:
1317         // AB: 0.5  * 0.5  = 0.250
1318         // AC: 0.5  * 0.5  = 0.250
1319         // BA: 0.25 * 0.67 = 0.167
1320         // BC: 0.25 * 0.33 = 0.082
1321         // CA: 0.25 * 0.67 = 0.167
1322         // CB: 0.25 * 0.33 = 0.082
1323         let choices = [('a', 2), ('b', 1), ('c', 1)];
1324         let mut rng = crate::test::rng(414);
1325 
1326         let mut results = [0i32; 3];
1327         let expected_results = [4167, 4167, 1666];
1328         for _ in 0..10000 {
1329             let result = choices
1330                 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1331                 .unwrap()
1332                 .collect::<Vec<_>>();
1333 
1334             assert_eq!(result.len(), 2);
1335 
1336             match (result[0].0, result[1].0) {
1337                 ('a', 'b') | ('b', 'a') => {
1338                     results[0] += 1;
1339                 }
1340                 ('a', 'c') | ('c', 'a') => {
1341                     results[1] += 1;
1342                 }
1343                 ('b', 'c') | ('c', 'b') => {
1344                     results[2] += 1;
1345                 }
1346                 (_, _) => panic!("unexpected result"),
1347             }
1348         }
1349 
1350         let mut diffs = results
1351             .iter()
1352             .zip(&expected_results)
1353             .map(|(a, b)| (a - b).abs());
1354         assert!(!diffs.any(|deviation| deviation > 100));
1355     }
1356 }
1357