1 use crate::future::Future;
2 use crate::runtime::task::core::{Cell, Core, CoreStage, Header, Scheduler, Trailer};
3 use crate::runtime::task::state::Snapshot;
4 use crate::runtime::task::waker::waker_ref;
5 use crate::runtime::task::{JoinError, Notified, Schedule, Task};
6 
7 use std::mem;
8 use std::panic;
9 use std::ptr::NonNull;
10 use std::task::{Context, Poll, Waker};
11 
12 /// Typed raw task handle
13 pub(super) struct Harness<T: Future, S: 'static> {
14     cell: NonNull<Cell<T, S>>,
15 }
16 
17 impl<T, S> Harness<T, S>
18 where
19     T: Future,
20     S: 'static,
21 {
from_raw(ptr: NonNull<Header>) -> Harness<T, S>22     pub(super) unsafe fn from_raw(ptr: NonNull<Header>) -> Harness<T, S> {
23         Harness {
24             cell: ptr.cast::<Cell<T, S>>(),
25         }
26     }
27 
header(&self) -> &Header28     fn header(&self) -> &Header {
29         unsafe { &self.cell.as_ref().header }
30     }
31 
trailer(&self) -> &Trailer32     fn trailer(&self) -> &Trailer {
33         unsafe { &self.cell.as_ref().trailer }
34     }
35 
core(&self) -> &Core<T, S>36     fn core(&self) -> &Core<T, S> {
37         unsafe { &self.cell.as_ref().core }
38     }
39 
scheduler_view(&self) -> SchedulerView<'_, S>40     fn scheduler_view(&self) -> SchedulerView<'_, S> {
41         SchedulerView {
42             header: self.header(),
43             scheduler: &self.core().scheduler,
44         }
45     }
46 }
47 
48 impl<T, S> Harness<T, S>
49 where
50     T: Future,
51     S: Schedule,
52 {
53     /// Polls the inner future.
54     ///
55     /// All necessary state checks and transitions are performed.
56     ///
57     /// Panics raised while polling the future are handled.
poll(self)58     pub(super) fn poll(self) {
59         match self.poll_inner() {
60             PollFuture::Notified => {
61                 // Signal yield
62                 self.core().scheduler.yield_now(Notified(self.to_task()));
63                 // The ref-count was incremented as part of
64                 // `transition_to_idle`.
65                 self.drop_reference();
66             }
67             PollFuture::DropReference => {
68                 self.drop_reference();
69             }
70             PollFuture::Complete(out, is_join_interested) => {
71                 self.complete(out, is_join_interested);
72             }
73             PollFuture::None => (),
74         }
75     }
76 
poll_inner(&self) -> PollFuture<T::Output>77     fn poll_inner(&self) -> PollFuture<T::Output> {
78         let snapshot = match self.scheduler_view().transition_to_running() {
79             TransitionToRunning::Ok(snapshot) => snapshot,
80             TransitionToRunning::DropReference => return PollFuture::DropReference,
81         };
82 
83         // The transition to `Running` done above ensures that a lock on the
84         // future has been obtained. This also ensures the `*mut T` pointer
85         // contains the future (as opposed to the output) and is initialized.
86 
87         let waker_ref = waker_ref::<T, S>(self.header());
88         let cx = Context::from_waker(&*waker_ref);
89         poll_future(self.header(), &self.core().stage, snapshot, cx)
90     }
91 
dealloc(self)92     pub(super) fn dealloc(self) {
93         // Release the join waker, if there is one.
94         self.trailer().waker.with_mut(drop);
95 
96         // Check causality
97         self.core().stage.with_mut(drop);
98         self.core().scheduler.with_mut(drop);
99 
100         unsafe {
101             drop(Box::from_raw(self.cell.as_ptr()));
102         }
103     }
104 
105     // ===== join handle =====
106 
107     /// Read the task output into `dst`.
try_read_output(self, dst: &mut Poll<super::Result<T::Output>>, waker: &Waker)108     pub(super) fn try_read_output(self, dst: &mut Poll<super::Result<T::Output>>, waker: &Waker) {
109         if can_read_output(self.header(), self.trailer(), waker) {
110             *dst = Poll::Ready(self.core().stage.take_output());
111         }
112     }
113 
drop_join_handle_slow(self)114     pub(super) fn drop_join_handle_slow(self) {
115         // Try to unset `JOIN_INTEREST`. This must be done as a first step in
116         // case the task concurrently completed.
117         if self.header().state.unset_join_interested().is_err() {
118             // It is our responsibility to drop the output. This is critical as
119             // the task output may not be `Send` and as such must remain with
120             // the scheduler or `JoinHandle`. i.e. if the output remains in the
121             // task structure until the task is deallocated, it may be dropped
122             // by a Waker on any arbitrary thread.
123             self.core().stage.drop_future_or_output();
124         }
125 
126         // Drop the `JoinHandle` reference, possibly deallocating the task
127         self.drop_reference();
128     }
129 
130     // ===== waker behavior =====
131 
wake_by_val(self)132     pub(super) fn wake_by_val(self) {
133         self.wake_by_ref();
134         self.drop_reference();
135     }
136 
wake_by_ref(&self)137     pub(super) fn wake_by_ref(&self) {
138         if self.header().state.transition_to_notified() {
139             self.core().scheduler.schedule(Notified(self.to_task()));
140         }
141     }
142 
drop_reference(self)143     pub(super) fn drop_reference(self) {
144         if self.header().state.ref_dec() {
145             self.dealloc();
146         }
147     }
148 
149     #[cfg(all(tokio_unstable, feature = "tracing"))]
id(&self) -> Option<&tracing::Id>150     pub(super) fn id(&self) -> Option<&tracing::Id> {
151         self.header().id.as_ref()
152     }
153 
154     /// Forcibly shutdown the task
155     ///
156     /// Attempt to transition to `Running` in order to forcibly shutdown the
157     /// task. If the task is currently running or in a state of completion, then
158     /// there is nothing further to do. When the task completes running, it will
159     /// notice the `CANCELLED` bit and finalize the task.
shutdown(self)160     pub(super) fn shutdown(self) {
161         if !self.header().state.transition_to_shutdown() {
162             // The task is concurrently running. No further work needed.
163             return;
164         }
165 
166         // By transitioning the lifecycle to `Running`, we have permission to
167         // drop the future.
168         let err = cancel_task(&self.core().stage);
169         self.complete(Err(err), true)
170     }
171 
172     /// Remotely abort the task
173     ///
174     /// This is similar to `shutdown` except that it asks the runtime to perform
175     /// the shutdown. This is necessary to avoid the shutdown happening in the
176     /// wrong thread for non-Send tasks.
remote_abort(self)177     pub(super) fn remote_abort(self) {
178         if self.header().state.transition_to_notified_and_cancel() {
179             self.core().scheduler.schedule(Notified(self.to_task()));
180         }
181     }
182 
183     // ====== internal ======
184 
complete(self, output: super::Result<T::Output>, is_join_interested: bool)185     fn complete(self, output: super::Result<T::Output>, is_join_interested: bool) {
186         if is_join_interested {
187             // Store the output. The future has already been dropped
188             //
189             // Safety: Mutual exclusion is obtained by having transitioned the task
190             // state -> Running
191             let stage = &self.core().stage;
192             stage.store_output(output);
193 
194             // Transition to `Complete`, notifying the `JoinHandle` if necessary.
195             transition_to_complete(self.header(), stage, &self.trailer());
196         }
197 
198         // The task has completed execution and will no longer be scheduled.
199         //
200         // Attempts to batch a ref-dec with the state transition below.
201 
202         if self
203             .scheduler_view()
204             .transition_to_terminal(is_join_interested)
205         {
206             self.dealloc()
207         }
208     }
209 
to_task(&self) -> Task<S>210     fn to_task(&self) -> Task<S> {
211         self.scheduler_view().to_task()
212     }
213 }
214 
215 enum TransitionToRunning {
216     Ok(Snapshot),
217     DropReference,
218 }
219 
220 struct SchedulerView<'a, S> {
221     header: &'a Header,
222     scheduler: &'a Scheduler<S>,
223 }
224 
225 impl<'a, S> SchedulerView<'a, S>
226 where
227     S: Schedule,
228 {
to_task(&self) -> Task<S>229     fn to_task(&self) -> Task<S> {
230         // SAFETY The header is from the same struct containing the scheduler `S` so  the cast is safe
231         unsafe { Task::from_raw(self.header.into()) }
232     }
233 
234     /// Returns true if the task should be deallocated.
transition_to_terminal(&self, is_join_interested: bool) -> bool235     fn transition_to_terminal(&self, is_join_interested: bool) -> bool {
236         let ref_dec = if self.scheduler.is_bound() {
237             if let Some(task) = self.scheduler.release(self.to_task()) {
238                 mem::forget(task);
239                 true
240             } else {
241                 false
242             }
243         } else {
244             false
245         };
246 
247         // This might deallocate
248         let snapshot = self
249             .header
250             .state
251             .transition_to_terminal(!is_join_interested, ref_dec);
252 
253         snapshot.ref_count() == 0
254     }
255 
transition_to_running(&self) -> TransitionToRunning256     fn transition_to_running(&self) -> TransitionToRunning {
257         // If this is the first time the task is polled, the task will be bound
258         // to the scheduler, in which case the task ref count must be
259         // incremented.
260         let is_not_bound = !self.scheduler.is_bound();
261 
262         // Transition the task to the running state.
263         //
264         // A failure to transition here indicates the task has been cancelled
265         // while in the run queue pending execution.
266         let snapshot = match self.header.state.transition_to_running(is_not_bound) {
267             Ok(snapshot) => snapshot,
268             Err(_) => {
269                 // The task was shutdown while in the run queue. At this point,
270                 // we just hold a ref counted reference. Since we do not have access to it here
271                 // return `DropReference` so the caller drops it.
272                 return TransitionToRunning::DropReference;
273             }
274         };
275 
276         if is_not_bound {
277             // Ensure the task is bound to a scheduler instance. Since this is
278             // the first time polling the task, a scheduler instance is pulled
279             // from the local context and assigned to the task.
280             //
281             // The scheduler maintains ownership of the task and responds to
282             // `wake` calls.
283             //
284             // The task reference count has been incremented.
285             //
286             // Safety: Since we have unique access to the task so that we can
287             // safely call `bind_scheduler`.
288             self.scheduler.bind_scheduler(self.to_task());
289         }
290         TransitionToRunning::Ok(snapshot)
291     }
292 }
293 
294 /// Transitions the task's lifecycle to `Complete`. Notifies the
295 /// `JoinHandle` if it still has interest in the completion.
transition_to_complete<T>(header: &Header, stage: &CoreStage<T>, trailer: &Trailer) where T: Future,296 fn transition_to_complete<T>(header: &Header, stage: &CoreStage<T>, trailer: &Trailer)
297 where
298     T: Future,
299 {
300     // Transition the task's lifecycle to `Complete` and get a snapshot of
301     // the task's sate.
302     let snapshot = header.state.transition_to_complete();
303 
304     if !snapshot.is_join_interested() {
305         // The `JoinHandle` is not interested in the output of this task. It
306         // is our responsibility to drop the output.
307         stage.drop_future_or_output();
308     } else if snapshot.has_join_waker() {
309         // Notify the join handle. The previous transition obtains the
310         // lock on the waker cell.
311         trailer.wake_join();
312     }
313 }
314 
can_read_output(header: &Header, trailer: &Trailer, waker: &Waker) -> bool315 fn can_read_output(header: &Header, trailer: &Trailer, waker: &Waker) -> bool {
316     // Load a snapshot of the current task state
317     let snapshot = header.state.load();
318 
319     debug_assert!(snapshot.is_join_interested());
320 
321     if !snapshot.is_complete() {
322         // The waker must be stored in the task struct.
323         let res = if snapshot.has_join_waker() {
324             // There already is a waker stored in the struct. If it matches
325             // the provided waker, then there is no further work to do.
326             // Otherwise, the waker must be swapped.
327             let will_wake = unsafe {
328                 // Safety: when `JOIN_INTEREST` is set, only `JOIN_HANDLE`
329                 // may mutate the `waker` field.
330                 trailer.will_wake(waker)
331             };
332 
333             if will_wake {
334                 // The task is not complete **and** the waker is up to date,
335                 // there is nothing further that needs to be done.
336                 return false;
337             }
338 
339             // Unset the `JOIN_WAKER` to gain mutable access to the `waker`
340             // field then update the field with the new join worker.
341             //
342             // This requires two atomic operations, unsetting the bit and
343             // then resetting it. If the task transitions to complete
344             // concurrently to either one of those operations, then setting
345             // the join waker fails and we proceed to reading the task
346             // output.
347             header
348                 .state
349                 .unset_waker()
350                 .and_then(|snapshot| set_join_waker(header, trailer, waker.clone(), snapshot))
351         } else {
352             set_join_waker(header, trailer, waker.clone(), snapshot)
353         };
354 
355         match res {
356             Ok(_) => return false,
357             Err(snapshot) => {
358                 assert!(snapshot.is_complete());
359             }
360         }
361     }
362     true
363 }
364 
set_join_waker( header: &Header, trailer: &Trailer, waker: Waker, snapshot: Snapshot, ) -> Result<Snapshot, Snapshot>365 fn set_join_waker(
366     header: &Header,
367     trailer: &Trailer,
368     waker: Waker,
369     snapshot: Snapshot,
370 ) -> Result<Snapshot, Snapshot> {
371     assert!(snapshot.is_join_interested());
372     assert!(!snapshot.has_join_waker());
373 
374     // Safety: Only the `JoinHandle` may set the `waker` field. When
375     // `JOIN_INTEREST` is **not** set, nothing else will touch the field.
376     unsafe {
377         trailer.set_waker(Some(waker));
378     }
379 
380     // Update the `JoinWaker` state accordingly
381     let res = header.state.set_join_waker();
382 
383     // If the state could not be updated, then clear the join waker
384     if res.is_err() {
385         unsafe {
386             trailer.set_waker(None);
387         }
388     }
389 
390     res
391 }
392 
393 enum PollFuture<T> {
394     Complete(Result<T, JoinError>, bool),
395     DropReference,
396     Notified,
397     None,
398 }
399 
cancel_task<T: Future>(stage: &CoreStage<T>) -> JoinError400 fn cancel_task<T: Future>(stage: &CoreStage<T>) -> JoinError {
401     // Drop the future from a panic guard.
402     let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
403         stage.drop_future_or_output();
404     }));
405 
406     if let Err(err) = res {
407         // Dropping the future panicked, complete the join
408         // handle with the panic to avoid dropping the panic
409         // on the ground.
410         JoinError::panic(err)
411     } else {
412         JoinError::cancelled()
413     }
414 }
415 
poll_future<T: Future>( header: &Header, core: &CoreStage<T>, snapshot: Snapshot, cx: Context<'_>, ) -> PollFuture<T::Output>416 fn poll_future<T: Future>(
417     header: &Header,
418     core: &CoreStage<T>,
419     snapshot: Snapshot,
420     cx: Context<'_>,
421 ) -> PollFuture<T::Output> {
422     if snapshot.is_cancelled() {
423         PollFuture::Complete(Err(cancel_task(core)), snapshot.is_join_interested())
424     } else {
425         let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
426             struct Guard<'a, T: Future> {
427                 core: &'a CoreStage<T>,
428             }
429 
430             impl<T: Future> Drop for Guard<'_, T> {
431                 fn drop(&mut self) {
432                     self.core.drop_future_or_output();
433                 }
434             }
435 
436             let guard = Guard { core };
437 
438             let res = guard.core.poll(cx);
439 
440             // prevent the guard from dropping the future
441             mem::forget(guard);
442 
443             res
444         }));
445         match res {
446             Ok(Poll::Pending) => match header.state.transition_to_idle() {
447                 Ok(snapshot) => {
448                     if snapshot.is_notified() {
449                         PollFuture::Notified
450                     } else {
451                         PollFuture::None
452                     }
453                 }
454                 Err(_) => PollFuture::Complete(Err(cancel_task(core)), true),
455             },
456             Ok(Poll::Ready(ok)) => PollFuture::Complete(Ok(ok), snapshot.is_join_interested()),
457             Err(err) => {
458                 PollFuture::Complete(Err(JoinError::panic(err)), snapshot.is_join_interested())
459             }
460         }
461     }
462 }
463