1 //! Rayon extensions for `HashSet`.
2 
3 use crate::hash_set::HashSet;
4 use core::hash::{BuildHasher, Hash};
5 use rayon::iter::plumbing::UnindexedConsumer;
6 use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator};
7 
8 /// Parallel iterator over elements of a consumed set.
9 ///
10 /// This iterator is created by the [`into_par_iter`] method on [`HashSet`]
11 /// (provided by the [`IntoParallelIterator`] trait).
12 /// See its documentation for more.
13 ///
14 /// [`into_par_iter`]: /hashbrown/struct.HashSet.html#method.into_par_iter
15 /// [`HashSet`]: /hashbrown/struct.HashSet.html
16 /// [`IntoParallelIterator`]: https://docs.rs/rayon/1.0/rayon/iter/trait.IntoParallelIterator.html
17 pub struct IntoParIter<T, S> {
18     set: HashSet<T, S>,
19 }
20 
21 impl<T: Send, S: Send> ParallelIterator for IntoParIter<T, S> {
22     type Item = T;
23 
drive_unindexed<C>(self, consumer: C) -> C::Result where C: UnindexedConsumer<Self::Item>,24     fn drive_unindexed<C>(self, consumer: C) -> C::Result
25     where
26         C: UnindexedConsumer<Self::Item>,
27     {
28         self.set
29             .map
30             .into_par_iter()
31             .map(|(k, _)| k)
32             .drive_unindexed(consumer)
33     }
34 }
35 
36 /// Parallel draining iterator over entries of a set.
37 ///
38 /// This iterator is created by the [`par_drain`] method on [`HashSet`].
39 /// See its documentation for more.
40 ///
41 /// [`par_drain`]: /hashbrown/struct.HashSet.html#method.par_drain
42 /// [`HashSet`]: /hashbrown/struct.HashSet.html
43 pub struct ParDrain<'a, T, S> {
44     set: &'a mut HashSet<T, S>,
45 }
46 
47 impl<T: Send, S: Send> ParallelIterator for ParDrain<'_, T, S> {
48     type Item = T;
49 
drive_unindexed<C>(self, consumer: C) -> C::Result where C: UnindexedConsumer<Self::Item>,50     fn drive_unindexed<C>(self, consumer: C) -> C::Result
51     where
52         C: UnindexedConsumer<Self::Item>,
53     {
54         self.set
55             .map
56             .par_drain()
57             .map(|(k, _)| k)
58             .drive_unindexed(consumer)
59     }
60 }
61 
62 /// Parallel iterator over shared references to elements in a set.
63 ///
64 /// This iterator is created by the [`par_iter`] method on [`HashSet`]
65 /// (provided by the [`IntoParallelRefIterator`] trait).
66 /// See its documentation for more.
67 ///
68 /// [`par_iter`]: /hashbrown/struct.HashSet.html#method.par_iter
69 /// [`HashSet`]: /hashbrown/struct.HashSet.html
70 /// [`IntoParallelRefIterator`]: https://docs.rs/rayon/1.0/rayon/iter/trait.IntoParallelRefIterator.html
71 pub struct ParIter<'a, T, S> {
72     set: &'a HashSet<T, S>,
73 }
74 
75 impl<'a, T: Sync, S: Sync> ParallelIterator for ParIter<'a, T, S> {
76     type Item = &'a T;
77 
drive_unindexed<C>(self, consumer: C) -> C::Result where C: UnindexedConsumer<Self::Item>,78     fn drive_unindexed<C>(self, consumer: C) -> C::Result
79     where
80         C: UnindexedConsumer<Self::Item>,
81     {
82         self.set.map.par_keys().drive_unindexed(consumer)
83     }
84 }
85 
86 /// Parallel iterator over shared references to elements in the difference of
87 /// sets.
88 ///
89 /// This iterator is created by the [`par_difference`] method on [`HashSet`].
90 /// See its documentation for more.
91 ///
92 /// [`par_difference`]: /hashbrown/struct.HashSet.html#method.par_difference
93 /// [`HashSet`]: /hashbrown/struct.HashSet.html
94 pub struct ParDifference<'a, T, S> {
95     a: &'a HashSet<T, S>,
96     b: &'a HashSet<T, S>,
97 }
98 
99 impl<'a, T, S> ParallelIterator for ParDifference<'a, T, S>
100 where
101     T: Eq + Hash + Sync,
102     S: BuildHasher + Sync,
103 {
104     type Item = &'a T;
105 
drive_unindexed<C>(self, consumer: C) -> C::Result where C: UnindexedConsumer<Self::Item>,106     fn drive_unindexed<C>(self, consumer: C) -> C::Result
107     where
108         C: UnindexedConsumer<Self::Item>,
109     {
110         self.a
111             .into_par_iter()
112             .filter(|&x| !self.b.contains(x))
113             .drive_unindexed(consumer)
114     }
115 }
116 
117 /// Parallel iterator over shared references to elements in the symmetric
118 /// difference of sets.
119 ///
120 /// This iterator is created by the [`par_symmetric_difference`] method on
121 /// [`HashSet`].
122 /// See its documentation for more.
123 ///
124 /// [`par_symmetric_difference`]: /hashbrown/struct.HashSet.html#method.par_symmetric_difference
125 /// [`HashSet`]: /hashbrown/struct.HashSet.html
126 pub struct ParSymmetricDifference<'a, T, S> {
127     a: &'a HashSet<T, S>,
128     b: &'a HashSet<T, S>,
129 }
130 
131 impl<'a, T, S> ParallelIterator for ParSymmetricDifference<'a, T, S>
132 where
133     T: Eq + Hash + Sync,
134     S: BuildHasher + Sync,
135 {
136     type Item = &'a T;
137 
drive_unindexed<C>(self, consumer: C) -> C::Result where C: UnindexedConsumer<Self::Item>,138     fn drive_unindexed<C>(self, consumer: C) -> C::Result
139     where
140         C: UnindexedConsumer<Self::Item>,
141     {
142         self.a
143             .par_difference(self.b)
144             .chain(self.b.par_difference(self.a))
145             .drive_unindexed(consumer)
146     }
147 }
148 
149 /// Parallel iterator over shared references to elements in the intersection of
150 /// sets.
151 ///
152 /// This iterator is created by the [`par_intersection`] method on [`HashSet`].
153 /// See its documentation for more.
154 ///
155 /// [`par_intersection`]: /hashbrown/struct.HashSet.html#method.par_intersection
156 /// [`HashSet`]: /hashbrown/struct.HashSet.html
157 pub struct ParIntersection<'a, T, S> {
158     a: &'a HashSet<T, S>,
159     b: &'a HashSet<T, S>,
160 }
161 
162 impl<'a, T, S> ParallelIterator for ParIntersection<'a, T, S>
163 where
164     T: Eq + Hash + Sync,
165     S: BuildHasher + Sync,
166 {
167     type Item = &'a T;
168 
drive_unindexed<C>(self, consumer: C) -> C::Result where C: UnindexedConsumer<Self::Item>,169     fn drive_unindexed<C>(self, consumer: C) -> C::Result
170     where
171         C: UnindexedConsumer<Self::Item>,
172     {
173         self.a
174             .into_par_iter()
175             .filter(|&x| self.b.contains(x))
176             .drive_unindexed(consumer)
177     }
178 }
179 
180 /// Parallel iterator over shared references to elements in the union of sets.
181 ///
182 /// This iterator is created by the [`par_union`] method on [`HashSet`].
183 /// See its documentation for more.
184 ///
185 /// [`par_union`]: /hashbrown/struct.HashSet.html#method.par_union
186 /// [`HashSet`]: /hashbrown/struct.HashSet.html
187 pub struct ParUnion<'a, T, S> {
188     a: &'a HashSet<T, S>,
189     b: &'a HashSet<T, S>,
190 }
191 
192 impl<'a, T, S> ParallelIterator for ParUnion<'a, T, S>
193 where
194     T: Eq + Hash + Sync,
195     S: BuildHasher + Sync,
196 {
197     type Item = &'a T;
198 
drive_unindexed<C>(self, consumer: C) -> C::Result where C: UnindexedConsumer<Self::Item>,199     fn drive_unindexed<C>(self, consumer: C) -> C::Result
200     where
201         C: UnindexedConsumer<Self::Item>,
202     {
203         self.a
204             .into_par_iter()
205             .chain(self.b.par_difference(self.a))
206             .drive_unindexed(consumer)
207     }
208 }
209 
210 impl<T, S> HashSet<T, S>
211 where
212     T: Eq + Hash + Sync,
213     S: BuildHasher + Sync,
214 {
215     /// Visits (potentially in parallel) the values representing the difference,
216     /// i.e. the values that are in `self` but not in `other`.
217     #[cfg_attr(feature = "inline-more", inline)]
par_difference<'a>(&'a self, other: &'a Self) -> ParDifference<'a, T, S>218     pub fn par_difference<'a>(&'a self, other: &'a Self) -> ParDifference<'a, T, S> {
219         ParDifference { a: self, b: other }
220     }
221 
222     /// Visits (potentially in parallel) the values representing the symmetric
223     /// difference, i.e. the values that are in `self` or in `other` but not in both.
224     #[cfg_attr(feature = "inline-more", inline)]
par_symmetric_difference<'a>( &'a self, other: &'a Self, ) -> ParSymmetricDifference<'a, T, S>225     pub fn par_symmetric_difference<'a>(
226         &'a self,
227         other: &'a Self,
228     ) -> ParSymmetricDifference<'a, T, S> {
229         ParSymmetricDifference { a: self, b: other }
230     }
231 
232     /// Visits (potentially in parallel) the values representing the
233     /// intersection, i.e. the values that are both in `self` and `other`.
234     #[cfg_attr(feature = "inline-more", inline)]
par_intersection<'a>(&'a self, other: &'a Self) -> ParIntersection<'a, T, S>235     pub fn par_intersection<'a>(&'a self, other: &'a Self) -> ParIntersection<'a, T, S> {
236         ParIntersection { a: self, b: other }
237     }
238 
239     /// Visits (potentially in parallel) the values representing the union,
240     /// i.e. all the values in `self` or `other`, without duplicates.
241     #[cfg_attr(feature = "inline-more", inline)]
par_union<'a>(&'a self, other: &'a Self) -> ParUnion<'a, T, S>242     pub fn par_union<'a>(&'a self, other: &'a Self) -> ParUnion<'a, T, S> {
243         ParUnion { a: self, b: other }
244     }
245 
246     /// Returns `true` if `self` has no elements in common with `other`.
247     /// This is equivalent to checking for an empty intersection.
248     ///
249     /// This method runs in a potentially parallel fashion.
par_is_disjoint(&self, other: &Self) -> bool250     pub fn par_is_disjoint(&self, other: &Self) -> bool {
251         self.into_par_iter().all(|x| !other.contains(x))
252     }
253 
254     /// Returns `true` if the set is a subset of another,
255     /// i.e. `other` contains at least all the values in `self`.
256     ///
257     /// This method runs in a potentially parallel fashion.
par_is_subset(&self, other: &Self) -> bool258     pub fn par_is_subset(&self, other: &Self) -> bool {
259         if self.len() <= other.len() {
260             self.into_par_iter().all(|x| other.contains(x))
261         } else {
262             false
263         }
264     }
265 
266     /// Returns `true` if the set is a superset of another,
267     /// i.e. `self` contains at least all the values in `other`.
268     ///
269     /// This method runs in a potentially parallel fashion.
par_is_superset(&self, other: &Self) -> bool270     pub fn par_is_superset(&self, other: &Self) -> bool {
271         other.par_is_subset(self)
272     }
273 
274     /// Returns `true` if the set is equal to another,
275     /// i.e. both sets contain the same values.
276     ///
277     /// This method runs in a potentially parallel fashion.
par_eq(&self, other: &Self) -> bool278     pub fn par_eq(&self, other: &Self) -> bool {
279         self.len() == other.len() && self.par_is_subset(other)
280     }
281 }
282 
283 impl<T, S> HashSet<T, S>
284 where
285     T: Eq + Hash + Send,
286     S: BuildHasher + Send,
287 {
288     /// Consumes (potentially in parallel) all values in an arbitrary order,
289     /// while preserving the set's allocated memory for reuse.
290     #[cfg_attr(feature = "inline-more", inline)]
par_drain(&mut self) -> ParDrain<'_, T, S>291     pub fn par_drain(&mut self) -> ParDrain<'_, T, S> {
292         ParDrain { set: self }
293     }
294 }
295 
296 impl<T: Send, S: Send> IntoParallelIterator for HashSet<T, S> {
297     type Item = T;
298     type Iter = IntoParIter<T, S>;
299 
300     #[cfg_attr(feature = "inline-more", inline)]
into_par_iter(self) -> Self::Iter301     fn into_par_iter(self) -> Self::Iter {
302         IntoParIter { set: self }
303     }
304 }
305 
306 impl<'a, T: Sync, S: Sync> IntoParallelIterator for &'a HashSet<T, S> {
307     type Item = &'a T;
308     type Iter = ParIter<'a, T, S>;
309 
310     #[cfg_attr(feature = "inline-more", inline)]
into_par_iter(self) -> Self::Iter311     fn into_par_iter(self) -> Self::Iter {
312         ParIter { set: self }
313     }
314 }
315 
316 /// Collect values from a parallel iterator into a hashset.
317 impl<T, S> FromParallelIterator<T> for HashSet<T, S>
318 where
319     T: Eq + Hash + Send,
320     S: BuildHasher + Default,
321 {
from_par_iter<P>(par_iter: P) -> Self where P: IntoParallelIterator<Item = T>,322     fn from_par_iter<P>(par_iter: P) -> Self
323     where
324         P: IntoParallelIterator<Item = T>,
325     {
326         let mut set = HashSet::default();
327         set.par_extend(par_iter);
328         set
329     }
330 }
331 
332 /// Extend a hash set with items from a parallel iterator.
333 impl<T, S> ParallelExtend<T> for HashSet<T, S>
334 where
335     T: Eq + Hash + Send,
336     S: BuildHasher,
337 {
par_extend<I>(&mut self, par_iter: I) where I: IntoParallelIterator<Item = T>,338     fn par_extend<I>(&mut self, par_iter: I)
339     where
340         I: IntoParallelIterator<Item = T>,
341     {
342         extend(self, par_iter);
343     }
344 }
345 
346 /// Extend a hash set with copied items from a parallel iterator.
347 impl<'a, T, S> ParallelExtend<&'a T> for HashSet<T, S>
348 where
349     T: 'a + Copy + Eq + Hash + Sync,
350     S: BuildHasher,
351 {
par_extend<I>(&mut self, par_iter: I) where I: IntoParallelIterator<Item = &'a T>,352     fn par_extend<I>(&mut self, par_iter: I)
353     where
354         I: IntoParallelIterator<Item = &'a T>,
355     {
356         extend(self, par_iter);
357     }
358 }
359 
360 // This is equal to the normal `HashSet` -- no custom advantage.
extend<T, S, I>(set: &mut HashSet<T, S>, par_iter: I) where T: Eq + Hash, S: BuildHasher, I: IntoParallelIterator, HashSet<T, S>: Extend<I::Item>,361 fn extend<T, S, I>(set: &mut HashSet<T, S>, par_iter: I)
362 where
363     T: Eq + Hash,
364     S: BuildHasher,
365     I: IntoParallelIterator,
366     HashSet<T, S>: Extend<I::Item>,
367 {
368     let (list, len) = super::helpers::collect(par_iter);
369 
370     // Values may be already present or show multiple times in the iterator.
371     // Reserve the entire length if the set is empty.
372     // Otherwise reserve half the length (rounded up), so the set
373     // will only resize twice in the worst case.
374     let reserve = if set.is_empty() { len } else { (len + 1) / 2 };
375     set.reserve(reserve);
376     for vec in list {
377         set.extend(vec);
378     }
379 }
380 
381 #[cfg(test)]
382 mod test_par_set {
383     use alloc::vec::Vec;
384     use core::sync::atomic::{AtomicUsize, Ordering};
385 
386     use rayon::prelude::*;
387 
388     use crate::hash_set::HashSet;
389 
390     #[test]
test_disjoint()391     fn test_disjoint() {
392         let mut xs = HashSet::new();
393         let mut ys = HashSet::new();
394         assert!(xs.par_is_disjoint(&ys));
395         assert!(ys.par_is_disjoint(&xs));
396         assert!(xs.insert(5));
397         assert!(ys.insert(11));
398         assert!(xs.par_is_disjoint(&ys));
399         assert!(ys.par_is_disjoint(&xs));
400         assert!(xs.insert(7));
401         assert!(xs.insert(19));
402         assert!(xs.insert(4));
403         assert!(ys.insert(2));
404         assert!(ys.insert(-11));
405         assert!(xs.par_is_disjoint(&ys));
406         assert!(ys.par_is_disjoint(&xs));
407         assert!(ys.insert(7));
408         assert!(!xs.par_is_disjoint(&ys));
409         assert!(!ys.par_is_disjoint(&xs));
410     }
411 
412     #[test]
test_subset_and_superset()413     fn test_subset_and_superset() {
414         let mut a = HashSet::new();
415         assert!(a.insert(0));
416         assert!(a.insert(5));
417         assert!(a.insert(11));
418         assert!(a.insert(7));
419 
420         let mut b = HashSet::new();
421         assert!(b.insert(0));
422         assert!(b.insert(7));
423         assert!(b.insert(19));
424         assert!(b.insert(250));
425         assert!(b.insert(11));
426         assert!(b.insert(200));
427 
428         assert!(!a.par_is_subset(&b));
429         assert!(!a.par_is_superset(&b));
430         assert!(!b.par_is_subset(&a));
431         assert!(!b.par_is_superset(&a));
432 
433         assert!(b.insert(5));
434 
435         assert!(a.par_is_subset(&b));
436         assert!(!a.par_is_superset(&b));
437         assert!(!b.par_is_subset(&a));
438         assert!(b.par_is_superset(&a));
439     }
440 
441     #[test]
test_iterate()442     fn test_iterate() {
443         let mut a = HashSet::new();
444         for i in 0..32 {
445             assert!(a.insert(i));
446         }
447         let observed = AtomicUsize::new(0);
448         a.par_iter().for_each(|k| {
449             observed.fetch_or(1 << *k, Ordering::Relaxed);
450         });
451         assert_eq!(observed.into_inner(), 0xFFFF_FFFF);
452     }
453 
454     #[test]
test_intersection()455     fn test_intersection() {
456         let mut a = HashSet::new();
457         let mut b = HashSet::new();
458 
459         assert!(a.insert(11));
460         assert!(a.insert(1));
461         assert!(a.insert(3));
462         assert!(a.insert(77));
463         assert!(a.insert(103));
464         assert!(a.insert(5));
465         assert!(a.insert(-5));
466 
467         assert!(b.insert(2));
468         assert!(b.insert(11));
469         assert!(b.insert(77));
470         assert!(b.insert(-9));
471         assert!(b.insert(-42));
472         assert!(b.insert(5));
473         assert!(b.insert(3));
474 
475         let expected = [3, 5, 11, 77];
476         let i = a
477             .par_intersection(&b)
478             .map(|x| {
479                 assert!(expected.contains(x));
480                 1
481             })
482             .sum::<usize>();
483         assert_eq!(i, expected.len());
484     }
485 
486     #[test]
test_difference()487     fn test_difference() {
488         let mut a = HashSet::new();
489         let mut b = HashSet::new();
490 
491         assert!(a.insert(1));
492         assert!(a.insert(3));
493         assert!(a.insert(5));
494         assert!(a.insert(9));
495         assert!(a.insert(11));
496 
497         assert!(b.insert(3));
498         assert!(b.insert(9));
499 
500         let expected = [1, 5, 11];
501         let i = a
502             .par_difference(&b)
503             .map(|x| {
504                 assert!(expected.contains(x));
505                 1
506             })
507             .sum::<usize>();
508         assert_eq!(i, expected.len());
509     }
510 
511     #[test]
test_symmetric_difference()512     fn test_symmetric_difference() {
513         let mut a = HashSet::new();
514         let mut b = HashSet::new();
515 
516         assert!(a.insert(1));
517         assert!(a.insert(3));
518         assert!(a.insert(5));
519         assert!(a.insert(9));
520         assert!(a.insert(11));
521 
522         assert!(b.insert(-2));
523         assert!(b.insert(3));
524         assert!(b.insert(9));
525         assert!(b.insert(14));
526         assert!(b.insert(22));
527 
528         let expected = [-2, 1, 5, 11, 14, 22];
529         let i = a
530             .par_symmetric_difference(&b)
531             .map(|x| {
532                 assert!(expected.contains(x));
533                 1
534             })
535             .sum::<usize>();
536         assert_eq!(i, expected.len());
537     }
538 
539     #[test]
test_union()540     fn test_union() {
541         let mut a = HashSet::new();
542         let mut b = HashSet::new();
543 
544         assert!(a.insert(1));
545         assert!(a.insert(3));
546         assert!(a.insert(5));
547         assert!(a.insert(9));
548         assert!(a.insert(11));
549         assert!(a.insert(16));
550         assert!(a.insert(19));
551         assert!(a.insert(24));
552 
553         assert!(b.insert(-2));
554         assert!(b.insert(1));
555         assert!(b.insert(5));
556         assert!(b.insert(9));
557         assert!(b.insert(13));
558         assert!(b.insert(19));
559 
560         let expected = [-2, 1, 3, 5, 9, 11, 13, 16, 19, 24];
561         let i = a
562             .par_union(&b)
563             .map(|x| {
564                 assert!(expected.contains(x));
565                 1
566             })
567             .sum::<usize>();
568         assert_eq!(i, expected.len());
569     }
570 
571     #[test]
test_from_iter()572     fn test_from_iter() {
573         let xs = [1, 2, 3, 4, 5, 6, 7, 8, 9];
574 
575         let set: HashSet<_> = xs.par_iter().cloned().collect();
576 
577         for x in &xs {
578             assert!(set.contains(x));
579         }
580     }
581 
582     #[test]
test_move_iter()583     fn test_move_iter() {
584         let hs = {
585             let mut hs = HashSet::new();
586 
587             hs.insert('a');
588             hs.insert('b');
589 
590             hs
591         };
592 
593         let v = hs.into_par_iter().collect::<Vec<char>>();
594         assert!(v == ['a', 'b'] || v == ['b', 'a']);
595     }
596 
597     #[test]
test_eq()598     fn test_eq() {
599         // These constants once happened to expose a bug in insert().
600         // I'm keeping them around to prevent a regression.
601         let mut s1 = HashSet::new();
602 
603         s1.insert(1);
604         s1.insert(2);
605         s1.insert(3);
606 
607         let mut s2 = HashSet::new();
608 
609         s2.insert(1);
610         s2.insert(2);
611 
612         assert!(!s1.par_eq(&s2));
613 
614         s2.insert(3);
615 
616         assert!(s1.par_eq(&s2));
617     }
618 
619     #[test]
test_extend_ref()620     fn test_extend_ref() {
621         let mut a = HashSet::new();
622         a.insert(1);
623 
624         a.par_extend(&[2, 3, 4][..]);
625 
626         assert_eq!(a.len(), 4);
627         assert!(a.contains(&1));
628         assert!(a.contains(&2));
629         assert!(a.contains(&3));
630         assert!(a.contains(&4));
631 
632         let mut b = HashSet::new();
633         b.insert(5);
634         b.insert(6);
635 
636         a.par_extend(&b);
637 
638         assert_eq!(a.len(), 6);
639         assert!(a.contains(&1));
640         assert!(a.contains(&2));
641         assert!(a.contains(&3));
642         assert!(a.contains(&4));
643         assert!(a.contains(&5));
644         assert!(a.contains(&6));
645     }
646 }
647