1 //! Rayon extensions for `HashSet`. 2 3 use crate::hash_set::HashSet; 4 use core::hash::{BuildHasher, Hash}; 5 use rayon::iter::plumbing::UnindexedConsumer; 6 use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator}; 7 8 /// Parallel iterator over elements of a consumed set. 9 /// 10 /// This iterator is created by the [`into_par_iter`] method on [`HashSet`] 11 /// (provided by the [`IntoParallelIterator`] trait). 12 /// See its documentation for more. 13 /// 14 /// [`into_par_iter`]: /hashbrown/struct.HashSet.html#method.into_par_iter 15 /// [`HashSet`]: /hashbrown/struct.HashSet.html 16 /// [`IntoParallelIterator`]: https://docs.rs/rayon/1.0/rayon/iter/trait.IntoParallelIterator.html 17 pub struct IntoParIter<T, S> { 18 set: HashSet<T, S>, 19 } 20 21 impl<T: Send, S: Send> ParallelIterator for IntoParIter<T, S> { 22 type Item = T; 23 24 fn drive_unindexed<C>(self, consumer: C) -> C::Result 25 where 26 C: UnindexedConsumer<Self::Item>, 27 { 28 self.set 29 .map 30 .into_par_iter() 31 .map(|(k, _)| k) 32 .drive_unindexed(consumer) 33 } 34 } 35 36 /// Parallel draining iterator over entries of a set. 37 /// 38 /// This iterator is created by the [`par_drain`] method on [`HashSet`]. 39 /// See its documentation for more. 40 /// 41 /// [`par_drain`]: /hashbrown/struct.HashSet.html#method.par_drain 42 /// [`HashSet`]: /hashbrown/struct.HashSet.html 43 pub struct ParDrain<'a, T, S> { 44 set: &'a mut HashSet<T, S>, 45 } 46 47 impl<T: Send, S: Send> ParallelIterator for ParDrain<'_, T, S> { 48 type Item = T; 49 50 fn drive_unindexed<C>(self, consumer: C) -> C::Result 51 where 52 C: UnindexedConsumer<Self::Item>, 53 { 54 self.set 55 .map 56 .par_drain() 57 .map(|(k, _)| k) 58 .drive_unindexed(consumer) 59 } 60 } 61 62 /// Parallel iterator over shared references to elements in a set. 63 /// 64 /// This iterator is created by the [`par_iter`] method on [`HashSet`] 65 /// (provided by the [`IntoParallelRefIterator`] trait). 66 /// See its documentation for more. 67 /// 68 /// [`par_iter`]: /hashbrown/struct.HashSet.html#method.par_iter 69 /// [`HashSet`]: /hashbrown/struct.HashSet.html 70 /// [`IntoParallelRefIterator`]: https://docs.rs/rayon/1.0/rayon/iter/trait.IntoParallelRefIterator.html 71 pub struct ParIter<'a, T, S> { 72 set: &'a HashSet<T, S>, 73 } 74 75 impl<'a, T: Sync, S: Sync> ParallelIterator for ParIter<'a, T, S> { 76 type Item = &'a T; 77 78 fn drive_unindexed<C>(self, consumer: C) -> C::Result 79 where 80 C: UnindexedConsumer<Self::Item>, 81 { 82 self.set.map.par_keys().drive_unindexed(consumer) 83 } 84 } 85 86 /// Parallel iterator over shared references to elements in the difference of 87 /// sets. 88 /// 89 /// This iterator is created by the [`par_difference`] method on [`HashSet`]. 90 /// See its documentation for more. 91 /// 92 /// [`par_difference`]: /hashbrown/struct.HashSet.html#method.par_difference 93 /// [`HashSet`]: /hashbrown/struct.HashSet.html 94 pub struct ParDifference<'a, T, S> { 95 a: &'a HashSet<T, S>, 96 b: &'a HashSet<T, S>, 97 } 98 99 impl<'a, T, S> ParallelIterator for ParDifference<'a, T, S> 100 where 101 T: Eq + Hash + Sync, 102 S: BuildHasher + Sync, 103 { 104 type Item = &'a T; 105 106 fn drive_unindexed<C>(self, consumer: C) -> C::Result 107 where 108 C: UnindexedConsumer<Self::Item>, 109 { 110 self.a 111 .into_par_iter() 112 .filter(|&x| !self.b.contains(x)) 113 .drive_unindexed(consumer) 114 } 115 } 116 117 /// Parallel iterator over shared references to elements in the symmetric 118 /// difference of sets. 119 /// 120 /// This iterator is created by the [`par_symmetric_difference`] method on 121 /// [`HashSet`]. 122 /// See its documentation for more. 123 /// 124 /// [`par_symmetric_difference`]: /hashbrown/struct.HashSet.html#method.par_symmetric_difference 125 /// [`HashSet`]: /hashbrown/struct.HashSet.html 126 pub struct ParSymmetricDifference<'a, T, S> { 127 a: &'a HashSet<T, S>, 128 b: &'a HashSet<T, S>, 129 } 130 131 impl<'a, T, S> ParallelIterator for ParSymmetricDifference<'a, T, S> 132 where 133 T: Eq + Hash + Sync, 134 S: BuildHasher + Sync, 135 { 136 type Item = &'a T; 137 138 fn drive_unindexed<C>(self, consumer: C) -> C::Result 139 where 140 C: UnindexedConsumer<Self::Item>, 141 { 142 self.a 143 .par_difference(self.b) 144 .chain(self.b.par_difference(self.a)) 145 .drive_unindexed(consumer) 146 } 147 } 148 149 /// Parallel iterator over shared references to elements in the intersection of 150 /// sets. 151 /// 152 /// This iterator is created by the [`par_intersection`] method on [`HashSet`]. 153 /// See its documentation for more. 154 /// 155 /// [`par_intersection`]: /hashbrown/struct.HashSet.html#method.par_intersection 156 /// [`HashSet`]: /hashbrown/struct.HashSet.html 157 pub struct ParIntersection<'a, T, S> { 158 a: &'a HashSet<T, S>, 159 b: &'a HashSet<T, S>, 160 } 161 162 impl<'a, T, S> ParallelIterator for ParIntersection<'a, T, S> 163 where 164 T: Eq + Hash + Sync, 165 S: BuildHasher + Sync, 166 { 167 type Item = &'a T; 168 169 fn drive_unindexed<C>(self, consumer: C) -> C::Result 170 where 171 C: UnindexedConsumer<Self::Item>, 172 { 173 self.a 174 .into_par_iter() 175 .filter(|&x| self.b.contains(x)) 176 .drive_unindexed(consumer) 177 } 178 } 179 180 /// Parallel iterator over shared references to elements in the union of sets. 181 /// 182 /// This iterator is created by the [`par_union`] method on [`HashSet`]. 183 /// See its documentation for more. 184 /// 185 /// [`par_union`]: /hashbrown/struct.HashSet.html#method.par_union 186 /// [`HashSet`]: /hashbrown/struct.HashSet.html 187 pub struct ParUnion<'a, T, S> { 188 a: &'a HashSet<T, S>, 189 b: &'a HashSet<T, S>, 190 } 191 192 impl<'a, T, S> ParallelIterator for ParUnion<'a, T, S> 193 where 194 T: Eq + Hash + Sync, 195 S: BuildHasher + Sync, 196 { 197 type Item = &'a T; 198 199 fn drive_unindexed<C>(self, consumer: C) -> C::Result 200 where 201 C: UnindexedConsumer<Self::Item>, 202 { 203 self.a 204 .into_par_iter() 205 .chain(self.b.par_difference(self.a)) 206 .drive_unindexed(consumer) 207 } 208 } 209 210 impl<T, S> HashSet<T, S> 211 where 212 T: Eq + Hash + Sync, 213 S: BuildHasher + Sync, 214 { 215 /// Visits (potentially in parallel) the values representing the difference, 216 /// i.e. the values that are in `self` but not in `other`. 217 #[cfg_attr(feature = "inline-more", inline)] 218 pub fn par_difference<'a>(&'a self, other: &'a Self) -> ParDifference<'a, T, S> { 219 ParDifference { a: self, b: other } 220 } 221 222 /// Visits (potentially in parallel) the values representing the symmetric 223 /// difference, i.e. the values that are in `self` or in `other` but not in both. 224 #[cfg_attr(feature = "inline-more", inline)] 225 pub fn par_symmetric_difference<'a>( 226 &'a self, 227 other: &'a Self, 228 ) -> ParSymmetricDifference<'a, T, S> { 229 ParSymmetricDifference { a: self, b: other } 230 } 231 232 /// Visits (potentially in parallel) the values representing the 233 /// intersection, i.e. the values that are both in `self` and `other`. 234 #[cfg_attr(feature = "inline-more", inline)] 235 pub fn par_intersection<'a>(&'a self, other: &'a Self) -> ParIntersection<'a, T, S> { 236 ParIntersection { a: self, b: other } 237 } 238 239 /// Visits (potentially in parallel) the values representing the union, 240 /// i.e. all the values in `self` or `other`, without duplicates. 241 #[cfg_attr(feature = "inline-more", inline)] 242 pub fn par_union<'a>(&'a self, other: &'a Self) -> ParUnion<'a, T, S> { 243 ParUnion { a: self, b: other } 244 } 245 246 /// Returns `true` if `self` has no elements in common with `other`. 247 /// This is equivalent to checking for an empty intersection. 248 /// 249 /// This method runs in a potentially parallel fashion. 250 pub fn par_is_disjoint(&self, other: &Self) -> bool { 251 self.into_par_iter().all(|x| !other.contains(x)) 252 } 253 254 /// Returns `true` if the set is a subset of another, 255 /// i.e. `other` contains at least all the values in `self`. 256 /// 257 /// This method runs in a potentially parallel fashion. 258 pub fn par_is_subset(&self, other: &Self) -> bool { 259 if self.len() <= other.len() { 260 self.into_par_iter().all(|x| other.contains(x)) 261 } else { 262 false 263 } 264 } 265 266 /// Returns `true` if the set is a superset of another, 267 /// i.e. `self` contains at least all the values in `other`. 268 /// 269 /// This method runs in a potentially parallel fashion. 270 pub fn par_is_superset(&self, other: &Self) -> bool { 271 other.par_is_subset(self) 272 } 273 274 /// Returns `true` if the set is equal to another, 275 /// i.e. both sets contain the same values. 276 /// 277 /// This method runs in a potentially parallel fashion. 278 pub fn par_eq(&self, other: &Self) -> bool { 279 self.len() == other.len() && self.par_is_subset(other) 280 } 281 } 282 283 impl<T, S> HashSet<T, S> 284 where 285 T: Eq + Hash + Send, 286 S: BuildHasher + Send, 287 { 288 /// Consumes (potentially in parallel) all values in an arbitrary order, 289 /// while preserving the set's allocated memory for reuse. 290 #[cfg_attr(feature = "inline-more", inline)] 291 pub fn par_drain(&mut self) -> ParDrain<'_, T, S> { 292 ParDrain { set: self } 293 } 294 } 295 296 impl<T: Send, S: Send> IntoParallelIterator for HashSet<T, S> { 297 type Item = T; 298 type Iter = IntoParIter<T, S>; 299 300 #[cfg_attr(feature = "inline-more", inline)] 301 fn into_par_iter(self) -> Self::Iter { 302 IntoParIter { set: self } 303 } 304 } 305 306 impl<'a, T: Sync, S: Sync> IntoParallelIterator for &'a HashSet<T, S> { 307 type Item = &'a T; 308 type Iter = ParIter<'a, T, S>; 309 310 #[cfg_attr(feature = "inline-more", inline)] 311 fn into_par_iter(self) -> Self::Iter { 312 ParIter { set: self } 313 } 314 } 315 316 /// Collect values from a parallel iterator into a hashset. 317 impl<T, S> FromParallelIterator<T> for HashSet<T, S> 318 where 319 T: Eq + Hash + Send, 320 S: BuildHasher + Default, 321 { 322 fn from_par_iter<P>(par_iter: P) -> Self 323 where 324 P: IntoParallelIterator<Item = T>, 325 { 326 let mut set = HashSet::default(); 327 set.par_extend(par_iter); 328 set 329 } 330 } 331 332 /// Extend a hash set with items from a parallel iterator. 333 impl<T, S> ParallelExtend<T> for HashSet<T, S> 334 where 335 T: Eq + Hash + Send, 336 S: BuildHasher, 337 { 338 fn par_extend<I>(&mut self, par_iter: I) 339 where 340 I: IntoParallelIterator<Item = T>, 341 { 342 extend(self, par_iter); 343 } 344 } 345 346 /// Extend a hash set with copied items from a parallel iterator. 347 impl<'a, T, S> ParallelExtend<&'a T> for HashSet<T, S> 348 where 349 T: 'a + Copy + Eq + Hash + Sync, 350 S: BuildHasher, 351 { 352 fn par_extend<I>(&mut self, par_iter: I) 353 where 354 I: IntoParallelIterator<Item = &'a T>, 355 { 356 extend(self, par_iter); 357 } 358 } 359 360 // This is equal to the normal `HashSet` -- no custom advantage. 361 fn extend<T, S, I>(set: &mut HashSet<T, S>, par_iter: I) 362 where 363 T: Eq + Hash, 364 S: BuildHasher, 365 I: IntoParallelIterator, 366 HashSet<T, S>: Extend<I::Item>, 367 { 368 let (list, len) = super::helpers::collect(par_iter); 369 370 // Values may be already present or show multiple times in the iterator. 371 // Reserve the entire length if the set is empty. 372 // Otherwise reserve half the length (rounded up), so the set 373 // will only resize twice in the worst case. 374 let reserve = if set.is_empty() { len } else { (len + 1) / 2 }; 375 set.reserve(reserve); 376 for vec in list { 377 set.extend(vec); 378 } 379 } 380 381 #[cfg(test)] 382 mod test_par_set { 383 use alloc::vec::Vec; 384 use core::sync::atomic::{AtomicUsize, Ordering}; 385 386 use rayon::prelude::*; 387 388 use crate::hash_set::HashSet; 389 390 #[test] 391 fn test_disjoint() { 392 let mut xs = HashSet::new(); 393 let mut ys = HashSet::new(); 394 assert!(xs.par_is_disjoint(&ys)); 395 assert!(ys.par_is_disjoint(&xs)); 396 assert!(xs.insert(5)); 397 assert!(ys.insert(11)); 398 assert!(xs.par_is_disjoint(&ys)); 399 assert!(ys.par_is_disjoint(&xs)); 400 assert!(xs.insert(7)); 401 assert!(xs.insert(19)); 402 assert!(xs.insert(4)); 403 assert!(ys.insert(2)); 404 assert!(ys.insert(-11)); 405 assert!(xs.par_is_disjoint(&ys)); 406 assert!(ys.par_is_disjoint(&xs)); 407 assert!(ys.insert(7)); 408 assert!(!xs.par_is_disjoint(&ys)); 409 assert!(!ys.par_is_disjoint(&xs)); 410 } 411 412 #[test] 413 fn test_subset_and_superset() { 414 let mut a = HashSet::new(); 415 assert!(a.insert(0)); 416 assert!(a.insert(5)); 417 assert!(a.insert(11)); 418 assert!(a.insert(7)); 419 420 let mut b = HashSet::new(); 421 assert!(b.insert(0)); 422 assert!(b.insert(7)); 423 assert!(b.insert(19)); 424 assert!(b.insert(250)); 425 assert!(b.insert(11)); 426 assert!(b.insert(200)); 427 428 assert!(!a.par_is_subset(&b)); 429 assert!(!a.par_is_superset(&b)); 430 assert!(!b.par_is_subset(&a)); 431 assert!(!b.par_is_superset(&a)); 432 433 assert!(b.insert(5)); 434 435 assert!(a.par_is_subset(&b)); 436 assert!(!a.par_is_superset(&b)); 437 assert!(!b.par_is_subset(&a)); 438 assert!(b.par_is_superset(&a)); 439 } 440 441 #[test] 442 fn test_iterate() { 443 let mut a = HashSet::new(); 444 for i in 0..32 { 445 assert!(a.insert(i)); 446 } 447 let observed = AtomicUsize::new(0); 448 a.par_iter().for_each(|k| { 449 observed.fetch_or(1 << *k, Ordering::Relaxed); 450 }); 451 assert_eq!(observed.into_inner(), 0xFFFF_FFFF); 452 } 453 454 #[test] 455 fn test_intersection() { 456 let mut a = HashSet::new(); 457 let mut b = HashSet::new(); 458 459 assert!(a.insert(11)); 460 assert!(a.insert(1)); 461 assert!(a.insert(3)); 462 assert!(a.insert(77)); 463 assert!(a.insert(103)); 464 assert!(a.insert(5)); 465 assert!(a.insert(-5)); 466 467 assert!(b.insert(2)); 468 assert!(b.insert(11)); 469 assert!(b.insert(77)); 470 assert!(b.insert(-9)); 471 assert!(b.insert(-42)); 472 assert!(b.insert(5)); 473 assert!(b.insert(3)); 474 475 let expected = [3, 5, 11, 77]; 476 let i = a 477 .par_intersection(&b) 478 .map(|x| { 479 assert!(expected.contains(x)); 480 1 481 }) 482 .sum::<usize>(); 483 assert_eq!(i, expected.len()); 484 } 485 486 #[test] 487 fn test_difference() { 488 let mut a = HashSet::new(); 489 let mut b = HashSet::new(); 490 491 assert!(a.insert(1)); 492 assert!(a.insert(3)); 493 assert!(a.insert(5)); 494 assert!(a.insert(9)); 495 assert!(a.insert(11)); 496 497 assert!(b.insert(3)); 498 assert!(b.insert(9)); 499 500 let expected = [1, 5, 11]; 501 let i = a 502 .par_difference(&b) 503 .map(|x| { 504 assert!(expected.contains(x)); 505 1 506 }) 507 .sum::<usize>(); 508 assert_eq!(i, expected.len()); 509 } 510 511 #[test] 512 fn test_symmetric_difference() { 513 let mut a = HashSet::new(); 514 let mut b = HashSet::new(); 515 516 assert!(a.insert(1)); 517 assert!(a.insert(3)); 518 assert!(a.insert(5)); 519 assert!(a.insert(9)); 520 assert!(a.insert(11)); 521 522 assert!(b.insert(-2)); 523 assert!(b.insert(3)); 524 assert!(b.insert(9)); 525 assert!(b.insert(14)); 526 assert!(b.insert(22)); 527 528 let expected = [-2, 1, 5, 11, 14, 22]; 529 let i = a 530 .par_symmetric_difference(&b) 531 .map(|x| { 532 assert!(expected.contains(x)); 533 1 534 }) 535 .sum::<usize>(); 536 assert_eq!(i, expected.len()); 537 } 538 539 #[test] 540 fn test_union() { 541 let mut a = HashSet::new(); 542 let mut b = HashSet::new(); 543 544 assert!(a.insert(1)); 545 assert!(a.insert(3)); 546 assert!(a.insert(5)); 547 assert!(a.insert(9)); 548 assert!(a.insert(11)); 549 assert!(a.insert(16)); 550 assert!(a.insert(19)); 551 assert!(a.insert(24)); 552 553 assert!(b.insert(-2)); 554 assert!(b.insert(1)); 555 assert!(b.insert(5)); 556 assert!(b.insert(9)); 557 assert!(b.insert(13)); 558 assert!(b.insert(19)); 559 560 let expected = [-2, 1, 3, 5, 9, 11, 13, 16, 19, 24]; 561 let i = a 562 .par_union(&b) 563 .map(|x| { 564 assert!(expected.contains(x)); 565 1 566 }) 567 .sum::<usize>(); 568 assert_eq!(i, expected.len()); 569 } 570 571 #[test] 572 fn test_from_iter() { 573 let xs = [1, 2, 3, 4, 5, 6, 7, 8, 9]; 574 575 let set: HashSet<_> = xs.par_iter().cloned().collect(); 576 577 for x in &xs { 578 assert!(set.contains(x)); 579 } 580 } 581 582 #[test] 583 fn test_move_iter() { 584 let hs = { 585 let mut hs = HashSet::new(); 586 587 hs.insert('a'); 588 hs.insert('b'); 589 590 hs 591 }; 592 593 let v = hs.into_par_iter().collect::<Vec<char>>(); 594 assert!(v == ['a', 'b'] || v == ['b', 'a']); 595 } 596 597 #[test] 598 fn test_eq() { 599 // These constants once happened to expose a bug in insert(). 600 // I'm keeping them around to prevent a regression. 601 let mut s1 = HashSet::new(); 602 603 s1.insert(1); 604 s1.insert(2); 605 s1.insert(3); 606 607 let mut s2 = HashSet::new(); 608 609 s2.insert(1); 610 s2.insert(2); 611 612 assert!(!s1.par_eq(&s2)); 613 614 s2.insert(3); 615 616 assert!(s1.par_eq(&s2)); 617 } 618 619 #[test] 620 fn test_extend_ref() { 621 let mut a = HashSet::new(); 622 a.insert(1); 623 624 a.par_extend(&[2, 3, 4][..]); 625 626 assert_eq!(a.len(), 4); 627 assert!(a.contains(&1)); 628 assert!(a.contains(&2)); 629 assert!(a.contains(&3)); 630 assert!(a.contains(&4)); 631 632 let mut b = HashSet::new(); 633 b.insert(5); 634 b.insert(6); 635 636 a.par_extend(&b); 637 638 assert_eq!(a.len(), 6); 639 assert!(a.contains(&1)); 640 assert!(a.contains(&2)); 641 assert!(a.contains(&3)); 642 assert!(a.contains(&4)); 643 assert!(a.contains(&5)); 644 assert!(a.contains(&6)); 645 } 646 } 647