1 #![cfg_attr(not(feature = "full"), allow(dead_code))]
2 
3 use crate::loom::sync::atomic::AtomicUsize;
4 use crate::loom::sync::{Arc, Condvar, Mutex};
5 use crate::park::{Park, Unpark};
6 
7 use std::sync::atomic::Ordering::SeqCst;
8 use std::time::Duration;
9 
10 #[derive(Debug)]
11 pub(crate) struct ParkThread {
12     inner: Arc<Inner>,
13 }
14 
15 pub(crate) type ParkError = ();
16 
17 /// Unblocks a thread that was blocked by `ParkThread`.
18 #[derive(Clone, Debug)]
19 pub(crate) struct UnparkThread {
20     inner: Arc<Inner>,
21 }
22 
23 #[derive(Debug)]
24 struct Inner {
25     state: AtomicUsize,
26     mutex: Mutex<()>,
27     condvar: Condvar,
28 }
29 
30 const EMPTY: usize = 0;
31 const PARKED: usize = 1;
32 const NOTIFIED: usize = 2;
33 
34 thread_local! {
35     static CURRENT_PARKER: ParkThread = ParkThread::new();
36 }
37 
38 // ==== impl ParkThread ====
39 
40 impl ParkThread {
new() -> Self41     pub(crate) fn new() -> Self {
42         Self {
43             inner: Arc::new(Inner {
44                 state: AtomicUsize::new(EMPTY),
45                 mutex: Mutex::new(()),
46                 condvar: Condvar::new(),
47             }),
48         }
49     }
50 }
51 
52 impl Park for ParkThread {
53     type Unpark = UnparkThread;
54     type Error = ParkError;
55 
unpark(&self) -> Self::Unpark56     fn unpark(&self) -> Self::Unpark {
57         let inner = self.inner.clone();
58         UnparkThread { inner }
59     }
60 
park(&mut self) -> Result<(), Self::Error>61     fn park(&mut self) -> Result<(), Self::Error> {
62         self.inner.park();
63         Ok(())
64     }
65 
park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error>66     fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> {
67         self.inner.park_timeout(duration);
68         Ok(())
69     }
70 
shutdown(&mut self)71     fn shutdown(&mut self) {
72         self.inner.shutdown();
73     }
74 }
75 
76 // ==== impl Inner ====
77 
78 impl Inner {
79     /// Park the current thread for at most `dur`.
park(&self)80     fn park(&self) {
81         // If we were previously notified then we consume this notification and
82         // return quickly.
83         if self
84             .state
85             .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst)
86             .is_ok()
87         {
88             return;
89         }
90 
91         // Otherwise we need to coordinate going to sleep
92         let mut m = self.mutex.lock();
93 
94         match self.state.compare_exchange(EMPTY, PARKED, SeqCst, SeqCst) {
95             Ok(_) => {}
96             Err(NOTIFIED) => {
97                 // We must read here, even though we know it will be `NOTIFIED`.
98                 // This is because `unpark` may have been called again since we read
99                 // `NOTIFIED` in the `compare_exchange` above. We must perform an
100                 // acquire operation that synchronizes with that `unpark` to observe
101                 // any writes it made before the call to unpark. To do that we must
102                 // read from the write it made to `state`.
103                 let old = self.state.swap(EMPTY, SeqCst);
104                 debug_assert_eq!(old, NOTIFIED, "park state changed unexpectedly");
105 
106                 return;
107             }
108             Err(actual) => panic!("inconsistent park state; actual = {}", actual),
109         }
110 
111         loop {
112             m = self.condvar.wait(m).unwrap();
113 
114             if self
115                 .state
116                 .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst)
117                 .is_ok()
118             {
119                 // got a notification
120                 return;
121             }
122 
123             // spurious wakeup, go back to sleep
124         }
125     }
126 
park_timeout(&self, dur: Duration)127     fn park_timeout(&self, dur: Duration) {
128         // Like `park` above we have a fast path for an already-notified thread,
129         // and afterwards we start coordinating for a sleep. Return quickly.
130         if self
131             .state
132             .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst)
133             .is_ok()
134         {
135             return;
136         }
137 
138         if dur == Duration::from_millis(0) {
139             return;
140         }
141 
142         let m = self.mutex.lock();
143 
144         match self.state.compare_exchange(EMPTY, PARKED, SeqCst, SeqCst) {
145             Ok(_) => {}
146             Err(NOTIFIED) => {
147                 // We must read again here, see `park`.
148                 let old = self.state.swap(EMPTY, SeqCst);
149                 debug_assert_eq!(old, NOTIFIED, "park state changed unexpectedly");
150 
151                 return;
152             }
153             Err(actual) => panic!("inconsistent park_timeout state; actual = {}", actual),
154         }
155 
156         // Wait with a timeout, and if we spuriously wake up or otherwise wake up
157         // from a notification, we just want to unconditionally set the state back to
158         // empty, either consuming a notification or un-flagging ourselves as
159         // parked.
160         let (_m, _result) = self.condvar.wait_timeout(m, dur).unwrap();
161 
162         match self.state.swap(EMPTY, SeqCst) {
163             NOTIFIED => {} // got a notification, hurray!
164             PARKED => {}   // no notification, alas
165             n => panic!("inconsistent park_timeout state: {}", n),
166         }
167     }
168 
unpark(&self)169     fn unpark(&self) {
170         // To ensure the unparked thread will observe any writes we made before
171         // this call, we must perform a release operation that `park` can
172         // synchronize with. To do that we must write `NOTIFIED` even if `state`
173         // is already `NOTIFIED`. That is why this must be a swap rather than a
174         // compare-and-swap that returns if it reads `NOTIFIED` on failure.
175         match self.state.swap(NOTIFIED, SeqCst) {
176             EMPTY => return,    // no one was waiting
177             NOTIFIED => return, // already unparked
178             PARKED => {}        // gotta go wake someone up
179             _ => panic!("inconsistent state in unpark"),
180         }
181 
182         // There is a period between when the parked thread sets `state` to
183         // `PARKED` (or last checked `state` in the case of a spurious wake
184         // up) and when it actually waits on `cvar`. If we were to notify
185         // during this period it would be ignored and then when the parked
186         // thread went to sleep it would never wake up. Fortunately, it has
187         // `lock` locked at this stage so we can acquire `lock` to wait until
188         // it is ready to receive the notification.
189         //
190         // Releasing `lock` before the call to `notify_one` means that when the
191         // parked thread wakes it doesn't get woken only to have to wait for us
192         // to release `lock`.
193         drop(self.mutex.lock());
194 
195         self.condvar.notify_one()
196     }
197 
shutdown(&self)198     fn shutdown(&self) {
199         self.condvar.notify_all();
200     }
201 }
202 
203 impl Default for ParkThread {
default() -> Self204     fn default() -> Self {
205         Self::new()
206     }
207 }
208 
209 // ===== impl UnparkThread =====
210 
211 impl Unpark for UnparkThread {
unpark(&self)212     fn unpark(&self) {
213         self.inner.unpark();
214     }
215 }
216 
217 use std::future::Future;
218 use std::marker::PhantomData;
219 use std::mem;
220 use std::rc::Rc;
221 use std::task::{RawWaker, RawWakerVTable, Waker};
222 
223 /// Blocks the current thread using a condition variable.
224 #[derive(Debug)]
225 pub(crate) struct CachedParkThread {
226     _anchor: PhantomData<Rc<()>>,
227 }
228 
229 impl CachedParkThread {
230     /// Create a new `ParkThread` handle for the current thread.
231     ///
232     /// This type cannot be moved to other threads, so it should be created on
233     /// the thread that the caller intends to park.
new() -> CachedParkThread234     pub(crate) fn new() -> CachedParkThread {
235         CachedParkThread {
236             _anchor: PhantomData,
237         }
238     }
239 
get_unpark(&self) -> Result<UnparkThread, ParkError>240     pub(crate) fn get_unpark(&self) -> Result<UnparkThread, ParkError> {
241         self.with_current(|park_thread| park_thread.unpark())
242     }
243 
244     /// Get a reference to the `ParkThread` handle for this thread.
with_current<F, R>(&self, f: F) -> Result<R, ParkError> where F: FnOnce(&ParkThread) -> R,245     fn with_current<F, R>(&self, f: F) -> Result<R, ParkError>
246     where
247         F: FnOnce(&ParkThread) -> R,
248     {
249         CURRENT_PARKER.try_with(|inner| f(inner)).map_err(|_| ())
250     }
251 
block_on<F: Future>(&mut self, f: F) -> Result<F::Output, ParkError>252     pub(crate) fn block_on<F: Future>(&mut self, f: F) -> Result<F::Output, ParkError> {
253         use std::task::Context;
254         use std::task::Poll::Ready;
255 
256         // `get_unpark()` should not return a Result
257         let waker = self.get_unpark()?.into_waker();
258         let mut cx = Context::from_waker(&waker);
259 
260         pin!(f);
261 
262         loop {
263             if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) {
264                 return Ok(v);
265             }
266 
267             self.park()?;
268         }
269     }
270 }
271 
272 impl Park for CachedParkThread {
273     type Unpark = UnparkThread;
274     type Error = ParkError;
275 
unpark(&self) -> Self::Unpark276     fn unpark(&self) -> Self::Unpark {
277         self.get_unpark().unwrap()
278     }
279 
park(&mut self) -> Result<(), Self::Error>280     fn park(&mut self) -> Result<(), Self::Error> {
281         self.with_current(|park_thread| park_thread.inner.park())?;
282         Ok(())
283     }
284 
park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error>285     fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> {
286         self.with_current(|park_thread| park_thread.inner.park_timeout(duration))?;
287         Ok(())
288     }
289 
shutdown(&mut self)290     fn shutdown(&mut self) {
291         let _ = self.with_current(|park_thread| park_thread.inner.shutdown());
292     }
293 }
294 
295 impl UnparkThread {
into_waker(self) -> Waker296     pub(crate) fn into_waker(self) -> Waker {
297         unsafe {
298             let raw = unparker_to_raw_waker(self.inner);
299             Waker::from_raw(raw)
300         }
301     }
302 }
303 
304 impl Inner {
305     #[allow(clippy::wrong_self_convention)]
into_raw(this: Arc<Inner>) -> *const ()306     fn into_raw(this: Arc<Inner>) -> *const () {
307         Arc::into_raw(this) as *const ()
308     }
309 
from_raw(ptr: *const ()) -> Arc<Inner>310     unsafe fn from_raw(ptr: *const ()) -> Arc<Inner> {
311         Arc::from_raw(ptr as *const Inner)
312     }
313 }
314 
unparker_to_raw_waker(unparker: Arc<Inner>) -> RawWaker315 unsafe fn unparker_to_raw_waker(unparker: Arc<Inner>) -> RawWaker {
316     RawWaker::new(
317         Inner::into_raw(unparker),
318         &RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker),
319     )
320 }
321 
clone(raw: *const ()) -> RawWaker322 unsafe fn clone(raw: *const ()) -> RawWaker {
323     let unparker = Inner::from_raw(raw);
324 
325     // Increment the ref count
326     mem::forget(unparker.clone());
327 
328     unparker_to_raw_waker(unparker)
329 }
330 
drop_waker(raw: *const ())331 unsafe fn drop_waker(raw: *const ()) {
332     let _ = Inner::from_raw(raw);
333 }
334 
wake(raw: *const ())335 unsafe fn wake(raw: *const ()) {
336     let unparker = Inner::from_raw(raw);
337     unparker.unpark();
338 }
339 
wake_by_ref(raw: *const ())340 unsafe fn wake_by_ref(raw: *const ()) {
341     let unparker = Inner::from_raw(raw);
342     unparker.unpark();
343 
344     // We don't actually own a reference to the unparker
345     mem::forget(unparker);
346 }
347