1 //! Rayon extensions for `HashSet`.
2 
3 use super::map;
4 use crate::hash_set::HashSet;
5 use crate::raw::{Allocator, Global};
6 use core::hash::{BuildHasher, Hash};
7 use rayon::iter::plumbing::UnindexedConsumer;
8 use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator};
9 
10 /// Parallel iterator over elements of a consumed set.
11 ///
12 /// This iterator is created by the [`into_par_iter`] method on [`HashSet`]
13 /// (provided by the [`IntoParallelIterator`] trait).
14 /// See its documentation for more.
15 ///
16 /// [`into_par_iter`]: /hashbrown/struct.HashSet.html#method.into_par_iter
17 /// [`HashSet`]: /hashbrown/struct.HashSet.html
18 /// [`IntoParallelIterator`]: https://docs.rs/rayon/1.0/rayon/iter/trait.IntoParallelIterator.html
19 pub struct IntoParIter<T, A: Allocator + Clone = Global> {
20     inner: map::IntoParIter<T, (), A>,
21 }
22 
23 impl<T: Send, A: Allocator + Clone + Send> ParallelIterator for IntoParIter<T, A> {
24     type Item = T;
25 
drive_unindexed<C>(self, consumer: C) -> C::Result where C: UnindexedConsumer<Self::Item>,26     fn drive_unindexed<C>(self, consumer: C) -> C::Result
27     where
28         C: UnindexedConsumer<Self::Item>,
29     {
30         self.inner.map(|(k, _)| k).drive_unindexed(consumer)
31     }
32 }
33 
34 /// Parallel draining iterator over entries of a set.
35 ///
36 /// This iterator is created by the [`par_drain`] method on [`HashSet`].
37 /// See its documentation for more.
38 ///
39 /// [`par_drain`]: /hashbrown/struct.HashSet.html#method.par_drain
40 /// [`HashSet`]: /hashbrown/struct.HashSet.html
41 pub struct ParDrain<'a, T, A: Allocator + Clone = Global> {
42     inner: map::ParDrain<'a, T, (), A>,
43 }
44 
45 impl<T: Send, A: Allocator + Clone + Send + Sync> ParallelIterator for ParDrain<'_, T, A> {
46     type Item = T;
47 
drive_unindexed<C>(self, consumer: C) -> C::Result where C: UnindexedConsumer<Self::Item>,48     fn drive_unindexed<C>(self, consumer: C) -> C::Result
49     where
50         C: UnindexedConsumer<Self::Item>,
51     {
52         self.inner.map(|(k, _)| k).drive_unindexed(consumer)
53     }
54 }
55 
56 /// Parallel iterator over shared references to elements in a set.
57 ///
58 /// This iterator is created by the [`par_iter`] method on [`HashSet`]
59 /// (provided by the [`IntoParallelRefIterator`] trait).
60 /// See its documentation for more.
61 ///
62 /// [`par_iter`]: /hashbrown/struct.HashSet.html#method.par_iter
63 /// [`HashSet`]: /hashbrown/struct.HashSet.html
64 /// [`IntoParallelRefIterator`]: https://docs.rs/rayon/1.0/rayon/iter/trait.IntoParallelRefIterator.html
65 pub struct ParIter<'a, T> {
66     inner: map::ParKeys<'a, T, ()>,
67 }
68 
69 impl<'a, T: Sync> ParallelIterator for ParIter<'a, T> {
70     type Item = &'a T;
71 
drive_unindexed<C>(self, consumer: C) -> C::Result where C: UnindexedConsumer<Self::Item>,72     fn drive_unindexed<C>(self, consumer: C) -> C::Result
73     where
74         C: UnindexedConsumer<Self::Item>,
75     {
76         self.inner.drive_unindexed(consumer)
77     }
78 }
79 
80 /// Parallel iterator over shared references to elements in the difference of
81 /// sets.
82 ///
83 /// This iterator is created by the [`par_difference`] method on [`HashSet`].
84 /// See its documentation for more.
85 ///
86 /// [`par_difference`]: /hashbrown/struct.HashSet.html#method.par_difference
87 /// [`HashSet`]: /hashbrown/struct.HashSet.html
88 pub struct ParDifference<'a, T, S, A: Allocator + Clone = Global> {
89     a: &'a HashSet<T, S, A>,
90     b: &'a HashSet<T, S, A>,
91 }
92 
93 impl<'a, T, S, A> ParallelIterator for ParDifference<'a, T, S, A>
94 where
95     T: Eq + Hash + Sync,
96     S: BuildHasher + Sync,
97     A: Allocator + Clone + Sync,
98 {
99     type Item = &'a T;
100 
drive_unindexed<C>(self, consumer: C) -> C::Result where C: UnindexedConsumer<Self::Item>,101     fn drive_unindexed<C>(self, consumer: C) -> C::Result
102     where
103         C: UnindexedConsumer<Self::Item>,
104     {
105         self.a
106             .into_par_iter()
107             .filter(|&x| !self.b.contains(x))
108             .drive_unindexed(consumer)
109     }
110 }
111 
112 /// Parallel iterator over shared references to elements in the symmetric
113 /// difference of sets.
114 ///
115 /// This iterator is created by the [`par_symmetric_difference`] method on
116 /// [`HashSet`].
117 /// See its documentation for more.
118 ///
119 /// [`par_symmetric_difference`]: /hashbrown/struct.HashSet.html#method.par_symmetric_difference
120 /// [`HashSet`]: /hashbrown/struct.HashSet.html
121 pub struct ParSymmetricDifference<'a, T, S, A: Allocator + Clone = Global> {
122     a: &'a HashSet<T, S, A>,
123     b: &'a HashSet<T, S, A>,
124 }
125 
126 impl<'a, T, S, A> ParallelIterator for ParSymmetricDifference<'a, T, S, A>
127 where
128     T: Eq + Hash + Sync,
129     S: BuildHasher + Sync,
130     A: Allocator + Clone + Sync,
131 {
132     type Item = &'a T;
133 
drive_unindexed<C>(self, consumer: C) -> C::Result where C: UnindexedConsumer<Self::Item>,134     fn drive_unindexed<C>(self, consumer: C) -> C::Result
135     where
136         C: UnindexedConsumer<Self::Item>,
137     {
138         self.a
139             .par_difference(self.b)
140             .chain(self.b.par_difference(self.a))
141             .drive_unindexed(consumer)
142     }
143 }
144 
145 /// Parallel iterator over shared references to elements in the intersection of
146 /// sets.
147 ///
148 /// This iterator is created by the [`par_intersection`] method on [`HashSet`].
149 /// See its documentation for more.
150 ///
151 /// [`par_intersection`]: /hashbrown/struct.HashSet.html#method.par_intersection
152 /// [`HashSet`]: /hashbrown/struct.HashSet.html
153 pub struct ParIntersection<'a, T, S, A: Allocator + Clone = Global> {
154     a: &'a HashSet<T, S, A>,
155     b: &'a HashSet<T, S, A>,
156 }
157 
158 impl<'a, T, S, A> ParallelIterator for ParIntersection<'a, T, S, A>
159 where
160     T: Eq + Hash + Sync,
161     S: BuildHasher + Sync,
162     A: Allocator + Clone + Sync,
163 {
164     type Item = &'a T;
165 
drive_unindexed<C>(self, consumer: C) -> C::Result where C: UnindexedConsumer<Self::Item>,166     fn drive_unindexed<C>(self, consumer: C) -> C::Result
167     where
168         C: UnindexedConsumer<Self::Item>,
169     {
170         self.a
171             .into_par_iter()
172             .filter(|&x| self.b.contains(x))
173             .drive_unindexed(consumer)
174     }
175 }
176 
177 /// Parallel iterator over shared references to elements in the union of sets.
178 ///
179 /// This iterator is created by the [`par_union`] method on [`HashSet`].
180 /// See its documentation for more.
181 ///
182 /// [`par_union`]: /hashbrown/struct.HashSet.html#method.par_union
183 /// [`HashSet`]: /hashbrown/struct.HashSet.html
184 pub struct ParUnion<'a, T, S, A: Allocator + Clone = Global> {
185     a: &'a HashSet<T, S, A>,
186     b: &'a HashSet<T, S, A>,
187 }
188 
189 impl<'a, T, S, A> ParallelIterator for ParUnion<'a, T, S, A>
190 where
191     T: Eq + Hash + Sync,
192     S: BuildHasher + Sync,
193     A: Allocator + Clone + Sync,
194 {
195     type Item = &'a T;
196 
drive_unindexed<C>(self, consumer: C) -> C::Result where C: UnindexedConsumer<Self::Item>,197     fn drive_unindexed<C>(self, consumer: C) -> C::Result
198     where
199         C: UnindexedConsumer<Self::Item>,
200     {
201         // We'll iterate one set in full, and only the remaining difference from the other.
202         // Use the smaller set for the difference in order to reduce hash lookups.
203         let (smaller, larger) = if self.a.len() <= self.b.len() {
204             (self.a, self.b)
205         } else {
206             (self.b, self.a)
207         };
208         larger
209             .into_par_iter()
210             .chain(smaller.par_difference(larger))
211             .drive_unindexed(consumer)
212     }
213 }
214 
215 impl<T, S, A> HashSet<T, S, A>
216 where
217     T: Eq + Hash + Sync,
218     S: BuildHasher + Sync,
219     A: Allocator + Clone + Sync,
220 {
221     /// Visits (potentially in parallel) the values representing the union,
222     /// i.e. all the values in `self` or `other`, without duplicates.
223     #[cfg_attr(feature = "inline-more", inline)]
par_union<'a>(&'a self, other: &'a Self) -> ParUnion<'a, T, S, A>224     pub fn par_union<'a>(&'a self, other: &'a Self) -> ParUnion<'a, T, S, A> {
225         ParUnion { a: self, b: other }
226     }
227 
228     /// Visits (potentially in parallel) the values representing the difference,
229     /// i.e. the values that are in `self` but not in `other`.
230     #[cfg_attr(feature = "inline-more", inline)]
par_difference<'a>(&'a self, other: &'a Self) -> ParDifference<'a, T, S, A>231     pub fn par_difference<'a>(&'a self, other: &'a Self) -> ParDifference<'a, T, S, A> {
232         ParDifference { a: self, b: other }
233     }
234 
235     /// Visits (potentially in parallel) the values representing the symmetric
236     /// difference, i.e. the values that are in `self` or in `other` but not in both.
237     #[cfg_attr(feature = "inline-more", inline)]
par_symmetric_difference<'a>( &'a self, other: &'a Self, ) -> ParSymmetricDifference<'a, T, S, A>238     pub fn par_symmetric_difference<'a>(
239         &'a self,
240         other: &'a Self,
241     ) -> ParSymmetricDifference<'a, T, S, A> {
242         ParSymmetricDifference { a: self, b: other }
243     }
244 
245     /// Visits (potentially in parallel) the values representing the
246     /// intersection, i.e. the values that are both in `self` and `other`.
247     #[cfg_attr(feature = "inline-more", inline)]
par_intersection<'a>(&'a self, other: &'a Self) -> ParIntersection<'a, T, S, A>248     pub fn par_intersection<'a>(&'a self, other: &'a Self) -> ParIntersection<'a, T, S, A> {
249         ParIntersection { a: self, b: other }
250     }
251 
252     /// Returns `true` if `self` has no elements in common with `other`.
253     /// This is equivalent to checking for an empty intersection.
254     ///
255     /// This method runs in a potentially parallel fashion.
par_is_disjoint(&self, other: &Self) -> bool256     pub fn par_is_disjoint(&self, other: &Self) -> bool {
257         self.into_par_iter().all(|x| !other.contains(x))
258     }
259 
260     /// Returns `true` if the set is a subset of another,
261     /// i.e. `other` contains at least all the values in `self`.
262     ///
263     /// This method runs in a potentially parallel fashion.
par_is_subset(&self, other: &Self) -> bool264     pub fn par_is_subset(&self, other: &Self) -> bool {
265         if self.len() <= other.len() {
266             self.into_par_iter().all(|x| other.contains(x))
267         } else {
268             false
269         }
270     }
271 
272     /// Returns `true` if the set is a superset of another,
273     /// i.e. `self` contains at least all the values in `other`.
274     ///
275     /// This method runs in a potentially parallel fashion.
par_is_superset(&self, other: &Self) -> bool276     pub fn par_is_superset(&self, other: &Self) -> bool {
277         other.par_is_subset(self)
278     }
279 
280     /// Returns `true` if the set is equal to another,
281     /// i.e. both sets contain the same values.
282     ///
283     /// This method runs in a potentially parallel fashion.
par_eq(&self, other: &Self) -> bool284     pub fn par_eq(&self, other: &Self) -> bool {
285         self.len() == other.len() && self.par_is_subset(other)
286     }
287 }
288 
289 impl<T, S, A> HashSet<T, S, A>
290 where
291     T: Eq + Hash + Send,
292     A: Allocator + Clone + Send,
293 {
294     /// Consumes (potentially in parallel) all values in an arbitrary order,
295     /// while preserving the set's allocated memory for reuse.
296     #[cfg_attr(feature = "inline-more", inline)]
par_drain(&mut self) -> ParDrain<'_, T, A>297     pub fn par_drain(&mut self) -> ParDrain<'_, T, A> {
298         ParDrain {
299             inner: self.map.par_drain(),
300         }
301     }
302 }
303 
304 impl<T: Send, S, A: Allocator + Clone + Send> IntoParallelIterator for HashSet<T, S, A> {
305     type Item = T;
306     type Iter = IntoParIter<T, A>;
307 
308     #[cfg_attr(feature = "inline-more", inline)]
into_par_iter(self) -> Self::Iter309     fn into_par_iter(self) -> Self::Iter {
310         IntoParIter {
311             inner: self.map.into_par_iter(),
312         }
313     }
314 }
315 
316 impl<'a, T: Sync, S, A: Allocator + Clone> IntoParallelIterator for &'a HashSet<T, S, A> {
317     type Item = &'a T;
318     type Iter = ParIter<'a, T>;
319 
320     #[cfg_attr(feature = "inline-more", inline)]
into_par_iter(self) -> Self::Iter321     fn into_par_iter(self) -> Self::Iter {
322         ParIter {
323             inner: self.map.par_keys(),
324         }
325     }
326 }
327 
328 /// Collect values from a parallel iterator into a hashset.
329 impl<T, S> FromParallelIterator<T> for HashSet<T, S, Global>
330 where
331     T: Eq + Hash + Send,
332     S: BuildHasher + Default,
333 {
from_par_iter<P>(par_iter: P) -> Self where P: IntoParallelIterator<Item = T>,334     fn from_par_iter<P>(par_iter: P) -> Self
335     where
336         P: IntoParallelIterator<Item = T>,
337     {
338         let mut set = HashSet::default();
339         set.par_extend(par_iter);
340         set
341     }
342 }
343 
344 /// Extend a hash set with items from a parallel iterator.
345 impl<T, S> ParallelExtend<T> for HashSet<T, S, Global>
346 where
347     T: Eq + Hash + Send,
348     S: BuildHasher,
349 {
par_extend<I>(&mut self, par_iter: I) where I: IntoParallelIterator<Item = T>,350     fn par_extend<I>(&mut self, par_iter: I)
351     where
352         I: IntoParallelIterator<Item = T>,
353     {
354         extend(self, par_iter);
355     }
356 }
357 
358 /// Extend a hash set with copied items from a parallel iterator.
359 impl<'a, T, S> ParallelExtend<&'a T> for HashSet<T, S, Global>
360 where
361     T: 'a + Copy + Eq + Hash + Sync,
362     S: BuildHasher,
363 {
par_extend<I>(&mut self, par_iter: I) where I: IntoParallelIterator<Item = &'a T>,364     fn par_extend<I>(&mut self, par_iter: I)
365     where
366         I: IntoParallelIterator<Item = &'a T>,
367     {
368         extend(self, par_iter);
369     }
370 }
371 
372 // This is equal to the normal `HashSet` -- no custom advantage.
extend<T, S, I, A>(set: &mut HashSet<T, S, A>, par_iter: I) where T: Eq + Hash, S: BuildHasher, A: Allocator + Clone, I: IntoParallelIterator, HashSet<T, S, A>: Extend<I::Item>,373 fn extend<T, S, I, A>(set: &mut HashSet<T, S, A>, par_iter: I)
374 where
375     T: Eq + Hash,
376     S: BuildHasher,
377     A: Allocator + Clone,
378     I: IntoParallelIterator,
379     HashSet<T, S, A>: Extend<I::Item>,
380 {
381     let (list, len) = super::helpers::collect(par_iter);
382 
383     // Values may be already present or show multiple times in the iterator.
384     // Reserve the entire length if the set is empty.
385     // Otherwise reserve half the length (rounded up), so the set
386     // will only resize twice in the worst case.
387     let reserve = if set.is_empty() { len } else { (len + 1) / 2 };
388     set.reserve(reserve);
389     for vec in list {
390         set.extend(vec);
391     }
392 }
393 
394 #[cfg(test)]
395 mod test_par_set {
396     use alloc::vec::Vec;
397     use core::sync::atomic::{AtomicUsize, Ordering};
398 
399     use rayon::prelude::*;
400 
401     use crate::hash_set::HashSet;
402 
403     #[test]
test_disjoint()404     fn test_disjoint() {
405         let mut xs = HashSet::new();
406         let mut ys = HashSet::new();
407         assert!(xs.par_is_disjoint(&ys));
408         assert!(ys.par_is_disjoint(&xs));
409         assert!(xs.insert(5));
410         assert!(ys.insert(11));
411         assert!(xs.par_is_disjoint(&ys));
412         assert!(ys.par_is_disjoint(&xs));
413         assert!(xs.insert(7));
414         assert!(xs.insert(19));
415         assert!(xs.insert(4));
416         assert!(ys.insert(2));
417         assert!(ys.insert(-11));
418         assert!(xs.par_is_disjoint(&ys));
419         assert!(ys.par_is_disjoint(&xs));
420         assert!(ys.insert(7));
421         assert!(!xs.par_is_disjoint(&ys));
422         assert!(!ys.par_is_disjoint(&xs));
423     }
424 
425     #[test]
test_subset_and_superset()426     fn test_subset_and_superset() {
427         let mut a = HashSet::new();
428         assert!(a.insert(0));
429         assert!(a.insert(5));
430         assert!(a.insert(11));
431         assert!(a.insert(7));
432 
433         let mut b = HashSet::new();
434         assert!(b.insert(0));
435         assert!(b.insert(7));
436         assert!(b.insert(19));
437         assert!(b.insert(250));
438         assert!(b.insert(11));
439         assert!(b.insert(200));
440 
441         assert!(!a.par_is_subset(&b));
442         assert!(!a.par_is_superset(&b));
443         assert!(!b.par_is_subset(&a));
444         assert!(!b.par_is_superset(&a));
445 
446         assert!(b.insert(5));
447 
448         assert!(a.par_is_subset(&b));
449         assert!(!a.par_is_superset(&b));
450         assert!(!b.par_is_subset(&a));
451         assert!(b.par_is_superset(&a));
452     }
453 
454     #[test]
test_iterate()455     fn test_iterate() {
456         let mut a = HashSet::new();
457         for i in 0..32 {
458             assert!(a.insert(i));
459         }
460         let observed = AtomicUsize::new(0);
461         a.par_iter().for_each(|k| {
462             observed.fetch_or(1 << *k, Ordering::Relaxed);
463         });
464         assert_eq!(observed.into_inner(), 0xFFFF_FFFF);
465     }
466 
467     #[test]
test_intersection()468     fn test_intersection() {
469         let mut a = HashSet::new();
470         let mut b = HashSet::new();
471 
472         assert!(a.insert(11));
473         assert!(a.insert(1));
474         assert!(a.insert(3));
475         assert!(a.insert(77));
476         assert!(a.insert(103));
477         assert!(a.insert(5));
478         assert!(a.insert(-5));
479 
480         assert!(b.insert(2));
481         assert!(b.insert(11));
482         assert!(b.insert(77));
483         assert!(b.insert(-9));
484         assert!(b.insert(-42));
485         assert!(b.insert(5));
486         assert!(b.insert(3));
487 
488         let expected = [3, 5, 11, 77];
489         let i = a
490             .par_intersection(&b)
491             .map(|x| {
492                 assert!(expected.contains(x));
493                 1
494             })
495             .sum::<usize>();
496         assert_eq!(i, expected.len());
497     }
498 
499     #[test]
test_difference()500     fn test_difference() {
501         let mut a = HashSet::new();
502         let mut b = HashSet::new();
503 
504         assert!(a.insert(1));
505         assert!(a.insert(3));
506         assert!(a.insert(5));
507         assert!(a.insert(9));
508         assert!(a.insert(11));
509 
510         assert!(b.insert(3));
511         assert!(b.insert(9));
512 
513         let expected = [1, 5, 11];
514         let i = a
515             .par_difference(&b)
516             .map(|x| {
517                 assert!(expected.contains(x));
518                 1
519             })
520             .sum::<usize>();
521         assert_eq!(i, expected.len());
522     }
523 
524     #[test]
test_symmetric_difference()525     fn test_symmetric_difference() {
526         let mut a = HashSet::new();
527         let mut b = HashSet::new();
528 
529         assert!(a.insert(1));
530         assert!(a.insert(3));
531         assert!(a.insert(5));
532         assert!(a.insert(9));
533         assert!(a.insert(11));
534 
535         assert!(b.insert(-2));
536         assert!(b.insert(3));
537         assert!(b.insert(9));
538         assert!(b.insert(14));
539         assert!(b.insert(22));
540 
541         let expected = [-2, 1, 5, 11, 14, 22];
542         let i = a
543             .par_symmetric_difference(&b)
544             .map(|x| {
545                 assert!(expected.contains(x));
546                 1
547             })
548             .sum::<usize>();
549         assert_eq!(i, expected.len());
550     }
551 
552     #[test]
test_union()553     fn test_union() {
554         let mut a = HashSet::new();
555         let mut b = HashSet::new();
556 
557         assert!(a.insert(1));
558         assert!(a.insert(3));
559         assert!(a.insert(5));
560         assert!(a.insert(9));
561         assert!(a.insert(11));
562         assert!(a.insert(16));
563         assert!(a.insert(19));
564         assert!(a.insert(24));
565 
566         assert!(b.insert(-2));
567         assert!(b.insert(1));
568         assert!(b.insert(5));
569         assert!(b.insert(9));
570         assert!(b.insert(13));
571         assert!(b.insert(19));
572 
573         let expected = [-2, 1, 3, 5, 9, 11, 13, 16, 19, 24];
574         let i = a
575             .par_union(&b)
576             .map(|x| {
577                 assert!(expected.contains(x));
578                 1
579             })
580             .sum::<usize>();
581         assert_eq!(i, expected.len());
582     }
583 
584     #[test]
test_from_iter()585     fn test_from_iter() {
586         let xs = [1, 2, 3, 4, 5, 6, 7, 8, 9];
587 
588         let set: HashSet<_> = xs.par_iter().cloned().collect();
589 
590         for x in &xs {
591             assert!(set.contains(x));
592         }
593     }
594 
595     #[test]
test_move_iter()596     fn test_move_iter() {
597         let hs = {
598             let mut hs = HashSet::new();
599 
600             hs.insert('a');
601             hs.insert('b');
602 
603             hs
604         };
605 
606         let v = hs.into_par_iter().collect::<Vec<char>>();
607         assert!(v == ['a', 'b'] || v == ['b', 'a']);
608     }
609 
610     #[test]
test_eq()611     fn test_eq() {
612         // These constants once happened to expose a bug in insert().
613         // I'm keeping them around to prevent a regression.
614         let mut s1 = HashSet::new();
615 
616         s1.insert(1);
617         s1.insert(2);
618         s1.insert(3);
619 
620         let mut s2 = HashSet::new();
621 
622         s2.insert(1);
623         s2.insert(2);
624 
625         assert!(!s1.par_eq(&s2));
626 
627         s2.insert(3);
628 
629         assert!(s1.par_eq(&s2));
630     }
631 
632     #[test]
test_extend_ref()633     fn test_extend_ref() {
634         let mut a = HashSet::new();
635         a.insert(1);
636 
637         a.par_extend(&[2, 3, 4][..]);
638 
639         assert_eq!(a.len(), 4);
640         assert!(a.contains(&1));
641         assert!(a.contains(&2));
642         assert!(a.contains(&3));
643         assert!(a.contains(&4));
644 
645         let mut b = HashSet::new();
646         b.insert(5);
647         b.insert(6);
648 
649         a.par_extend(&b);
650 
651         assert_eq!(a.len(), 6);
652         assert!(a.contains(&1));
653         assert!(a.contains(&2));
654         assert!(a.contains(&3));
655         assert!(a.contains(&4));
656         assert!(a.contains(&5));
657         assert!(a.contains(&6));
658     }
659 }
660