1 use rayon::prelude::*;
2 
3 use std::panic;
4 use std::sync::atomic::AtomicUsize;
5 use std::sync::atomic::Ordering;
6 use std::sync::Mutex;
7 
8 #[test]
collect_drop_on_unwind()9 fn collect_drop_on_unwind() {
10     struct Recorddrop<'a>(i64, &'a Mutex<Vec<i64>>);
11 
12     impl<'a> Drop for Recorddrop<'a> {
13         fn drop(&mut self) {
14             self.1.lock().unwrap().push(self.0);
15         }
16     }
17 
18     let test_collect_panic = |will_panic: bool| {
19         let test_vec_len = 1024;
20         let panic_point = 740;
21 
22         let mut inserts = Mutex::new(Vec::new());
23         let mut drops = Mutex::new(Vec::new());
24 
25         let mut a = (0..test_vec_len).collect::<Vec<_>>();
26         let b = (0..test_vec_len).collect::<Vec<_>>();
27 
28         let _result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
29             let mut result = Vec::new();
30             a.par_iter_mut()
31                 .zip(&b)
32                 .map(|(&mut a, &b)| {
33                     if a > panic_point && will_panic {
34                         panic!("unwinding for test");
35                     }
36                     let elt = a + b;
37                     inserts.lock().unwrap().push(elt);
38                     Recorddrop(elt, &drops)
39                 })
40                 .collect_into_vec(&mut result);
41 
42             // If we reach this point, this must pass
43             assert_eq!(a.len(), result.len());
44         }));
45 
46         let inserts = inserts.get_mut().unwrap();
47         let drops = drops.get_mut().unwrap();
48         println!("{:?}", inserts);
49         println!("{:?}", drops);
50 
51         assert_eq!(inserts.len(), drops.len(), "Incorrect number of drops");
52         // sort to normalize order
53         inserts.sort();
54         drops.sort();
55         assert_eq!(inserts, drops, "Incorrect elements were dropped");
56     };
57 
58     for &should_panic in &[true, false] {
59         test_collect_panic(should_panic);
60     }
61 }
62 
63 #[test]
collect_drop_on_unwind_zst()64 fn collect_drop_on_unwind_zst() {
65     static INSERTS: AtomicUsize = AtomicUsize::new(0);
66     static DROPS: AtomicUsize = AtomicUsize::new(0);
67 
68     struct RecorddropZst;
69 
70     impl Drop for RecorddropZst {
71         fn drop(&mut self) {
72             DROPS.fetch_add(1, Ordering::SeqCst);
73         }
74     }
75 
76     let test_collect_panic = |will_panic: bool| {
77         INSERTS.store(0, Ordering::SeqCst);
78         DROPS.store(0, Ordering::SeqCst);
79 
80         let test_vec_len = 1024;
81         let panic_point = 740;
82 
83         let a = (0..test_vec_len).collect::<Vec<_>>();
84 
85         let _result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
86             let mut result = Vec::new();
87             a.par_iter()
88                 .map(|&a| {
89                     if a > panic_point && will_panic {
90                         panic!("unwinding for test");
91                     }
92                     INSERTS.fetch_add(1, Ordering::SeqCst);
93                     RecorddropZst
94                 })
95                 .collect_into_vec(&mut result);
96 
97             // If we reach this point, this must pass
98             assert_eq!(a.len(), result.len());
99         }));
100 
101         let inserts = INSERTS.load(Ordering::SeqCst);
102         let drops = DROPS.load(Ordering::SeqCst);
103 
104         assert_eq!(inserts, drops, "Incorrect number of drops");
105         assert!(will_panic || drops == test_vec_len)
106     };
107 
108     for &should_panic in &[true, false] {
109         test_collect_panic(should_panic);
110     }
111 }
112