1 #![cfg_attr(not(feature = "full"), allow(dead_code))]
2
3 use crate::loom::sync::atomic::AtomicUsize;
4 use crate::loom::sync::{Arc, Condvar, Mutex};
5 use crate::park::{Park, Unpark};
6
7 use std::sync::atomic::Ordering::SeqCst;
8 use std::time::Duration;
9
10 #[derive(Debug)]
11 pub(crate) struct ParkThread {
12 inner: Arc<Inner>,
13 }
14
15 pub(crate) type ParkError = ();
16
17 /// Unblocks a thread that was blocked by `ParkThread`.
18 #[derive(Clone, Debug)]
19 pub(crate) struct UnparkThread {
20 inner: Arc<Inner>,
21 }
22
23 #[derive(Debug)]
24 struct Inner {
25 state: AtomicUsize,
26 mutex: Mutex<()>,
27 condvar: Condvar,
28 }
29
30 const EMPTY: usize = 0;
31 const PARKED: usize = 1;
32 const NOTIFIED: usize = 2;
33
34 thread_local! {
35 static CURRENT_PARKER: ParkThread = ParkThread::new();
36 }
37
38 // ==== impl ParkThread ====
39
40 impl ParkThread {
new() -> Self41 pub(crate) fn new() -> Self {
42 Self {
43 inner: Arc::new(Inner {
44 state: AtomicUsize::new(EMPTY),
45 mutex: Mutex::new(()),
46 condvar: Condvar::new(),
47 }),
48 }
49 }
50 }
51
52 impl Park for ParkThread {
53 type Unpark = UnparkThread;
54 type Error = ParkError;
55
unpark(&self) -> Self::Unpark56 fn unpark(&self) -> Self::Unpark {
57 let inner = self.inner.clone();
58 UnparkThread { inner }
59 }
60
park(&mut self) -> Result<(), Self::Error>61 fn park(&mut self) -> Result<(), Self::Error> {
62 self.inner.park();
63 Ok(())
64 }
65
park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error>66 fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> {
67 self.inner.park_timeout(duration);
68 Ok(())
69 }
70
shutdown(&mut self)71 fn shutdown(&mut self) {
72 self.inner.shutdown();
73 }
74 }
75
76 // ==== impl Inner ====
77
78 impl Inner {
79 /// Park the current thread for at most `dur`.
park(&self)80 fn park(&self) {
81 // If we were previously notified then we consume this notification and
82 // return quickly.
83 if self
84 .state
85 .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst)
86 .is_ok()
87 {
88 return;
89 }
90
91 // Otherwise we need to coordinate going to sleep
92 let mut m = self.mutex.lock();
93
94 match self.state.compare_exchange(EMPTY, PARKED, SeqCst, SeqCst) {
95 Ok(_) => {}
96 Err(NOTIFIED) => {
97 // We must read here, even though we know it will be `NOTIFIED`.
98 // This is because `unpark` may have been called again since we read
99 // `NOTIFIED` in the `compare_exchange` above. We must perform an
100 // acquire operation that synchronizes with that `unpark` to observe
101 // any writes it made before the call to unpark. To do that we must
102 // read from the write it made to `state`.
103 let old = self.state.swap(EMPTY, SeqCst);
104 debug_assert_eq!(old, NOTIFIED, "park state changed unexpectedly");
105
106 return;
107 }
108 Err(actual) => panic!("inconsistent park state; actual = {}", actual),
109 }
110
111 loop {
112 m = self.condvar.wait(m).unwrap();
113
114 if self
115 .state
116 .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst)
117 .is_ok()
118 {
119 // got a notification
120 return;
121 }
122
123 // spurious wakeup, go back to sleep
124 }
125 }
126
park_timeout(&self, dur: Duration)127 fn park_timeout(&self, dur: Duration) {
128 // Like `park` above we have a fast path for an already-notified thread,
129 // and afterwards we start coordinating for a sleep. Return quickly.
130 if self
131 .state
132 .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst)
133 .is_ok()
134 {
135 return;
136 }
137
138 if dur == Duration::from_millis(0) {
139 return;
140 }
141
142 let m = self.mutex.lock();
143
144 match self.state.compare_exchange(EMPTY, PARKED, SeqCst, SeqCst) {
145 Ok(_) => {}
146 Err(NOTIFIED) => {
147 // We must read again here, see `park`.
148 let old = self.state.swap(EMPTY, SeqCst);
149 debug_assert_eq!(old, NOTIFIED, "park state changed unexpectedly");
150
151 return;
152 }
153 Err(actual) => panic!("inconsistent park_timeout state; actual = {}", actual),
154 }
155
156 // Wait with a timeout, and if we spuriously wake up or otherwise wake up
157 // from a notification, we just want to unconditionally set the state back to
158 // empty, either consuming a notification or un-flagging ourselves as
159 // parked.
160 let (_m, _result) = self.condvar.wait_timeout(m, dur).unwrap();
161
162 match self.state.swap(EMPTY, SeqCst) {
163 NOTIFIED => {} // got a notification, hurray!
164 PARKED => {} // no notification, alas
165 n => panic!("inconsistent park_timeout state: {}", n),
166 }
167 }
168
unpark(&self)169 fn unpark(&self) {
170 // To ensure the unparked thread will observe any writes we made before
171 // this call, we must perform a release operation that `park` can
172 // synchronize with. To do that we must write `NOTIFIED` even if `state`
173 // is already `NOTIFIED`. That is why this must be a swap rather than a
174 // compare-and-swap that returns if it reads `NOTIFIED` on failure.
175 match self.state.swap(NOTIFIED, SeqCst) {
176 EMPTY => return, // no one was waiting
177 NOTIFIED => return, // already unparked
178 PARKED => {} // gotta go wake someone up
179 _ => panic!("inconsistent state in unpark"),
180 }
181
182 // There is a period between when the parked thread sets `state` to
183 // `PARKED` (or last checked `state` in the case of a spurious wake
184 // up) and when it actually waits on `cvar`. If we were to notify
185 // during this period it would be ignored and then when the parked
186 // thread went to sleep it would never wake up. Fortunately, it has
187 // `lock` locked at this stage so we can acquire `lock` to wait until
188 // it is ready to receive the notification.
189 //
190 // Releasing `lock` before the call to `notify_one` means that when the
191 // parked thread wakes it doesn't get woken only to have to wait for us
192 // to release `lock`.
193 drop(self.mutex.lock());
194
195 self.condvar.notify_one()
196 }
197
shutdown(&self)198 fn shutdown(&self) {
199 self.condvar.notify_all();
200 }
201 }
202
203 impl Default for ParkThread {
default() -> Self204 fn default() -> Self {
205 Self::new()
206 }
207 }
208
209 // ===== impl UnparkThread =====
210
211 impl Unpark for UnparkThread {
unpark(&self)212 fn unpark(&self) {
213 self.inner.unpark();
214 }
215 }
216
217 use std::future::Future;
218 use std::marker::PhantomData;
219 use std::mem;
220 use std::rc::Rc;
221 use std::task::{RawWaker, RawWakerVTable, Waker};
222
223 /// Blocks the current thread using a condition variable.
224 #[derive(Debug)]
225 pub(crate) struct CachedParkThread {
226 _anchor: PhantomData<Rc<()>>,
227 }
228
229 impl CachedParkThread {
230 /// Create a new `ParkThread` handle for the current thread.
231 ///
232 /// This type cannot be moved to other threads, so it should be created on
233 /// the thread that the caller intends to park.
new() -> CachedParkThread234 pub(crate) fn new() -> CachedParkThread {
235 CachedParkThread {
236 _anchor: PhantomData,
237 }
238 }
239
get_unpark(&self) -> Result<UnparkThread, ParkError>240 pub(crate) fn get_unpark(&self) -> Result<UnparkThread, ParkError> {
241 self.with_current(|park_thread| park_thread.unpark())
242 }
243
244 /// Get a reference to the `ParkThread` handle for this thread.
with_current<F, R>(&self, f: F) -> Result<R, ParkError> where F: FnOnce(&ParkThread) -> R,245 fn with_current<F, R>(&self, f: F) -> Result<R, ParkError>
246 where
247 F: FnOnce(&ParkThread) -> R,
248 {
249 CURRENT_PARKER.try_with(|inner| f(inner)).map_err(|_| ())
250 }
251
block_on<F: Future>(&mut self, f: F) -> Result<F::Output, ParkError>252 pub(crate) fn block_on<F: Future>(&mut self, f: F) -> Result<F::Output, ParkError> {
253 use std::task::Context;
254 use std::task::Poll::Ready;
255
256 // `get_unpark()` should not return a Result
257 let waker = self.get_unpark()?.into_waker();
258 let mut cx = Context::from_waker(&waker);
259
260 pin!(f);
261
262 loop {
263 if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) {
264 return Ok(v);
265 }
266
267 self.park()?;
268 }
269 }
270 }
271
272 impl Park for CachedParkThread {
273 type Unpark = UnparkThread;
274 type Error = ParkError;
275
unpark(&self) -> Self::Unpark276 fn unpark(&self) -> Self::Unpark {
277 self.get_unpark().unwrap()
278 }
279
park(&mut self) -> Result<(), Self::Error>280 fn park(&mut self) -> Result<(), Self::Error> {
281 self.with_current(|park_thread| park_thread.inner.park())?;
282 Ok(())
283 }
284
park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error>285 fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> {
286 self.with_current(|park_thread| park_thread.inner.park_timeout(duration))?;
287 Ok(())
288 }
289
shutdown(&mut self)290 fn shutdown(&mut self) {
291 let _ = self.with_current(|park_thread| park_thread.inner.shutdown());
292 }
293 }
294
295 impl UnparkThread {
into_waker(self) -> Waker296 pub(crate) fn into_waker(self) -> Waker {
297 unsafe {
298 let raw = unparker_to_raw_waker(self.inner);
299 Waker::from_raw(raw)
300 }
301 }
302 }
303
304 impl Inner {
305 #[allow(clippy::wrong_self_convention)]
into_raw(this: Arc<Inner>) -> *const ()306 fn into_raw(this: Arc<Inner>) -> *const () {
307 Arc::into_raw(this) as *const ()
308 }
309
from_raw(ptr: *const ()) -> Arc<Inner>310 unsafe fn from_raw(ptr: *const ()) -> Arc<Inner> {
311 Arc::from_raw(ptr as *const Inner)
312 }
313 }
314
unparker_to_raw_waker(unparker: Arc<Inner>) -> RawWaker315 unsafe fn unparker_to_raw_waker(unparker: Arc<Inner>) -> RawWaker {
316 RawWaker::new(
317 Inner::into_raw(unparker),
318 &RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker),
319 )
320 }
321
clone(raw: *const ()) -> RawWaker322 unsafe fn clone(raw: *const ()) -> RawWaker {
323 let unparker = Inner::from_raw(raw);
324
325 // Increment the ref count
326 mem::forget(unparker.clone());
327
328 unparker_to_raw_waker(unparker)
329 }
330
drop_waker(raw: *const ())331 unsafe fn drop_waker(raw: *const ()) {
332 let _ = Inner::from_raw(raw);
333 }
334
wake(raw: *const ())335 unsafe fn wake(raw: *const ()) {
336 let unparker = Inner::from_raw(raw);
337 unparker.unpark();
338 }
339
wake_by_ref(raw: *const ())340 unsafe fn wake_by_ref(raw: *const ()) {
341 let unparker = Inner::from_raw(raw);
342 unparker.unpark();
343
344 // We don't actually own a reference to the unparker
345 mem::forget(unparker);
346 }
347