1 use crate::loom::sync::atomic::AtomicUsize;
2 use crate::loom::sync::{Arc, Condvar, Mutex};
3 use crate::park::{Park, Unpark};
4 
5 use std::sync::atomic::Ordering::SeqCst;
6 use std::time::Duration;
7 
8 #[derive(Debug)]
9 pub(crate) struct ParkThread {
10     inner: Arc<Inner>,
11 }
12 
13 pub(crate) type ParkError = ();
14 
15 /// Unblocks a thread that was blocked by `ParkThread`.
16 #[derive(Clone, Debug)]
17 pub(crate) struct UnparkThread {
18     inner: Arc<Inner>,
19 }
20 
21 #[derive(Debug)]
22 struct Inner {
23     state: AtomicUsize,
24     mutex: Mutex<()>,
25     condvar: Condvar,
26 }
27 
28 const EMPTY: usize = 0;
29 const PARKED: usize = 1;
30 const NOTIFIED: usize = 2;
31 
32 thread_local! {
33     static CURRENT_PARKER: ParkThread = ParkThread::new();
34 }
35 
36 // ==== impl ParkThread ====
37 
38 impl ParkThread {
new() -> Self39     pub(crate) fn new() -> Self {
40         Self {
41             inner: Arc::new(Inner {
42                 state: AtomicUsize::new(EMPTY),
43                 mutex: Mutex::new(()),
44                 condvar: Condvar::new(),
45             }),
46         }
47     }
48 }
49 
50 impl Park for ParkThread {
51     type Unpark = UnparkThread;
52     type Error = ParkError;
53 
unpark(&self) -> Self::Unpark54     fn unpark(&self) -> Self::Unpark {
55         let inner = self.inner.clone();
56         UnparkThread { inner }
57     }
58 
park(&mut self) -> Result<(), Self::Error>59     fn park(&mut self) -> Result<(), Self::Error> {
60         self.inner.park();
61         Ok(())
62     }
63 
park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error>64     fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> {
65         self.inner.park_timeout(duration);
66         Ok(())
67     }
68 }
69 
70 // ==== impl Inner ====
71 
72 impl Inner {
73     /// Park the current thread for at most `dur`.
park(&self)74     fn park(&self) {
75         // If we were previously notified then we consume this notification and
76         // return quickly.
77         if self
78             .state
79             .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst)
80             .is_ok()
81         {
82             return;
83         }
84 
85         // Otherwise we need to coordinate going to sleep
86         let mut m = self.mutex.lock().unwrap();
87 
88         match self.state.compare_exchange(EMPTY, PARKED, SeqCst, SeqCst) {
89             Ok(_) => {}
90             Err(NOTIFIED) => {
91                 // We must read here, even though we know it will be `NOTIFIED`.
92                 // This is because `unpark` may have been called again since we read
93                 // `NOTIFIED` in the `compare_exchange` above. We must perform an
94                 // acquire operation that synchronizes with that `unpark` to observe
95                 // any writes it made before the call to unpark. To do that we must
96                 // read from the write it made to `state`.
97                 let old = self.state.swap(EMPTY, SeqCst);
98                 debug_assert_eq!(old, NOTIFIED, "park state changed unexpectedly");
99 
100                 return;
101             }
102             Err(actual) => panic!("inconsistent park state; actual = {}", actual),
103         }
104 
105         loop {
106             m = self.condvar.wait(m).unwrap();
107 
108             if self
109                 .state
110                 .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst)
111                 .is_ok()
112             {
113                 // got a notification
114                 return;
115             }
116 
117             // spurious wakeup, go back to sleep
118         }
119     }
120 
park_timeout(&self, dur: Duration)121     fn park_timeout(&self, dur: Duration) {
122         // Like `park` above we have a fast path for an already-notified thread,
123         // and afterwards we start coordinating for a sleep. Return quickly.
124         if self
125             .state
126             .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst)
127             .is_ok()
128         {
129             return;
130         }
131 
132         if dur == Duration::from_millis(0) {
133             return;
134         }
135 
136         let m = self.mutex.lock().unwrap();
137 
138         match self.state.compare_exchange(EMPTY, PARKED, SeqCst, SeqCst) {
139             Ok(_) => {}
140             Err(NOTIFIED) => {
141                 // We must read again here, see `park`.
142                 let old = self.state.swap(EMPTY, SeqCst);
143                 debug_assert_eq!(old, NOTIFIED, "park state changed unexpectedly");
144 
145                 return;
146             }
147             Err(actual) => panic!("inconsistent park_timeout state; actual = {}", actual),
148         }
149 
150         // Wait with a timeout, and if we spuriously wake up or otherwise wake up
151         // from a notification, we just want to unconditionally set the state back to
152         // empty, either consuming a notification or un-flagging ourselves as
153         // parked.
154         let (_m, _result) = self.condvar.wait_timeout(m, dur).unwrap();
155 
156         match self.state.swap(EMPTY, SeqCst) {
157             NOTIFIED => {} // got a notification, hurray!
158             PARKED => {}   // no notification, alas
159             n => panic!("inconsistent park_timeout state: {}", n),
160         }
161     }
162 
unpark(&self)163     fn unpark(&self) {
164         // To ensure the unparked thread will observe any writes we made before
165         // this call, we must perform a release operation that `park` can
166         // synchronize with. To do that we must write `NOTIFIED` even if `state`
167         // is already `NOTIFIED`. That is why this must be a swap rather than a
168         // compare-and-swap that returns if it reads `NOTIFIED` on failure.
169         match self.state.swap(NOTIFIED, SeqCst) {
170             EMPTY => return,    // no one was waiting
171             NOTIFIED => return, // already unparked
172             PARKED => {}        // gotta go wake someone up
173             _ => panic!("inconsistent state in unpark"),
174         }
175 
176         // There is a period between when the parked thread sets `state` to
177         // `PARKED` (or last checked `state` in the case of a spurious wake
178         // up) and when it actually waits on `cvar`. If we were to notify
179         // during this period it would be ignored and then when the parked
180         // thread went to sleep it would never wake up. Fortunately, it has
181         // `lock` locked at this stage so we can acquire `lock` to wait until
182         // it is ready to receive the notification.
183         //
184         // Releasing `lock` before the call to `notify_one` means that when the
185         // parked thread wakes it doesn't get woken only to have to wait for us
186         // to release `lock`.
187         drop(self.mutex.lock().unwrap());
188 
189         self.condvar.notify_one()
190     }
191 }
192 
193 impl Default for ParkThread {
default() -> Self194     fn default() -> Self {
195         Self::new()
196     }
197 }
198 
199 // ===== impl UnparkThread =====
200 
201 impl Unpark for UnparkThread {
unpark(&self)202     fn unpark(&self) {
203         self.inner.unpark();
204     }
205 }
206 
207 cfg_blocking_impl! {
208     use std::marker::PhantomData;
209     use std::rc::Rc;
210 
211     use std::mem;
212     use std::task::{RawWaker, RawWakerVTable, Waker};
213 
214     /// Blocks the current thread using a condition variable.
215     #[derive(Debug)]
216     pub(crate) struct CachedParkThread {
217         _anchor: PhantomData<Rc<()>>,
218     }
219 
220     impl CachedParkThread {
221         /// Create a new `ParkThread` handle for the current thread.
222         ///
223         /// This type cannot be moved to other threads, so it should be created on
224         /// the thread that the caller intends to park.
225         pub(crate) fn new() -> CachedParkThread {
226             CachedParkThread {
227                 _anchor: PhantomData,
228             }
229         }
230 
231         pub(crate) fn get_unpark(&self) -> Result<UnparkThread, ParkError> {
232             self.with_current(|park_thread| park_thread.unpark())
233         }
234 
235         /// Get a reference to the `ParkThread` handle for this thread.
236         fn with_current<F, R>(&self, f: F) -> Result<R, ParkError>
237         where
238             F: FnOnce(&ParkThread) -> R,
239         {
240             CURRENT_PARKER.try_with(|inner| f(inner))
241                 .map_err(|_| ())
242         }
243     }
244 
245     impl Park for CachedParkThread {
246         type Unpark = UnparkThread;
247         type Error = ParkError;
248 
249         fn unpark(&self) -> Self::Unpark {
250             self.get_unpark().unwrap()
251         }
252 
253         fn park(&mut self) -> Result<(), Self::Error> {
254             self.with_current(|park_thread| park_thread.inner.park())?;
255             Ok(())
256         }
257 
258         fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> {
259             self.with_current(|park_thread| park_thread.inner.park_timeout(duration))?;
260             Ok(())
261         }
262     }
263 
264 
265     impl UnparkThread {
266         pub(crate) fn into_waker(self) -> Waker {
267             unsafe {
268                 let raw = unparker_to_raw_waker(self.inner);
269                 Waker::from_raw(raw)
270             }
271         }
272     }
273 
274     impl Inner {
275         #[allow(clippy::wrong_self_convention)]
276         fn into_raw(this: Arc<Inner>) -> *const () {
277             Arc::into_raw(this) as *const ()
278         }
279 
280         unsafe fn from_raw(ptr: *const ()) -> Arc<Inner> {
281             Arc::from_raw(ptr as *const Inner)
282         }
283     }
284 
285     unsafe fn unparker_to_raw_waker(unparker: Arc<Inner>) -> RawWaker {
286         RawWaker::new(
287             Inner::into_raw(unparker),
288             &RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker),
289         )
290     }
291 
292     unsafe fn clone(raw: *const ()) -> RawWaker {
293         let unparker = Inner::from_raw(raw);
294 
295         // Increment the ref count
296         mem::forget(unparker.clone());
297 
298         unparker_to_raw_waker(unparker)
299     }
300 
301     unsafe fn drop_waker(raw: *const ()) {
302         let _ = Inner::from_raw(raw);
303     }
304 
305     unsafe fn wake(raw: *const ()) {
306         let unparker = Inner::from_raw(raw);
307         unparker.unpark();
308     }
309 
310     unsafe fn wake_by_ref(raw: *const ()) {
311         let unparker = Inner::from_raw(raw);
312         unparker.unpark();
313 
314         // We don't actually own a reference to the unparker
315         mem::forget(unparker);
316     }
317 }
318