1 use crate::runtime::task::core::{Cell, Core, Header, Trailer};
2 use crate::runtime::task::state::Snapshot;
3 use crate::runtime::task::{JoinError, Notified, Schedule, Task};
4 
5 use std::future::Future;
6 use std::mem;
7 use std::panic;
8 use std::ptr::NonNull;
9 use std::task::{Poll, Waker};
10 
11 /// Typed raw task handle
12 pub(super) struct Harness<T: Future, S: 'static> {
13     cell: NonNull<Cell<T, S>>,
14 }
15 
16 impl<T, S> Harness<T, S>
17 where
18     T: Future,
19     S: 'static,
20 {
from_raw(ptr: NonNull<Header>) -> Harness<T, S>21     pub(super) unsafe fn from_raw(ptr: NonNull<Header>) -> Harness<T, S> {
22         Harness {
23             cell: ptr.cast::<Cell<T, S>>(),
24         }
25     }
26 
header(&self) -> &Header27     fn header(&self) -> &Header {
28         unsafe { &self.cell.as_ref().header }
29     }
30 
trailer(&self) -> &Trailer31     fn trailer(&self) -> &Trailer {
32         unsafe { &self.cell.as_ref().trailer }
33     }
34 
core(&self) -> &Core<T, S>35     fn core(&self) -> &Core<T, S> {
36         unsafe { &self.cell.as_ref().core }
37     }
38 }
39 
40 impl<T, S> Harness<T, S>
41 where
42     T: Future,
43     S: Schedule,
44 {
45     /// Polls the inner future.
46     ///
47     /// All necessary state checks and transitions are performed.
48     ///
49     /// Panics raised while polling the future are handled.
poll(self)50     pub(super) fn poll(self) {
51         // If this is the first time the task is polled, the task will be bound
52         // to the scheduler, in which case the task ref count must be
53         // incremented.
54         let is_not_bound = !self.core().is_bound();
55 
56         // Transition the task to the running state.
57         //
58         // A failure to transition here indicates the task has been cancelled
59         // while in the run queue pending execution.
60         let snapshot = match self.header().state.transition_to_running(is_not_bound) {
61             Ok(snapshot) => snapshot,
62             Err(_) => {
63                 // The task was shutdown while in the run queue. At this point,
64                 // we just hold a ref counted reference. Drop it here.
65                 self.drop_reference();
66                 return;
67             }
68         };
69 
70         if is_not_bound {
71             // Ensure the task is bound to a scheduler instance. Since this is
72             // the first time polling the task, a scheduler instance is pulled
73             // from the local context and assigned to the task.
74             //
75             // The scheduler maintains ownership of the task and responds to
76             // `wake` calls.
77             //
78             // The task reference count has been incremented.
79             //
80             // Safety: Since we have unique access to the task so that we can
81             // safely call `bind_scheduler`.
82             self.core().bind_scheduler(self.to_task());
83         }
84 
85         // The transition to `Running` done above ensures that a lock on the
86         // future has been obtained. This also ensures the `*mut T` pointer
87         // contains the future (as opposed to the output) and is initialized.
88 
89         let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
90             struct Guard<'a, T: Future, S: Schedule> {
91                 core: &'a Core<T, S>,
92             }
93 
94             impl<T: Future, S: Schedule> Drop for Guard<'_, T, S> {
95                 fn drop(&mut self) {
96                     self.core.drop_future_or_output();
97                 }
98             }
99 
100             let guard = Guard { core: self.core() };
101 
102             // If the task is cancelled, avoid polling it, instead signalling it
103             // is complete.
104             if snapshot.is_cancelled() {
105                 Poll::Ready(Err(JoinError::cancelled2()))
106             } else {
107                 let res = guard.core.poll(self.header());
108 
109                 // prevent the guard from dropping the future
110                 mem::forget(guard);
111 
112                 res.map(Ok)
113             }
114         }));
115 
116         match res {
117             Ok(Poll::Ready(out)) => {
118                 self.complete(out, snapshot.is_join_interested());
119             }
120             Ok(Poll::Pending) => {
121                 match self.header().state.transition_to_idle() {
122                     Ok(snapshot) => {
123                         if snapshot.is_notified() {
124                             // Signal yield
125                             self.core().yield_now(Notified(self.to_task()));
126                             // The ref-count was incremented as part of
127                             // `transition_to_idle`.
128                             self.drop_reference();
129                         }
130                     }
131                     Err(_) => self.cancel_task(),
132                 }
133             }
134             Err(err) => {
135                 self.complete(Err(JoinError::panic2(err)), snapshot.is_join_interested());
136             }
137         }
138     }
139 
dealloc(self)140     pub(super) fn dealloc(self) {
141         // Release the join waker, if there is one.
142         self.trailer().waker.with_mut(|_| ());
143 
144         // Check causality
145         self.core().stage.with_mut(|_| {});
146         self.core().scheduler.with_mut(|_| {});
147 
148         unsafe {
149             drop(Box::from_raw(self.cell.as_ptr()));
150         }
151     }
152 
153     // ===== join handle =====
154 
155     /// Read the task output into `dst`.
try_read_output(self, dst: &mut Poll<super::Result<T::Output>>, waker: &Waker)156     pub(super) fn try_read_output(self, dst: &mut Poll<super::Result<T::Output>>, waker: &Waker) {
157         // Load a snapshot of the current task state
158         let snapshot = self.header().state.load();
159 
160         debug_assert!(snapshot.is_join_interested());
161 
162         if !snapshot.is_complete() {
163             // The waker must be stored in the task struct.
164             let res = if snapshot.has_join_waker() {
165                 // There already is a waker stored in the struct. If it matches
166                 // the provided waker, then there is no further work to do.
167                 // Otherwise, the waker must be swapped.
168                 let will_wake = unsafe {
169                     // Safety: when `JOIN_INTEREST` is set, only `JOIN_HANDLE`
170                     // may mutate the `waker` field.
171                     self.trailer()
172                         .waker
173                         .with(|ptr| (*ptr).as_ref().unwrap().will_wake(waker))
174                 };
175 
176                 if will_wake {
177                     // The task is not complete **and** the waker is up to date,
178                     // there is nothing further that needs to be done.
179                     return;
180                 }
181 
182                 // Unset the `JOIN_WAKER` to gain mutable access to the `waker`
183                 // field then update the field with the new join worker.
184                 //
185                 // This requires two atomic operations, unsetting the bit and
186                 // then resetting it. If the task transitions to complete
187                 // concurrently to either one of those operations, then setting
188                 // the join waker fails and we proceed to reading the task
189                 // output.
190                 self.header()
191                     .state
192                     .unset_waker()
193                     .and_then(|snapshot| self.set_join_waker(waker.clone(), snapshot))
194             } else {
195                 self.set_join_waker(waker.clone(), snapshot)
196             };
197 
198             match res {
199                 Ok(_) => return,
200                 Err(snapshot) => {
201                     assert!(snapshot.is_complete());
202                 }
203             }
204         }
205 
206         *dst = Poll::Ready(self.core().take_output());
207     }
208 
set_join_waker(&self, waker: Waker, snapshot: Snapshot) -> Result<Snapshot, Snapshot>209     fn set_join_waker(&self, waker: Waker, snapshot: Snapshot) -> Result<Snapshot, Snapshot> {
210         assert!(snapshot.is_join_interested());
211         assert!(!snapshot.has_join_waker());
212 
213         // Safety: Only the `JoinHandle` may set the `waker` field. When
214         // `JOIN_INTEREST` is **not** set, nothing else will touch the field.
215         unsafe {
216             self.trailer().waker.with_mut(|ptr| {
217                 *ptr = Some(waker);
218             });
219         }
220 
221         // Update the `JoinWaker` state accordingly
222         let res = self.header().state.set_join_waker();
223 
224         // If the state could not be updated, then clear the join waker
225         if res.is_err() {
226             unsafe {
227                 self.trailer().waker.with_mut(|ptr| {
228                     *ptr = None;
229                 });
230             }
231         }
232 
233         res
234     }
235 
drop_join_handle_slow(self)236     pub(super) fn drop_join_handle_slow(self) {
237         // Try to unset `JOIN_INTEREST`. This must be done as a first step in
238         // case the task concurrently completed.
239         if self.header().state.unset_join_interested().is_err() {
240             // It is our responsibility to drop the output. This is critical as
241             // the task output may not be `Send` and as such must remain with
242             // the scheduler or `JoinHandle`. i.e. if the output remains in the
243             // task structure until the task is deallocated, it may be dropped
244             // by a Waker on any arbitrary thread.
245             self.core().drop_future_or_output();
246         }
247 
248         // Drop the `JoinHandle` reference, possibly deallocating the task
249         self.drop_reference();
250     }
251 
252     // ===== waker behavior =====
253 
wake_by_val(self)254     pub(super) fn wake_by_val(self) {
255         self.wake_by_ref();
256         self.drop_reference();
257     }
258 
wake_by_ref(&self)259     pub(super) fn wake_by_ref(&self) {
260         if self.header().state.transition_to_notified() {
261             self.core().schedule(Notified(self.to_task()));
262         }
263     }
264 
drop_reference(self)265     pub(super) fn drop_reference(self) {
266         if self.header().state.ref_dec() {
267             self.dealloc();
268         }
269     }
270 
271     /// Forcibly shutdown the task
272     ///
273     /// Attempt to transition to `Running` in order to forcibly shutdown the
274     /// task. If the task is currently running or in a state of completion, then
275     /// there is nothing further to do. When the task completes running, it will
276     /// notice the `CANCELLED` bit and finalize the task.
shutdown(self)277     pub(super) fn shutdown(self) {
278         if !self.header().state.transition_to_shutdown() {
279             // The task is concurrently running. No further work needed.
280             return;
281         }
282 
283         // By transitioning the lifcycle to `Running`, we have permission to
284         // drop the future.
285         self.cancel_task();
286     }
287 
288     // ====== internal ======
289 
cancel_task(self)290     fn cancel_task(self) {
291         // Drop the future from a panic guard.
292         let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
293             self.core().drop_future_or_output();
294         }));
295 
296         if let Err(err) = res {
297             // Dropping the future panicked, complete the join
298             // handle with the panic to avoid dropping the panic
299             // on the ground.
300             self.complete(Err(JoinError::panic2(err)), true);
301         } else {
302             self.complete(Err(JoinError::cancelled2()), true);
303         }
304     }
305 
complete(mut self, output: super::Result<T::Output>, is_join_interested: bool)306     fn complete(mut self, output: super::Result<T::Output>, is_join_interested: bool) {
307         if is_join_interested {
308             // Store the output. The future has already been dropped
309             //
310             // Safety: Mutual exclusion is obtained by having transitioned the task
311             // state -> Running
312             self.core().store_output(output);
313 
314             // Transition to `Complete`, notifying the `JoinHandle` if necessary.
315             self.transition_to_complete();
316         }
317 
318         // The task has completed execution and will no longer be scheduled.
319         //
320         // Attempts to batch a ref-dec with the state transition below.
321         let ref_dec = if self.core().is_bound() {
322             if let Some(task) = self.core().release(self.to_task()) {
323                 mem::forget(task);
324                 true
325             } else {
326                 false
327             }
328         } else {
329             false
330         };
331 
332         // This might deallocate
333         let snapshot = self
334             .header()
335             .state
336             .transition_to_terminal(!is_join_interested, ref_dec);
337 
338         if snapshot.ref_count() == 0 {
339             self.dealloc()
340         }
341     }
342 
343     /// Transitions the task's lifecycle to `Complete`. Notifies the
344     /// `JoinHandle` if it still has interest in the completion.
transition_to_complete(&mut self)345     fn transition_to_complete(&mut self) {
346         // Transition the task's lifecycle to `Complete` and get a snapshot of
347         // the task's sate.
348         let snapshot = self.header().state.transition_to_complete();
349 
350         if !snapshot.is_join_interested() {
351             // The `JoinHandle` is not interested in the output of this task. It
352             // is our responsibility to drop the output.
353             self.core().drop_future_or_output();
354         } else if snapshot.has_join_waker() {
355             // Notify the join handle. The previous transition obtains the
356             // lock on the waker cell.
357             self.wake_join();
358         }
359     }
360 
wake_join(&self)361     fn wake_join(&self) {
362         self.trailer().waker.with(|ptr| match unsafe { &*ptr } {
363             Some(waker) => waker.wake_by_ref(),
364             None => panic!("waker missing"),
365         });
366     }
367 
to_task(&self) -> Task<S>368     fn to_task(&self) -> Task<S> {
369         unsafe { Task::from_raw(self.header().into()) }
370     }
371 }
372