1 // Copyright 2018 Developers of the Rand project.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8
9 //! Sequence-related functionality
10 //!
11 //! This module provides:
12 //!
13 //! * [`SliceRandom`] slice sampling and mutation
14 //! * [`IteratorRandom`] iterator sampling
15 //! * [`index::sample`] low-level API to choose multiple indices from
16 //! `0..length`
17 //!
18 //! Also see:
19 //!
20 //! * [`crate::distributions::WeightedIndex`] distribution which provides
21 //! weighted index sampling.
22 //!
23 //! In order to make results reproducible across 32-64 bit architectures, all
24 //! `usize` indices are sampled as a `u32` where possible (also providing a
25 //! small performance boost in some cases).
26
27
28 #[cfg(feature = "alloc")]
29 #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
30 pub mod index;
31
32 #[cfg(feature = "alloc")] use core::ops::Index;
33
34 #[cfg(feature = "alloc")] use alloc::vec::Vec;
35
36 #[cfg(feature = "alloc")]
37 use crate::distributions::uniform::{SampleBorrow, SampleUniform};
38 #[cfg(feature = "alloc")] use crate::distributions::WeightedError;
39 use crate::Rng;
40
41 /// Extension trait on slices, providing random mutation and sampling methods.
42 ///
43 /// This trait is implemented on all `[T]` slice types, providing several
44 /// methods for choosing and shuffling elements. You must `use` this trait:
45 ///
46 /// ```
47 /// use rand::seq::SliceRandom;
48 ///
49 /// let mut rng = rand::thread_rng();
50 /// let mut bytes = "Hello, random!".to_string().into_bytes();
51 /// bytes.shuffle(&mut rng);
52 /// let str = String::from_utf8(bytes).unwrap();
53 /// println!("{}", str);
54 /// ```
55 /// Example output (non-deterministic):
56 /// ```none
57 /// l,nmroHado !le
58 /// ```
59 pub trait SliceRandom {
60 /// The element type.
61 type Item;
62
63 /// Returns a reference to one random element of the slice, or `None` if the
64 /// slice is empty.
65 ///
66 /// For slices, complexity is `O(1)`.
67 ///
68 /// # Example
69 ///
70 /// ```
71 /// use rand::thread_rng;
72 /// use rand::seq::SliceRandom;
73 ///
74 /// let choices = [1, 2, 4, 8, 16, 32];
75 /// let mut rng = thread_rng();
76 /// println!("{:?}", choices.choose(&mut rng));
77 /// assert_eq!(choices[..0].choose(&mut rng), None);
78 /// ```
choose<R>(&self, rng: &mut R) -> Option<&Self::Item> where R: Rng + ?Sized79 fn choose<R>(&self, rng: &mut R) -> Option<&Self::Item>
80 where R: Rng + ?Sized;
81
82 /// Returns a mutable reference to one random element of the slice, or
83 /// `None` if the slice is empty.
84 ///
85 /// For slices, complexity is `O(1)`.
choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Item> where R: Rng + ?Sized86 fn choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Item>
87 where R: Rng + ?Sized;
88
89 /// Chooses `amount` elements from the slice at random, without repetition,
90 /// and in random order. The returned iterator is appropriate both for
91 /// collection into a `Vec` and filling an existing buffer (see example).
92 ///
93 /// In case this API is not sufficiently flexible, use [`index::sample`].
94 ///
95 /// For slices, complexity is the same as [`index::sample`].
96 ///
97 /// # Example
98 /// ```
99 /// use rand::seq::SliceRandom;
100 ///
101 /// let mut rng = &mut rand::thread_rng();
102 /// let sample = "Hello, audience!".as_bytes();
103 ///
104 /// // collect the results into a vector:
105 /// let v: Vec<u8> = sample.choose_multiple(&mut rng, 3).cloned().collect();
106 ///
107 /// // store in a buffer:
108 /// let mut buf = [0u8; 5];
109 /// for (b, slot) in sample.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) {
110 /// *slot = *b;
111 /// }
112 /// ```
113 #[cfg(feature = "alloc")]
114 #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item> where R: Rng + ?Sized115 fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item>
116 where R: Rng + ?Sized;
117
118 /// Similar to [`choose`], but where the likelihood of each outcome may be
119 /// specified.
120 ///
121 /// The specified function `weight` maps each item `x` to a relative
122 /// likelihood `weight(x)`. The probability of each item being selected is
123 /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
124 ///
125 /// For slices of length `n`, complexity is `O(n)`.
126 /// See also [`choose_weighted_mut`], [`distributions::weighted`].
127 ///
128 /// # Example
129 ///
130 /// ```
131 /// use rand::prelude::*;
132 ///
133 /// let choices = [('a', 2), ('b', 1), ('c', 1)];
134 /// let mut rng = thread_rng();
135 /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
136 /// println!("{:?}", choices.choose_weighted(&mut rng, |item| item.1).unwrap().0);
137 /// ```
138 /// [`choose`]: SliceRandom::choose
139 /// [`choose_weighted_mut`]: SliceRandom::choose_weighted_mut
140 /// [`distributions::weighted`]: crate::distributions::weighted
141 #[cfg(feature = "alloc")]
142 #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
choose_weighted<R, F, B, X>( &self, rng: &mut R, weight: F, ) -> Result<&Self::Item, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow<X>, X: SampleUniform + for<'a> ::core::ops::AddAssign<&'a X> + ::core::cmp::PartialOrd<X> + Clone + Default143 fn choose_weighted<R, F, B, X>(
144 &self, rng: &mut R, weight: F,
145 ) -> Result<&Self::Item, WeightedError>
146 where
147 R: Rng + ?Sized,
148 F: Fn(&Self::Item) -> B,
149 B: SampleBorrow<X>,
150 X: SampleUniform
151 + for<'a> ::core::ops::AddAssign<&'a X>
152 + ::core::cmp::PartialOrd<X>
153 + Clone
154 + Default;
155
156 /// Similar to [`choose_mut`], but where the likelihood of each outcome may
157 /// be specified.
158 ///
159 /// The specified function `weight` maps each item `x` to a relative
160 /// likelihood `weight(x)`. The probability of each item being selected is
161 /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
162 ///
163 /// For slices of length `n`, complexity is `O(n)`.
164 /// See also [`choose_weighted`], [`distributions::weighted`].
165 ///
166 /// [`choose_mut`]: SliceRandom::choose_mut
167 /// [`choose_weighted`]: SliceRandom::choose_weighted
168 /// [`distributions::weighted`]: crate::distributions::weighted
169 #[cfg(feature = "alloc")]
170 #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
choose_weighted_mut<R, F, B, X>( &mut self, rng: &mut R, weight: F, ) -> Result<&mut Self::Item, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow<X>, X: SampleUniform + for<'a> ::core::ops::AddAssign<&'a X> + ::core::cmp::PartialOrd<X> + Clone + Default171 fn choose_weighted_mut<R, F, B, X>(
172 &mut self, rng: &mut R, weight: F,
173 ) -> Result<&mut Self::Item, WeightedError>
174 where
175 R: Rng + ?Sized,
176 F: Fn(&Self::Item) -> B,
177 B: SampleBorrow<X>,
178 X: SampleUniform
179 + for<'a> ::core::ops::AddAssign<&'a X>
180 + ::core::cmp::PartialOrd<X>
181 + Clone
182 + Default;
183
184 /// Similar to [`choose_multiple`], but where the likelihood of each element's
185 /// inclusion in the output may be specified. The elements are returned in an
186 /// arbitrary, unspecified order.
187 ///
188 /// The specified function `weight` maps each item `x` to a relative
189 /// likelihood `weight(x)`. The probability of each item being selected is
190 /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
191 ///
192 /// If all of the weights are equal, even if they are all zero, each element has
193 /// an equal likelihood of being selected.
194 ///
195 /// The complexity of this method depends on the feature `partition_at_index`.
196 /// If the feature is enabled, then for slices of length `n`, the complexity
197 /// is `O(n)` space and `O(n)` time. Otherwise, the complexity is `O(n)` space and
198 /// `O(n * log amount)` time.
199 ///
200 /// # Example
201 ///
202 /// ```
203 /// use rand::prelude::*;
204 ///
205 /// let choices = [('a', 2), ('b', 1), ('c', 1)];
206 /// let mut rng = thread_rng();
207 /// // First Draw * Second Draw = total odds
208 /// // -----------------------
209 /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'b']` in some order.
210 /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'c']` in some order.
211 /// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order.
212 /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::<Vec<_>>());
213 /// ```
214 /// [`choose_multiple`]: SliceRandom::choose_multiple
215 #[cfg(feature = "alloc")]
choose_multiple_weighted<R, F, X>( &self, rng: &mut R, amount: usize, weight: F, ) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> X, X: Into<f64>216 fn choose_multiple_weighted<R, F, X>(
217 &self, rng: &mut R, amount: usize, weight: F,
218 ) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError>
219 where
220 R: Rng + ?Sized,
221 F: Fn(&Self::Item) -> X,
222 X: Into<f64>;
223
224 /// Shuffle a mutable slice in place.
225 ///
226 /// For slices of length `n`, complexity is `O(n)`.
227 ///
228 /// # Example
229 ///
230 /// ```
231 /// use rand::seq::SliceRandom;
232 /// use rand::thread_rng;
233 ///
234 /// let mut rng = thread_rng();
235 /// let mut y = [1, 2, 3, 4, 5];
236 /// println!("Unshuffled: {:?}", y);
237 /// y.shuffle(&mut rng);
238 /// println!("Shuffled: {:?}", y);
239 /// ```
shuffle<R>(&mut self, rng: &mut R) where R: Rng + ?Sized240 fn shuffle<R>(&mut self, rng: &mut R)
241 where R: Rng + ?Sized;
242
243 /// Shuffle a slice in place, but exit early.
244 ///
245 /// Returns two mutable slices from the source slice. The first contains
246 /// `amount` elements randomly permuted. The second has the remaining
247 /// elements that are not fully shuffled.
248 ///
249 /// This is an efficient method to select `amount` elements at random from
250 /// the slice, provided the slice may be mutated.
251 ///
252 /// If you only need to choose elements randomly and `amount > self.len()/2`
253 /// then you may improve performance by taking
254 /// `amount = values.len() - amount` and using only the second slice.
255 ///
256 /// If `amount` is greater than the number of elements in the slice, this
257 /// will perform a full shuffle.
258 ///
259 /// For slices, complexity is `O(m)` where `m = amount`.
partial_shuffle<R>( &mut self, rng: &mut R, amount: usize, ) -> (&mut [Self::Item], &mut [Self::Item]) where R: Rng + ?Sized260 fn partial_shuffle<R>(
261 &mut self, rng: &mut R, amount: usize,
262 ) -> (&mut [Self::Item], &mut [Self::Item])
263 where R: Rng + ?Sized;
264 }
265
266 /// Extension trait on iterators, providing random sampling methods.
267 ///
268 /// This trait is implemented on all iterators `I` where `I: Iterator + Sized`
269 /// and provides methods for
270 /// choosing one or more elements. You must `use` this trait:
271 ///
272 /// ```
273 /// use rand::seq::IteratorRandom;
274 ///
275 /// let mut rng = rand::thread_rng();
276 ///
277 /// let faces = "";
278 /// println!("I am {}!", faces.chars().choose(&mut rng).unwrap());
279 /// ```
280 /// Example output (non-deterministic):
281 /// ```none
282 /// I am !
283 /// ```
284 pub trait IteratorRandom: Iterator + Sized {
285 /// Choose one element at random from the iterator.
286 ///
287 /// Returns `None` if and only if the iterator is empty.
288 ///
289 /// This method uses [`Iterator::size_hint`] for optimisation. With an
290 /// accurate hint and where [`Iterator::nth`] is a constant-time operation
291 /// this method can offer `O(1)` performance. Where no size hint is
292 /// available, complexity is `O(n)` where `n` is the iterator length.
293 /// Partial hints (where `lower > 0`) also improve performance.
294 ///
295 /// Note that the output values and the number of RNG samples used
296 /// depends on size hints. In particular, `Iterator` combinators that don't
297 /// change the values yielded but change the size hints may result in
298 /// `choose` returning different elements. If you want consistent results
299 /// and RNG usage consider using [`IteratorRandom::choose_stable`].
choose<R>(mut self, rng: &mut R) -> Option<Self::Item> where R: Rng + ?Sized300 fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item>
301 where R: Rng + ?Sized {
302 let (mut lower, mut upper) = self.size_hint();
303 let mut consumed = 0;
304 let mut result = None;
305
306 // Handling for this condition outside the loop allows the optimizer to eliminate the loop
307 // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g.
308 // seq_iter_choose_from_1000.
309 if upper == Some(lower) {
310 return if lower == 0 {
311 None
312 } else {
313 self.nth(gen_index(rng, lower))
314 };
315 }
316
317 // Continue until the iterator is exhausted
318 loop {
319 if lower > 1 {
320 let ix = gen_index(rng, lower + consumed);
321 let skip = if ix < lower {
322 result = self.nth(ix);
323 lower - (ix + 1)
324 } else {
325 lower
326 };
327 if upper == Some(lower) {
328 return result;
329 }
330 consumed += lower;
331 if skip > 0 {
332 self.nth(skip - 1);
333 }
334 } else {
335 let elem = self.next();
336 if elem.is_none() {
337 return result;
338 }
339 consumed += 1;
340 if gen_index(rng, consumed) == 0 {
341 result = elem;
342 }
343 }
344
345 let hint = self.size_hint();
346 lower = hint.0;
347 upper = hint.1;
348 }
349 }
350
351 /// Choose one element at random from the iterator.
352 ///
353 /// Returns `None` if and only if the iterator is empty.
354 ///
355 /// This method is very similar to [`choose`] except that the result
356 /// only depends on the length of the iterator and the values produced by
357 /// `rng`. Notably for any iterator of a given length this will make the
358 /// same requests to `rng` and if the same sequence of values are produced
359 /// the same index will be selected from `self`. This may be useful if you
360 /// need consistent results no matter what type of iterator you are working
361 /// with. If you do not need this stability prefer [`choose`].
362 ///
363 /// Note that this method still uses [`Iterator::size_hint`] to skip
364 /// constructing elements where possible, however the selection and `rng`
365 /// calls are the same in the face of this optimization. If you want to
366 /// force every element to be created regardless call `.inspect(|e| ())`.
367 ///
368 /// [`choose`]: IteratorRandom::choose
choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item> where R: Rng + ?Sized369 fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item>
370 where R: Rng + ?Sized {
371 let mut consumed = 0;
372 let mut result = None;
373
374 loop {
375 // Currently the only way to skip elements is `nth()`. So we need to
376 // store what index to access next here.
377 // This should be replaced by `advance_by()` once it is stable:
378 // https://github.com/rust-lang/rust/issues/77404
379 let mut next = 0;
380
381 let (lower, _) = self.size_hint();
382 if lower >= 2 {
383 let highest_selected = (0..lower)
384 .filter(|ix| gen_index(rng, consumed+ix+1) == 0)
385 .last();
386
387 consumed += lower;
388 next = lower;
389
390 if let Some(ix) = highest_selected {
391 result = self.nth(ix);
392 next -= ix + 1;
393 debug_assert!(result.is_some(), "iterator shorter than size_hint().0");
394 }
395 }
396
397 let elem = self.nth(next);
398 if elem.is_none() {
399 return result
400 }
401
402 if gen_index(rng, consumed+1) == 0 {
403 result = elem;
404 }
405 consumed += 1;
406 }
407 }
408
409 /// Collects values at random from the iterator into a supplied buffer
410 /// until that buffer is filled.
411 ///
412 /// Although the elements are selected randomly, the order of elements in
413 /// the buffer is neither stable nor fully random. If random ordering is
414 /// desired, shuffle the result.
415 ///
416 /// Returns the number of elements added to the buffer. This equals the length
417 /// of the buffer unless the iterator contains insufficient elements, in which
418 /// case this equals the number of elements available.
419 ///
420 /// Complexity is `O(n)` where `n` is the length of the iterator.
421 /// For slices, prefer [`SliceRandom::choose_multiple`].
choose_multiple_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize where R: Rng + ?Sized422 fn choose_multiple_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize
423 where R: Rng + ?Sized {
424 let amount = buf.len();
425 let mut len = 0;
426 while len < amount {
427 if let Some(elem) = self.next() {
428 buf[len] = elem;
429 len += 1;
430 } else {
431 // Iterator exhausted; stop early
432 return len;
433 }
434 }
435
436 // Continue, since the iterator was not exhausted
437 for (i, elem) in self.enumerate() {
438 let k = gen_index(rng, i + 1 + amount);
439 if let Some(slot) = buf.get_mut(k) {
440 *slot = elem;
441 }
442 }
443 len
444 }
445
446 /// Collects `amount` values at random from the iterator into a vector.
447 ///
448 /// This is equivalent to `choose_multiple_fill` except for the result type.
449 ///
450 /// Although the elements are selected randomly, the order of elements in
451 /// the buffer is neither stable nor fully random. If random ordering is
452 /// desired, shuffle the result.
453 ///
454 /// The length of the returned vector equals `amount` unless the iterator
455 /// contains insufficient elements, in which case it equals the number of
456 /// elements available.
457 ///
458 /// Complexity is `O(n)` where `n` is the length of the iterator.
459 /// For slices, prefer [`SliceRandom::choose_multiple`].
460 #[cfg(feature = "alloc")]
461 #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
choose_multiple<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item> where R: Rng + ?Sized462 fn choose_multiple<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item>
463 where R: Rng + ?Sized {
464 let mut reservoir = Vec::with_capacity(amount);
465 reservoir.extend(self.by_ref().take(amount));
466
467 // Continue unless the iterator was exhausted
468 //
469 // note: this prevents iterators that "restart" from causing problems.
470 // If the iterator stops once, then so do we.
471 if reservoir.len() == amount {
472 for (i, elem) in self.enumerate() {
473 let k = gen_index(rng, i + 1 + amount);
474 if let Some(slot) = reservoir.get_mut(k) {
475 *slot = elem;
476 }
477 }
478 } else {
479 // Don't hang onto extra memory. There is a corner case where
480 // `amount` was much less than `self.len()`.
481 reservoir.shrink_to_fit();
482 }
483 reservoir
484 }
485 }
486
487
488 impl<T> SliceRandom for [T] {
489 type Item = T;
490
choose<R>(&self, rng: &mut R) -> Option<&Self::Item> where R: Rng + ?Sized491 fn choose<R>(&self, rng: &mut R) -> Option<&Self::Item>
492 where R: Rng + ?Sized {
493 if self.is_empty() {
494 None
495 } else {
496 Some(&self[gen_index(rng, self.len())])
497 }
498 }
499
choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Item> where R: Rng + ?Sized500 fn choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Item>
501 where R: Rng + ?Sized {
502 if self.is_empty() {
503 None
504 } else {
505 let len = self.len();
506 Some(&mut self[gen_index(rng, len)])
507 }
508 }
509
510 #[cfg(feature = "alloc")]
choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item> where R: Rng + ?Sized511 fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item>
512 where R: Rng + ?Sized {
513 let amount = ::core::cmp::min(amount, self.len());
514 SliceChooseIter {
515 slice: self,
516 _phantom: Default::default(),
517 indices: index::sample(rng, self.len(), amount).into_iter(),
518 }
519 }
520
521 #[cfg(feature = "alloc")]
choose_weighted<R, F, B, X>( &self, rng: &mut R, weight: F, ) -> Result<&Self::Item, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow<X>, X: SampleUniform + for<'a> ::core::ops::AddAssign<&'a X> + ::core::cmp::PartialOrd<X> + Clone + Default,522 fn choose_weighted<R, F, B, X>(
523 &self, rng: &mut R, weight: F,
524 ) -> Result<&Self::Item, WeightedError>
525 where
526 R: Rng + ?Sized,
527 F: Fn(&Self::Item) -> B,
528 B: SampleBorrow<X>,
529 X: SampleUniform
530 + for<'a> ::core::ops::AddAssign<&'a X>
531 + ::core::cmp::PartialOrd<X>
532 + Clone
533 + Default,
534 {
535 use crate::distributions::{Distribution, WeightedIndex};
536 let distr = WeightedIndex::new(self.iter().map(weight))?;
537 Ok(&self[distr.sample(rng)])
538 }
539
540 #[cfg(feature = "alloc")]
choose_weighted_mut<R, F, B, X>( &mut self, rng: &mut R, weight: F, ) -> Result<&mut Self::Item, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow<X>, X: SampleUniform + for<'a> ::core::ops::AddAssign<&'a X> + ::core::cmp::PartialOrd<X> + Clone + Default,541 fn choose_weighted_mut<R, F, B, X>(
542 &mut self, rng: &mut R, weight: F,
543 ) -> Result<&mut Self::Item, WeightedError>
544 where
545 R: Rng + ?Sized,
546 F: Fn(&Self::Item) -> B,
547 B: SampleBorrow<X>,
548 X: SampleUniform
549 + for<'a> ::core::ops::AddAssign<&'a X>
550 + ::core::cmp::PartialOrd<X>
551 + Clone
552 + Default,
553 {
554 use crate::distributions::{Distribution, WeightedIndex};
555 let distr = WeightedIndex::new(self.iter().map(weight))?;
556 Ok(&mut self[distr.sample(rng)])
557 }
558
559 #[cfg(feature = "alloc")]
choose_multiple_weighted<R, F, X>( &self, rng: &mut R, amount: usize, weight: F, ) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> X, X: Into<f64>,560 fn choose_multiple_weighted<R, F, X>(
561 &self, rng: &mut R, amount: usize, weight: F,
562 ) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError>
563 where
564 R: Rng + ?Sized,
565 F: Fn(&Self::Item) -> X,
566 X: Into<f64>,
567 {
568 let amount = ::core::cmp::min(amount, self.len());
569 Ok(SliceChooseIter {
570 slice: self,
571 _phantom: Default::default(),
572 indices: index::sample_weighted(
573 rng,
574 self.len(),
575 |idx| weight(&self[idx]).into(),
576 amount,
577 )?
578 .into_iter(),
579 })
580 }
581
shuffle<R>(&mut self, rng: &mut R) where R: Rng + ?Sized582 fn shuffle<R>(&mut self, rng: &mut R)
583 where R: Rng + ?Sized {
584 for i in (1..self.len()).rev() {
585 // invariant: elements with index > i have been locked in place.
586 self.swap(i, gen_index(rng, i + 1));
587 }
588 }
589
partial_shuffle<R>( &mut self, rng: &mut R, amount: usize, ) -> (&mut [Self::Item], &mut [Self::Item]) where R: Rng + ?Sized590 fn partial_shuffle<R>(
591 &mut self, rng: &mut R, amount: usize,
592 ) -> (&mut [Self::Item], &mut [Self::Item])
593 where R: Rng + ?Sized {
594 // This applies Durstenfeld's algorithm for the
595 // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm)
596 // for an unbiased permutation, but exits early after choosing `amount`
597 // elements.
598
599 let len = self.len();
600 let end = if amount >= len { 0 } else { len - amount };
601
602 for i in (end..len).rev() {
603 // invariant: elements with index > i have been locked in place.
604 self.swap(i, gen_index(rng, i + 1));
605 }
606 let r = self.split_at_mut(end);
607 (r.1, r.0)
608 }
609 }
610
611 impl<I> IteratorRandom for I where I: Iterator + Sized {}
612
613
614 /// An iterator over multiple slice elements.
615 ///
616 /// This struct is created by
617 /// [`SliceRandom::choose_multiple`](trait.SliceRandom.html#tymethod.choose_multiple).
618 #[cfg(feature = "alloc")]
619 #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
620 #[derive(Debug)]
621 pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> {
622 slice: &'a S,
623 _phantom: ::core::marker::PhantomData<T>,
624 indices: index::IndexVecIntoIter,
625 }
626
627 #[cfg(feature = "alloc")]
628 impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> Iterator for SliceChooseIter<'a, S, T> {
629 type Item = &'a T;
630
next(&mut self) -> Option<Self::Item>631 fn next(&mut self) -> Option<Self::Item> {
632 // TODO: investigate using SliceIndex::get_unchecked when stable
633 self.indices.next().map(|i| &self.slice[i as usize])
634 }
635
size_hint(&self) -> (usize, Option<usize>)636 fn size_hint(&self) -> (usize, Option<usize>) {
637 (self.indices.len(), Some(self.indices.len()))
638 }
639 }
640
641 #[cfg(feature = "alloc")]
642 impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> ExactSizeIterator
643 for SliceChooseIter<'a, S, T>
644 {
len(&self) -> usize645 fn len(&self) -> usize {
646 self.indices.len()
647 }
648 }
649
650
651 // Sample a number uniformly between 0 and `ubound`. Uses 32-bit sampling where
652 // possible, primarily in order to produce the same output on 32-bit and 64-bit
653 // platforms.
654 #[inline]
gen_index<R: Rng + ?Sized>(rng: &mut R, ubound: usize) -> usize655 fn gen_index<R: Rng + ?Sized>(rng: &mut R, ubound: usize) -> usize {
656 if ubound <= (core::u32::MAX as usize) {
657 rng.gen_range(0..ubound as u32) as usize
658 } else {
659 rng.gen_range(0..ubound)
660 }
661 }
662
663
664 #[cfg(test)]
665 mod test {
666 use super::*;
667 #[cfg(feature = "alloc")] use crate::Rng;
668 #[cfg(all(feature = "alloc", not(feature = "std")))] use alloc::vec::Vec;
669
670 #[test]
test_slice_choose()671 fn test_slice_choose() {
672 let mut r = crate::test::rng(107);
673 let chars = [
674 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
675 ];
676 let mut chosen = [0i32; 14];
677 // The below all use a binomial distribution with n=1000, p=1/14.
678 // binocdf(40, 1000, 1/14) ~= 2e-5; 1-binocdf(106, ..) ~= 2e-5
679 for _ in 0..1000 {
680 let picked = *chars.choose(&mut r).unwrap();
681 chosen[(picked as usize) - ('a' as usize)] += 1;
682 }
683 for count in chosen.iter() {
684 assert!(40 < *count && *count < 106);
685 }
686
687 chosen.iter_mut().for_each(|x| *x = 0);
688 for _ in 0..1000 {
689 *chosen.choose_mut(&mut r).unwrap() += 1;
690 }
691 for count in chosen.iter() {
692 assert!(40 < *count && *count < 106);
693 }
694
695 let mut v: [isize; 0] = [];
696 assert_eq!(v.choose(&mut r), None);
697 assert_eq!(v.choose_mut(&mut r), None);
698 }
699
700 #[test]
value_stability_slice()701 fn value_stability_slice() {
702 let mut r = crate::test::rng(413);
703 let chars = [
704 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
705 ];
706 let mut nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
707
708 assert_eq!(chars.choose(&mut r), Some(&'l'));
709 assert_eq!(nums.choose_mut(&mut r), Some(&mut 10));
710
711 #[cfg(feature = "alloc")]
712 assert_eq!(
713 &chars
714 .choose_multiple(&mut r, 8)
715 .cloned()
716 .collect::<Vec<char>>(),
717 &['d', 'm', 'b', 'n', 'c', 'k', 'h', 'e']
718 );
719
720 #[cfg(feature = "alloc")]
721 assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'f'));
722 #[cfg(feature = "alloc")]
723 assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 5));
724
725 let mut r = crate::test::rng(414);
726 nums.shuffle(&mut r);
727 assert_eq!(nums, [9, 5, 3, 10, 7, 12, 8, 11, 6, 4, 0, 2, 1]);
728 nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
729 let res = nums.partial_shuffle(&mut r, 6);
730 assert_eq!(res.0, &mut [7, 4, 8, 6, 9, 3]);
731 assert_eq!(res.1, &mut [0, 1, 2, 12, 11, 5, 10]);
732 }
733
734 #[derive(Clone)]
735 struct UnhintedIterator<I: Iterator + Clone> {
736 iter: I,
737 }
738 impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
739 type Item = I::Item;
740
next(&mut self) -> Option<Self::Item>741 fn next(&mut self) -> Option<Self::Item> {
742 self.iter.next()
743 }
744 }
745
746 #[derive(Clone)]
747 struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
748 iter: I,
749 chunk_remaining: usize,
750 chunk_size: usize,
751 hint_total_size: bool,
752 }
753 impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> {
754 type Item = I::Item;
755
next(&mut self) -> Option<Self::Item>756 fn next(&mut self) -> Option<Self::Item> {
757 if self.chunk_remaining == 0 {
758 self.chunk_remaining = ::core::cmp::min(self.chunk_size, self.iter.len());
759 }
760 self.chunk_remaining = self.chunk_remaining.saturating_sub(1);
761
762 self.iter.next()
763 }
764
size_hint(&self) -> (usize, Option<usize>)765 fn size_hint(&self) -> (usize, Option<usize>) {
766 (
767 self.chunk_remaining,
768 if self.hint_total_size {
769 Some(self.iter.len())
770 } else {
771 None
772 },
773 )
774 }
775 }
776
777 #[derive(Clone)]
778 struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
779 iter: I,
780 window_size: usize,
781 hint_total_size: bool,
782 }
783 impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
784 type Item = I::Item;
785
next(&mut self) -> Option<Self::Item>786 fn next(&mut self) -> Option<Self::Item> {
787 self.iter.next()
788 }
789
size_hint(&self) -> (usize, Option<usize>)790 fn size_hint(&self) -> (usize, Option<usize>) {
791 (
792 ::core::cmp::min(self.iter.len(), self.window_size),
793 if self.hint_total_size {
794 Some(self.iter.len())
795 } else {
796 None
797 },
798 )
799 }
800 }
801
802 #[test]
803 #[cfg_attr(miri, ignore)] // Miri is too slow
test_iterator_choose()804 fn test_iterator_choose() {
805 let r = &mut crate::test::rng(109);
806 fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
807 let mut chosen = [0i32; 9];
808 for _ in 0..1000 {
809 let picked = iter.clone().choose(r).unwrap();
810 chosen[picked] += 1;
811 }
812 for count in chosen.iter() {
813 // Samples should follow Binomial(1000, 1/9)
814 // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
815 // Note: have seen 153, which is unlikely but not impossible.
816 assert!(
817 72 < *count && *count < 154,
818 "count not close to 1000/9: {}",
819 count
820 );
821 }
822 }
823
824 test_iter(r, 0..9);
825 test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
826 #[cfg(feature = "alloc")]
827 test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
828 test_iter(r, UnhintedIterator { iter: 0..9 });
829 test_iter(r, ChunkHintedIterator {
830 iter: 0..9,
831 chunk_size: 4,
832 chunk_remaining: 4,
833 hint_total_size: false,
834 });
835 test_iter(r, ChunkHintedIterator {
836 iter: 0..9,
837 chunk_size: 4,
838 chunk_remaining: 4,
839 hint_total_size: true,
840 });
841 test_iter(r, WindowHintedIterator {
842 iter: 0..9,
843 window_size: 2,
844 hint_total_size: false,
845 });
846 test_iter(r, WindowHintedIterator {
847 iter: 0..9,
848 window_size: 2,
849 hint_total_size: true,
850 });
851
852 assert_eq!((0..0).choose(r), None);
853 assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
854 }
855
856 #[test]
857 #[cfg_attr(miri, ignore)] // Miri is too slow
test_iterator_choose_stable()858 fn test_iterator_choose_stable() {
859 let r = &mut crate::test::rng(109);
860 fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
861 let mut chosen = [0i32; 9];
862 for _ in 0..1000 {
863 let picked = iter.clone().choose_stable(r).unwrap();
864 chosen[picked] += 1;
865 }
866 for count in chosen.iter() {
867 // Samples should follow Binomial(1000, 1/9)
868 // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
869 // Note: have seen 153, which is unlikely but not impossible.
870 assert!(
871 72 < *count && *count < 154,
872 "count not close to 1000/9: {}",
873 count
874 );
875 }
876 }
877
878 test_iter(r, 0..9);
879 test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
880 #[cfg(feature = "alloc")]
881 test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
882 test_iter(r, UnhintedIterator { iter: 0..9 });
883 test_iter(r, ChunkHintedIterator {
884 iter: 0..9,
885 chunk_size: 4,
886 chunk_remaining: 4,
887 hint_total_size: false,
888 });
889 test_iter(r, ChunkHintedIterator {
890 iter: 0..9,
891 chunk_size: 4,
892 chunk_remaining: 4,
893 hint_total_size: true,
894 });
895 test_iter(r, WindowHintedIterator {
896 iter: 0..9,
897 window_size: 2,
898 hint_total_size: false,
899 });
900 test_iter(r, WindowHintedIterator {
901 iter: 0..9,
902 window_size: 2,
903 hint_total_size: true,
904 });
905
906 assert_eq!((0..0).choose(r), None);
907 assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
908 }
909
910 #[test]
911 #[cfg_attr(miri, ignore)] // Miri is too slow
test_iterator_choose_stable_stability()912 fn test_iterator_choose_stable_stability() {
913 fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] {
914 let r = &mut crate::test::rng(109);
915 let mut chosen = [0i32; 9];
916 for _ in 0..1000 {
917 let picked = iter.clone().choose_stable(r).unwrap();
918 chosen[picked] += 1;
919 }
920 chosen
921 }
922
923 let reference = test_iter(0..9);
924 assert_eq!(test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), reference);
925
926 #[cfg(feature = "alloc")]
927 assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference);
928 assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference);
929 assert_eq!(test_iter(ChunkHintedIterator {
930 iter: 0..9,
931 chunk_size: 4,
932 chunk_remaining: 4,
933 hint_total_size: false,
934 }), reference);
935 assert_eq!(test_iter(ChunkHintedIterator {
936 iter: 0..9,
937 chunk_size: 4,
938 chunk_remaining: 4,
939 hint_total_size: true,
940 }), reference);
941 assert_eq!(test_iter(WindowHintedIterator {
942 iter: 0..9,
943 window_size: 2,
944 hint_total_size: false,
945 }), reference);
946 assert_eq!(test_iter(WindowHintedIterator {
947 iter: 0..9,
948 window_size: 2,
949 hint_total_size: true,
950 }), reference);
951 }
952
953 #[test]
954 #[cfg_attr(miri, ignore)] // Miri is too slow
test_shuffle()955 fn test_shuffle() {
956 let mut r = crate::test::rng(108);
957 let empty: &mut [isize] = &mut [];
958 empty.shuffle(&mut r);
959 let mut one = [1];
960 one.shuffle(&mut r);
961 let b: &[_] = &[1];
962 assert_eq!(one, b);
963
964 let mut two = [1, 2];
965 two.shuffle(&mut r);
966 assert!(two == [1, 2] || two == [2, 1]);
967
968 fn move_last(slice: &mut [usize], pos: usize) {
969 // use slice[pos..].rotate_left(1); once we can use that
970 let last_val = slice[pos];
971 for i in pos..slice.len() - 1 {
972 slice[i] = slice[i + 1];
973 }
974 *slice.last_mut().unwrap() = last_val;
975 }
976 let mut counts = [0i32; 24];
977 for _ in 0..10000 {
978 let mut arr: [usize; 4] = [0, 1, 2, 3];
979 arr.shuffle(&mut r);
980 let mut permutation = 0usize;
981 let mut pos_value = counts.len();
982 for i in 0..4 {
983 pos_value /= 4 - i;
984 let pos = arr.iter().position(|&x| x == i).unwrap();
985 assert!(pos < (4 - i));
986 permutation += pos * pos_value;
987 move_last(&mut arr, pos);
988 assert_eq!(arr[3], i);
989 }
990 for i in 0..4 {
991 assert_eq!(arr[i], i);
992 }
993 counts[permutation] += 1;
994 }
995 for count in counts.iter() {
996 // Binomial(10000, 1/24) with average 416.667
997 // Octave: binocdf(n, 10000, 1/24)
998 // 99.9% chance samples lie within this range:
999 assert!(352 <= *count && *count <= 483, "count: {}", count);
1000 }
1001 }
1002
1003 #[test]
test_partial_shuffle()1004 fn test_partial_shuffle() {
1005 let mut r = crate::test::rng(118);
1006
1007 let mut empty: [u32; 0] = [];
1008 let res = empty.partial_shuffle(&mut r, 10);
1009 assert_eq!((res.0.len(), res.1.len()), (0, 0));
1010
1011 let mut v = [1, 2, 3, 4, 5];
1012 let res = v.partial_shuffle(&mut r, 2);
1013 assert_eq!((res.0.len(), res.1.len()), (2, 3));
1014 assert!(res.0[0] != res.0[1]);
1015 // First elements are only modified if selected, so at least one isn't modified:
1016 assert!(res.1[0] == 1 || res.1[1] == 2 || res.1[2] == 3);
1017 }
1018
1019 #[test]
1020 #[cfg(feature = "alloc")]
test_sample_iter()1021 fn test_sample_iter() {
1022 let min_val = 1;
1023 let max_val = 100;
1024
1025 let mut r = crate::test::rng(401);
1026 let vals = (min_val..max_val).collect::<Vec<i32>>();
1027 let small_sample = vals.iter().choose_multiple(&mut r, 5);
1028 let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5);
1029
1030 assert_eq!(small_sample.len(), 5);
1031 assert_eq!(large_sample.len(), vals.len());
1032 // no randomization happens when amount >= len
1033 assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
1034
1035 assert!(small_sample
1036 .iter()
1037 .all(|e| { **e >= min_val && **e <= max_val }));
1038 }
1039
1040 #[test]
1041 #[cfg(feature = "alloc")]
1042 #[cfg_attr(miri, ignore)] // Miri is too slow
test_weighted()1043 fn test_weighted() {
1044 let mut r = crate::test::rng(406);
1045 const N_REPS: u32 = 3000;
1046 let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
1047 let total_weight = weights.iter().sum::<u32>() as f32;
1048
1049 let verify = |result: [i32; 14]| {
1050 for (i, count) in result.iter().enumerate() {
1051 let exp = (weights[i] * N_REPS) as f32 / total_weight;
1052 let mut err = (*count as f32 - exp).abs();
1053 if err != 0.0 {
1054 err /= exp;
1055 }
1056 assert!(err <= 0.25);
1057 }
1058 };
1059
1060 // choose_weighted
1061 fn get_weight<T>(item: &(u32, T)) -> u32 {
1062 item.0
1063 }
1064 let mut chosen = [0i32; 14];
1065 let mut items = [(0u32, 0usize); 14]; // (weight, index)
1066 for (i, item) in items.iter_mut().enumerate() {
1067 *item = (weights[i], i);
1068 }
1069 for _ in 0..N_REPS {
1070 let item = items.choose_weighted(&mut r, get_weight).unwrap();
1071 chosen[item.1] += 1;
1072 }
1073 verify(chosen);
1074
1075 // choose_weighted_mut
1076 let mut items = [(0u32, 0i32); 14]; // (weight, count)
1077 for (i, item) in items.iter_mut().enumerate() {
1078 *item = (weights[i], 0);
1079 }
1080 for _ in 0..N_REPS {
1081 items.choose_weighted_mut(&mut r, get_weight).unwrap().1 += 1;
1082 }
1083 for (ch, item) in chosen.iter_mut().zip(items.iter()) {
1084 *ch = item.1;
1085 }
1086 verify(chosen);
1087
1088 // Check error cases
1089 let empty_slice = &mut [10][0..0];
1090 assert_eq!(
1091 empty_slice.choose_weighted(&mut r, |_| 1),
1092 Err(WeightedError::NoItem)
1093 );
1094 assert_eq!(
1095 empty_slice.choose_weighted_mut(&mut r, |_| 1),
1096 Err(WeightedError::NoItem)
1097 );
1098 assert_eq!(
1099 ['x'].choose_weighted_mut(&mut r, |_| 0),
1100 Err(WeightedError::AllWeightsZero)
1101 );
1102 assert_eq!(
1103 [0, -1].choose_weighted_mut(&mut r, |x| *x),
1104 Err(WeightedError::InvalidWeight)
1105 );
1106 assert_eq!(
1107 [-1, 0].choose_weighted_mut(&mut r, |x| *x),
1108 Err(WeightedError::InvalidWeight)
1109 );
1110 }
1111
1112 #[test]
value_stability_choose()1113 fn value_stability_choose() {
1114 fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
1115 let mut rng = crate::test::rng(411);
1116 iter.choose(&mut rng)
1117 }
1118
1119 assert_eq!(choose([].iter().cloned()), None);
1120 assert_eq!(choose(0..100), Some(33));
1121 assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40));
1122 assert_eq!(
1123 choose(ChunkHintedIterator {
1124 iter: 0..100,
1125 chunk_size: 32,
1126 chunk_remaining: 32,
1127 hint_total_size: false,
1128 }),
1129 Some(39)
1130 );
1131 assert_eq!(
1132 choose(ChunkHintedIterator {
1133 iter: 0..100,
1134 chunk_size: 32,
1135 chunk_remaining: 32,
1136 hint_total_size: true,
1137 }),
1138 Some(39)
1139 );
1140 assert_eq!(
1141 choose(WindowHintedIterator {
1142 iter: 0..100,
1143 window_size: 32,
1144 hint_total_size: false,
1145 }),
1146 Some(90)
1147 );
1148 assert_eq!(
1149 choose(WindowHintedIterator {
1150 iter: 0..100,
1151 window_size: 32,
1152 hint_total_size: true,
1153 }),
1154 Some(90)
1155 );
1156 }
1157
1158 #[test]
value_stability_choose_stable()1159 fn value_stability_choose_stable() {
1160 fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
1161 let mut rng = crate::test::rng(411);
1162 iter.choose_stable(&mut rng)
1163 }
1164
1165 assert_eq!(choose([].iter().cloned()), None);
1166 assert_eq!(choose(0..100), Some(40));
1167 assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40));
1168 assert_eq!(
1169 choose(ChunkHintedIterator {
1170 iter: 0..100,
1171 chunk_size: 32,
1172 chunk_remaining: 32,
1173 hint_total_size: false,
1174 }),
1175 Some(40)
1176 );
1177 assert_eq!(
1178 choose(ChunkHintedIterator {
1179 iter: 0..100,
1180 chunk_size: 32,
1181 chunk_remaining: 32,
1182 hint_total_size: true,
1183 }),
1184 Some(40)
1185 );
1186 assert_eq!(
1187 choose(WindowHintedIterator {
1188 iter: 0..100,
1189 window_size: 32,
1190 hint_total_size: false,
1191 }),
1192 Some(40)
1193 );
1194 assert_eq!(
1195 choose(WindowHintedIterator {
1196 iter: 0..100,
1197 window_size: 32,
1198 hint_total_size: true,
1199 }),
1200 Some(40)
1201 );
1202 }
1203
1204 #[test]
value_stability_choose_multiple()1205 fn value_stability_choose_multiple() {
1206 fn do_test<I: Iterator<Item = u32>>(iter: I, v: &[u32]) {
1207 let mut rng = crate::test::rng(412);
1208 let mut buf = [0u32; 8];
1209 assert_eq!(iter.choose_multiple_fill(&mut rng, &mut buf), v.len());
1210 assert_eq!(&buf[0..v.len()], v);
1211 }
1212
1213 do_test(0..4, &[0, 1, 2, 3]);
1214 do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]);
1215 do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]);
1216
1217 #[cfg(feature = "alloc")]
1218 {
1219 fn do_test<I: Iterator<Item = u32>>(iter: I, v: &[u32]) {
1220 let mut rng = crate::test::rng(412);
1221 assert_eq!(iter.choose_multiple(&mut rng, v.len()), v);
1222 }
1223
1224 do_test(0..4, &[0, 1, 2, 3]);
1225 do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]);
1226 do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]);
1227 }
1228 }
1229
1230 #[test]
1231 #[cfg(feature = "alloc")]
test_multiple_weighted_edge_cases()1232 fn test_multiple_weighted_edge_cases() {
1233 use super::*;
1234
1235 let mut rng = crate::test::rng(413);
1236
1237 // Case 1: One of the weights is 0
1238 let choices = [('a', 2), ('b', 1), ('c', 0)];
1239 for _ in 0..100 {
1240 let result = choices
1241 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1242 .unwrap()
1243 .collect::<Vec<_>>();
1244
1245 assert_eq!(result.len(), 2);
1246 assert!(!result.iter().any(|val| val.0 == 'c'));
1247 }
1248
1249 // Case 2: All of the weights are 0
1250 let choices = [('a', 0), ('b', 0), ('c', 0)];
1251 let result = choices
1252 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1253 .unwrap()
1254 .collect::<Vec<_>>();
1255 assert_eq!(result.len(), 2);
1256
1257 // Case 3: Negative weights
1258 let choices = [('a', -1), ('b', 1), ('c', 1)];
1259 assert_eq!(
1260 choices
1261 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1262 .unwrap_err(),
1263 WeightedError::InvalidWeight
1264 );
1265
1266 // Case 4: Empty list
1267 let choices = [];
1268 let result = choices
1269 .choose_multiple_weighted(&mut rng, 0, |_: &()| 0)
1270 .unwrap()
1271 .collect::<Vec<_>>();
1272 assert_eq!(result.len(), 0);
1273
1274 // Case 5: NaN weights
1275 let choices = [('a', core::f64::NAN), ('b', 1.0), ('c', 1.0)];
1276 assert_eq!(
1277 choices
1278 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1279 .unwrap_err(),
1280 WeightedError::InvalidWeight
1281 );
1282
1283 // Case 6: +infinity weights
1284 let choices = [('a', core::f64::INFINITY), ('b', 1.0), ('c', 1.0)];
1285 for _ in 0..100 {
1286 let result = choices
1287 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1288 .unwrap()
1289 .collect::<Vec<_>>();
1290 assert_eq!(result.len(), 2);
1291 assert!(result.iter().any(|val| val.0 == 'a'));
1292 }
1293
1294 // Case 7: -infinity weights
1295 let choices = [('a', core::f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)];
1296 assert_eq!(
1297 choices
1298 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1299 .unwrap_err(),
1300 WeightedError::InvalidWeight
1301 );
1302
1303 // Case 8: -0 weights
1304 let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)];
1305 assert!(choices
1306 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1307 .is_ok());
1308 }
1309
1310 #[test]
1311 #[cfg(feature = "alloc")]
test_multiple_weighted_distributions()1312 fn test_multiple_weighted_distributions() {
1313 use super::*;
1314
1315 // The theoretical probabilities of the different outcomes are:
1316 // AB: 0.5 * 0.5 = 0.250
1317 // AC: 0.5 * 0.5 = 0.250
1318 // BA: 0.25 * 0.67 = 0.167
1319 // BC: 0.25 * 0.33 = 0.082
1320 // CA: 0.25 * 0.67 = 0.167
1321 // CB: 0.25 * 0.33 = 0.082
1322 let choices = [('a', 2), ('b', 1), ('c', 1)];
1323 let mut rng = crate::test::rng(414);
1324
1325 let mut results = [0i32; 3];
1326 let expected_results = [4167, 4167, 1666];
1327 for _ in 0..10000 {
1328 let result = choices
1329 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
1330 .unwrap()
1331 .collect::<Vec<_>>();
1332
1333 assert_eq!(result.len(), 2);
1334
1335 match (result[0].0, result[1].0) {
1336 ('a', 'b') | ('b', 'a') => {
1337 results[0] += 1;
1338 }
1339 ('a', 'c') | ('c', 'a') => {
1340 results[1] += 1;
1341 }
1342 ('b', 'c') | ('c', 'b') => {
1343 results[2] += 1;
1344 }
1345 (_, _) => panic!("unexpected result"),
1346 }
1347 }
1348
1349 let mut diffs = results
1350 .iter()
1351 .zip(&expected_results)
1352 .map(|(a, b)| (a - b).abs());
1353 assert!(!diffs.any(|deviation| deviation > 100));
1354 }
1355 }
1356