1 use rayon::prelude::*;
2
3 use std::panic;
4 use std::sync::atomic::AtomicUsize;
5 use std::sync::atomic::Ordering;
6 use std::sync::Mutex;
7
8 #[test]
collect_drop_on_unwind()9 fn collect_drop_on_unwind() {
10 struct Recorddrop<'a>(i64, &'a Mutex<Vec<i64>>);
11
12 impl<'a> Drop for Recorddrop<'a> {
13 fn drop(&mut self) {
14 self.1.lock().unwrap().push(self.0);
15 }
16 }
17
18 let test_collect_panic = |will_panic: bool| {
19 let test_vec_len = 1024;
20 let panic_point = 740;
21
22 let mut inserts = Mutex::new(Vec::new());
23 let mut drops = Mutex::new(Vec::new());
24
25 let mut a = (0..test_vec_len).collect::<Vec<_>>();
26 let b = (0..test_vec_len).collect::<Vec<_>>();
27
28 let _result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
29 let mut result = Vec::new();
30 a.par_iter_mut()
31 .zip(&b)
32 .map(|(&mut a, &b)| {
33 if a > panic_point && will_panic {
34 panic!("unwinding for test");
35 }
36 let elt = a + b;
37 inserts.lock().unwrap().push(elt);
38 Recorddrop(elt, &drops)
39 })
40 .collect_into_vec(&mut result);
41
42 // If we reach this point, this must pass
43 assert_eq!(a.len(), result.len());
44 }));
45
46 let inserts = inserts.get_mut().unwrap();
47 let drops = drops.get_mut().unwrap();
48 println!("{:?}", inserts);
49 println!("{:?}", drops);
50
51 assert_eq!(inserts.len(), drops.len(), "Incorrect number of drops");
52 // sort to normalize order
53 inserts.sort();
54 drops.sort();
55 assert_eq!(inserts, drops, "Incorrect elements were dropped");
56 };
57
58 for &should_panic in &[true, false] {
59 test_collect_panic(should_panic);
60 }
61 }
62
63 #[test]
collect_drop_on_unwind_zst()64 fn collect_drop_on_unwind_zst() {
65 static INSERTS: AtomicUsize = AtomicUsize::new(0);
66 static DROPS: AtomicUsize = AtomicUsize::new(0);
67
68 struct RecorddropZst;
69
70 impl Drop for RecorddropZst {
71 fn drop(&mut self) {
72 DROPS.fetch_add(1, Ordering::SeqCst);
73 }
74 }
75
76 let test_collect_panic = |will_panic: bool| {
77 INSERTS.store(0, Ordering::SeqCst);
78 DROPS.store(0, Ordering::SeqCst);
79
80 let test_vec_len = 1024;
81 let panic_point = 740;
82
83 let a = (0..test_vec_len).collect::<Vec<_>>();
84
85 let _result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
86 let mut result = Vec::new();
87 a.par_iter()
88 .map(|&a| {
89 if a > panic_point && will_panic {
90 panic!("unwinding for test");
91 }
92 INSERTS.fetch_add(1, Ordering::SeqCst);
93 RecorddropZst
94 })
95 .collect_into_vec(&mut result);
96
97 // If we reach this point, this must pass
98 assert_eq!(a.len(), result.len());
99 }));
100
101 let inserts = INSERTS.load(Ordering::SeqCst);
102 let drops = DROPS.load(Ordering::SeqCst);
103
104 assert_eq!(inserts, drops, "Incorrect number of drops");
105 assert!(will_panic || drops == test_vec_len)
106 };
107
108 for &should_panic in &[true, false] {
109 test_collect_panic(should_panic);
110 }
111 }
112