1 use std::fmt;
2 use std::sync::{Arc, Condvar, Mutex};
3 
4 /// Enables threads to synchronize the beginning or end of some computation.
5 ///
6 /// # Wait groups vs barriers
7 ///
8 /// `WaitGroup` is very similar to [`Barrier`], but there are a few differences:
9 ///
10 /// * `Barrier` needs to know the number of threads at construction, while `WaitGroup` is cloned to
11 ///   register more threads.
12 ///
13 /// * A `Barrier` can be reused even after all threads have synchronized, while a `WaitGroup`
14 ///   synchronizes threads only once.
15 ///
16 /// * All threads wait for others to reach the `Barrier`. With `WaitGroup`, each thread can choose
17 ///   to either wait for other threads or to continue without blocking.
18 ///
19 /// # Examples
20 ///
21 /// ```
22 /// use crossbeam_utils::sync::WaitGroup;
23 /// use std::thread;
24 ///
25 /// // Create a new wait group.
26 /// let wg = WaitGroup::new();
27 ///
28 /// for _ in 0..4 {
29 ///     // Create another reference to the wait group.
30 ///     let wg = wg.clone();
31 ///
32 ///     thread::spawn(move || {
33 ///         // Do some work.
34 ///
35 ///         // Drop the reference to the wait group.
36 ///         drop(wg);
37 ///     });
38 /// }
39 ///
40 /// // Block until all threads have finished their work.
41 /// wg.wait();
42 /// ```
43 ///
44 /// [`Barrier`]: https://doc.rust-lang.org/std/sync/struct.Barrier.html
45 pub struct WaitGroup {
46     inner: Arc<Inner>,
47 }
48 
49 /// Inner state of a `WaitGroup`.
50 struct Inner {
51     cvar: Condvar,
52     count: Mutex<usize>,
53 }
54 
55 impl WaitGroup {
56     /// Creates a new wait group and returns the single reference to it.
57     ///
58     /// # Examples
59     ///
60     /// ```
61     /// use crossbeam_utils::sync::WaitGroup;
62     ///
63     /// let wg = WaitGroup::new();
64     /// ```
new() -> WaitGroup65     pub fn new() -> WaitGroup {
66         WaitGroup {
67             inner: Arc::new(Inner {
68                 cvar: Condvar::new(),
69                 count: Mutex::new(1),
70             }),
71         }
72     }
73 
74     /// Drops this reference and waits until all other references are dropped.
75     ///
76     /// # Examples
77     ///
78     /// ```
79     /// use crossbeam_utils::sync::WaitGroup;
80     /// use std::thread;
81     ///
82     /// let wg = WaitGroup::new();
83     ///
84     /// thread::spawn({
85     ///     let wg = wg.clone();
86     ///     move || {
87     ///         // Block until both threads have reached `wait()`.
88     ///         wg.wait();
89     ///     }
90     /// });
91     ///
92     /// // Block until both threads have reached `wait()`.
93     /// wg.wait();
94     /// ```
wait(self)95     pub fn wait(self) {
96         if *self.inner.count.lock().unwrap() == 1 {
97             return;
98         }
99 
100         let inner = self.inner.clone();
101         drop(self);
102 
103         let mut count = inner.count.lock().unwrap();
104         while *count > 0 {
105             count = inner.cvar.wait(count).unwrap();
106         }
107     }
108 }
109 
110 impl Drop for WaitGroup {
drop(&mut self)111     fn drop(&mut self) {
112         let mut count = self.inner.count.lock().unwrap();
113         *count -= 1;
114 
115         if *count == 0 {
116             self.inner.cvar.notify_all();
117         }
118     }
119 }
120 
121 impl Clone for WaitGroup {
clone(&self) -> WaitGroup122     fn clone(&self) -> WaitGroup {
123         let mut count = self.inner.count.lock().unwrap();
124         *count += 1;
125 
126         WaitGroup {
127             inner: self.inner.clone(),
128         }
129     }
130 }
131 
132 impl fmt::Debug for WaitGroup {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result133     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
134         let count: &usize = &*self.inner.count.lock().unwrap();
135         f.debug_struct("WaitGroup").field("count", count).finish()
136     }
137 }
138