1 //! An asynchronously awaitable `CancellationToken`.
2 //! The token allows to signal a cancellation request to one or more tasks.
3 
4 use crate::loom::sync::atomic::AtomicUsize;
5 use crate::loom::sync::Mutex;
6 use crate::sync::intrusive_double_linked_list::{LinkedList, ListNode};
7 
8 use core::future::Future;
9 use core::pin::Pin;
10 use core::ptr::NonNull;
11 use core::sync::atomic::Ordering;
12 use core::task::{Context, Poll, Waker};
13 
14 /// A token which can be used to signal a cancellation request to one or more
15 /// tasks.
16 ///
17 /// Tasks can call [`CancellationToken::cancelled()`] in order to
18 /// obtain a Future which will be resolved when cancellation is requested.
19 ///
20 /// Cancellation can be requested through the [`CancellationToken::cancel`] method.
21 ///
22 /// # Examples
23 ///
24 /// ```ignore
25 /// use tokio::select;
26 /// use tokio::scope::CancellationToken;
27 ///
28 /// #[tokio::main]
29 /// async fn main() {
30 ///     let token = CancellationToken::new();
31 ///     let cloned_token = token.clone();
32 ///
33 ///     let join_handle = tokio::spawn(async move {
34 ///         // Wait for either cancellation or a very long time
35 ///         select! {
36 ///             _ = cloned_token.cancelled() => {
37 ///                 // The token was cancelled
38 ///                 5
39 ///             }
40 ///             _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => {
41 ///                 99
42 ///             }
43 ///         }
44 ///     });
45 ///
46 ///     tokio::spawn(async move {
47 ///         tokio::time::sleep(std::time::Duration::from_millis(10)).await;
48 ///         token.cancel();
49 ///     });
50 ///
51 ///     assert_eq!(5, join_handle.await.unwrap());
52 /// }
53 /// ```
54 pub struct CancellationToken {
55     inner: NonNull<CancellationTokenState>,
56 }
57 
58 // Safety: The CancellationToken is thread-safe and can be moved between threads,
59 // since all methods are internally synchronized.
60 unsafe impl Send for CancellationToken {}
61 unsafe impl Sync for CancellationToken {}
62 
63 /// A Future that is resolved once the corresponding [`CancellationToken`]
64 /// was cancelled
65 #[must_use = "futures do nothing unless polled"]
66 pub struct WaitForCancellationFuture<'a> {
67     /// The CancellationToken that is associated with this WaitForCancellationFuture
68     cancellation_token: Option<&'a CancellationToken>,
69     /// Node for waiting at the cancellation_token
70     wait_node: ListNode<WaitQueueEntry>,
71     /// Whether this future was registered at the token yet as a waiter
72     is_registered: bool,
73 }
74 
75 // Safety: Futures can be sent between threads as long as the underlying
76 // cancellation_token is thread-safe (Sync),
77 // which allows to poll/register/unregister from a different thread.
78 unsafe impl<'a> Send for WaitForCancellationFuture<'a> {}
79 
80 // ===== impl CancellationToken =====
81 
82 impl core::fmt::Debug for CancellationToken {
fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result83     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
84         f.debug_struct("CancellationToken")
85             .field("is_cancelled", &self.is_cancelled())
86             .finish()
87     }
88 }
89 
90 impl Clone for CancellationToken {
clone(&self) -> Self91     fn clone(&self) -> Self {
92         // Safety: The state inside a `CancellationToken` is always valid, since
93         // is reference counted
94         let inner = self.state();
95 
96         // Tokens are cloned by increasing their refcount
97         let current_state = inner.snapshot();
98         inner.increment_refcount(current_state);
99 
100         CancellationToken { inner: self.inner }
101     }
102 }
103 
104 impl Drop for CancellationToken {
drop(&mut self)105     fn drop(&mut self) {
106         let token_state_pointer = self.inner;
107 
108         // Safety: The state inside a `CancellationToken` is always valid, since
109         // is reference counted
110         let inner = unsafe { &mut *self.inner.as_ptr() };
111 
112         let mut current_state = inner.snapshot();
113 
114         // We need to safe the parent, since the state might be released by the
115         // next call
116         let parent = inner.parent;
117 
118         // Drop our own refcount
119         current_state = inner.decrement_refcount(current_state);
120 
121         // If this was the last reference, unregister from the parent
122         if current_state.refcount == 0 {
123             if let Some(mut parent) = parent {
124                 // Safety: Since we still retain a reference on the parent, it must be valid.
125                 let parent = unsafe { parent.as_mut() };
126                 parent.unregister_child(token_state_pointer, current_state);
127             }
128         }
129     }
130 }
131 
132 impl Default for CancellationToken {
default() -> CancellationToken133     fn default() -> CancellationToken {
134         CancellationToken::new()
135     }
136 }
137 
138 impl CancellationToken {
139     /// Creates a new CancellationToken in the non-cancelled state.
new() -> CancellationToken140     pub fn new() -> CancellationToken {
141         let state = Box::new(CancellationTokenState::new(
142             None,
143             StateSnapshot {
144                 cancel_state: CancellationState::NotCancelled,
145                 has_parent_ref: false,
146                 refcount: 1,
147             },
148         ));
149 
150         // Safety: We just created the Box. The pointer is guaranteed to be
151         // not null
152         CancellationToken {
153             inner: unsafe { NonNull::new_unchecked(Box::into_raw(state)) },
154         }
155     }
156 
157     /// Returns a reference to the utilized `CancellationTokenState`.
state(&self) -> &CancellationTokenState158     fn state(&self) -> &CancellationTokenState {
159         // Safety: The state inside a `CancellationToken` is always valid, since
160         // is reference counted
161         unsafe { &*self.inner.as_ptr() }
162     }
163 
164     /// Creates a `CancellationToken` which will get cancelled whenever the
165     /// current token gets cancelled.
166     ///
167     /// If the current token is already cancelled, the child token will get
168     /// returned in cancelled state.
169     ///
170     /// # Examples
171     ///
172     /// ```ignore
173     /// use tokio::select;
174     /// use tokio::scope::CancellationToken;
175     ///
176     /// #[tokio::main]
177     /// async fn main() {
178     ///     let token = CancellationToken::new();
179     ///     let child_token = token.child_token();
180     ///
181     ///     let join_handle = tokio::spawn(async move {
182     ///         // Wait for either cancellation or a very long time
183     ///         select! {
184     ///             _ = child_token.cancelled() => {
185     ///                 // The token was cancelled
186     ///                 5
187     ///             }
188     ///             _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => {
189     ///                 99
190     ///             }
191     ///         }
192     ///     });
193     ///
194     ///     tokio::spawn(async move {
195     ///         tokio::time::sleep(std::time::Duration::from_millis(10)).await;
196     ///         token.cancel();
197     ///     });
198     ///
199     ///     assert_eq!(5, join_handle.await.unwrap());
200     /// }
201     /// ```
child_token(&self) -> CancellationToken202     pub fn child_token(&self) -> CancellationToken {
203         let inner = self.state();
204 
205         // Increment the refcount of this token. It will be referenced by the
206         // child, independent of whether the child is immediately cancelled or
207         // not.
208         let _current_state = inner.increment_refcount(inner.snapshot());
209 
210         let mut unpacked_child_state = StateSnapshot {
211             has_parent_ref: true,
212             refcount: 1,
213             cancel_state: CancellationState::NotCancelled,
214         };
215         let mut child_token_state = Box::new(CancellationTokenState::new(
216             Some(self.inner),
217             unpacked_child_state,
218         ));
219 
220         {
221             let mut guard = inner.synchronized.lock().unwrap();
222             if guard.is_cancelled {
223                 // This task was already cancelled. In this case we should not
224                 // insert the child into the list, since it would never get removed
225                 // from the list.
226                 (*child_token_state.synchronized.lock().unwrap()).is_cancelled = true;
227                 unpacked_child_state.cancel_state = CancellationState::Cancelled;
228                 // Since it's not in the list, the parent doesn't need to retain
229                 // a reference to it.
230                 unpacked_child_state.has_parent_ref = false;
231                 child_token_state
232                     .state
233                     .store(unpacked_child_state.pack(), Ordering::SeqCst);
234             } else {
235                 if let Some(mut first_child) = guard.first_child {
236                     child_token_state.from_parent.next_peer = Some(first_child);
237                     // Safety: We manipulate other child task inside the Mutex
238                     // and retain a parent reference on it. The child token can't
239                     // get invalidated while the Mutex is held.
240                     unsafe {
241                         first_child.as_mut().from_parent.prev_peer =
242                             Some((&mut *child_token_state).into())
243                     };
244                 }
245                 guard.first_child = Some((&mut *child_token_state).into());
246             }
247         };
248 
249         let child_token_ptr = Box::into_raw(child_token_state);
250         // Safety: We just created the pointer from a `Box`
251         CancellationToken {
252             inner: unsafe { NonNull::new_unchecked(child_token_ptr) },
253         }
254     }
255 
256     /// Cancel the [`CancellationToken`] and all child tokens which had been
257     /// derived from it.
258     ///
259     /// This will wake up all tasks which are waiting for cancellation.
cancel(&self)260     pub fn cancel(&self) {
261         self.state().cancel();
262     }
263 
264     /// Returns `true` if the `CancellationToken` had been cancelled
is_cancelled(&self) -> bool265     pub fn is_cancelled(&self) -> bool {
266         self.state().is_cancelled()
267     }
268 
269     /// Returns a `Future` that gets fulfilled when cancellation is requested.
cancelled(&self) -> WaitForCancellationFuture<'_>270     pub fn cancelled(&self) -> WaitForCancellationFuture<'_> {
271         WaitForCancellationFuture {
272             cancellation_token: Some(self),
273             wait_node: ListNode::new(WaitQueueEntry::new()),
274             is_registered: false,
275         }
276     }
277 
register( &self, wait_node: &mut ListNode<WaitQueueEntry>, cx: &mut Context<'_>, ) -> Poll<()>278     unsafe fn register(
279         &self,
280         wait_node: &mut ListNode<WaitQueueEntry>,
281         cx: &mut Context<'_>,
282     ) -> Poll<()> {
283         self.state().register(wait_node, cx)
284     }
285 
check_for_cancellation( &self, wait_node: &mut ListNode<WaitQueueEntry>, cx: &mut Context<'_>, ) -> Poll<()>286     fn check_for_cancellation(
287         &self,
288         wait_node: &mut ListNode<WaitQueueEntry>,
289         cx: &mut Context<'_>,
290     ) -> Poll<()> {
291         self.state().check_for_cancellation(wait_node, cx)
292     }
293 
unregister(&self, wait_node: &mut ListNode<WaitQueueEntry>)294     fn unregister(&self, wait_node: &mut ListNode<WaitQueueEntry>) {
295         self.state().unregister(wait_node)
296     }
297 }
298 
299 // ===== impl WaitForCancellationFuture =====
300 
301 impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> {
fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result302     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
303         f.debug_struct("WaitForCancellationFuture").finish()
304     }
305 }
306 
307 impl<'a> Future for WaitForCancellationFuture<'a> {
308     type Output = ();
309 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()>310     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
311         // Safety: We do not move anything out of `WaitForCancellationFuture`
312         let mut_self: &mut WaitForCancellationFuture<'_> = unsafe { Pin::get_unchecked_mut(self) };
313 
314         let cancellation_token = mut_self
315             .cancellation_token
316             .expect("polled WaitForCancellationFuture after completion");
317 
318         let poll_res = if !mut_self.is_registered {
319             // Safety: The `ListNode` is pinned through the Future,
320             // and we will unregister it in `WaitForCancellationFuture::drop`
321             // before the Future is dropped and the memory reference is invalidated.
322             unsafe { cancellation_token.register(&mut mut_self.wait_node, cx) }
323         } else {
324             cancellation_token.check_for_cancellation(&mut mut_self.wait_node, cx)
325         };
326 
327         if let Poll::Ready(()) = poll_res {
328             // The cancellation_token was signalled
329             mut_self.cancellation_token = None;
330             // A signalled Token means the Waker won't be enqueued anymore
331             mut_self.is_registered = false;
332             mut_self.wait_node.task = None;
333         } else {
334             // This `Future` and its stored `Waker` stay registered at the
335             // `CancellationToken`
336             mut_self.is_registered = true;
337         }
338 
339         poll_res
340     }
341 }
342 
343 impl<'a> Drop for WaitForCancellationFuture<'a> {
drop(&mut self)344     fn drop(&mut self) {
345         // If this WaitForCancellationFuture has been polled and it was added to the
346         // wait queue at the cancellation_token, it must be removed before dropping.
347         // Otherwise the cancellation_token would access invalid memory.
348         if let Some(token) = self.cancellation_token {
349             if self.is_registered {
350                 token.unregister(&mut self.wait_node);
351             }
352         }
353     }
354 }
355 
356 /// Tracks how the future had interacted with the [`CancellationToken`]
357 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
358 enum PollState {
359     /// The task has never interacted with the [`CancellationToken`].
360     New,
361     /// The task was added to the wait queue at the [`CancellationToken`].
362     Waiting,
363     /// The task has been polled to completion.
364     Done,
365 }
366 
367 /// Tracks the WaitForCancellationFuture waiting state.
368 /// Access to this struct is synchronized through the mutex in the CancellationToken.
369 struct WaitQueueEntry {
370     /// The task handle of the waiting task
371     task: Option<Waker>,
372     // Current polling state. This state is only updated inside the Mutex of
373     // the CancellationToken.
374     state: PollState,
375 }
376 
377 impl WaitQueueEntry {
378     /// Creates a new WaitQueueEntry
new() -> WaitQueueEntry379     fn new() -> WaitQueueEntry {
380         WaitQueueEntry {
381             task: None,
382             state: PollState::New,
383         }
384     }
385 }
386 
387 struct SynchronizedState {
388     waiters: LinkedList<WaitQueueEntry>,
389     first_child: Option<NonNull<CancellationTokenState>>,
390     is_cancelled: bool,
391 }
392 
393 impl SynchronizedState {
new() -> Self394     fn new() -> Self {
395         Self {
396             waiters: LinkedList::new(),
397             first_child: None,
398             is_cancelled: false,
399         }
400     }
401 }
402 
403 /// Information embedded in child tokens which is synchronized through the Mutex
404 /// in their parent.
405 struct SynchronizedThroughParent {
406     next_peer: Option<NonNull<CancellationTokenState>>,
407     prev_peer: Option<NonNull<CancellationTokenState>>,
408 }
409 
410 /// Possible states of a `CancellationToken`
411 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
412 enum CancellationState {
413     NotCancelled = 0,
414     Cancelling = 1,
415     Cancelled = 2,
416 }
417 
418 impl CancellationState {
pack(self) -> usize419     fn pack(self) -> usize {
420         self as usize
421     }
422 
unpack(value: usize) -> Self423     fn unpack(value: usize) -> Self {
424         match value {
425             0 => CancellationState::NotCancelled,
426             1 => CancellationState::Cancelling,
427             2 => CancellationState::Cancelled,
428             _ => unreachable!("Invalid value"),
429         }
430     }
431 }
432 
433 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
434 struct StateSnapshot {
435     /// The amount of references to this particular CancellationToken.
436     /// `CancellationToken` structs hold these references to a `CancellationTokenState`.
437     /// Also the state is referenced by the state of each child.
438     refcount: usize,
439     /// Whether the state is still referenced by it's parent and can therefore
440     /// not be freed.
441     has_parent_ref: bool,
442     /// Whether the token is cancelled
443     cancel_state: CancellationState,
444 }
445 
446 impl StateSnapshot {
447     /// Packs the snapshot into a `usize`
pack(self) -> usize448     fn pack(self) -> usize {
449         self.refcount << 3 | if self.has_parent_ref { 4 } else { 0 } | self.cancel_state.pack()
450     }
451 
452     /// Unpacks the snapshot from a `usize`
unpack(value: usize) -> Self453     fn unpack(value: usize) -> Self {
454         let refcount = value >> 3;
455         let has_parent_ref = value & 4 != 0;
456         let cancel_state = CancellationState::unpack(value & 0x03);
457 
458         StateSnapshot {
459             refcount,
460             has_parent_ref,
461             cancel_state,
462         }
463     }
464 
465     /// Whether this `CancellationTokenState` is still referenced by any
466     /// `CancellationToken`.
has_refs(&self) -> bool467     fn has_refs(&self) -> bool {
468         self.refcount != 0 || self.has_parent_ref
469     }
470 }
471 
472 /// The maximum permitted amount of references to a CancellationToken. This
473 /// is derived from the intent to never use more than 32bit in the `Snapshot`.
474 const MAX_REFS: u32 = (std::u32::MAX - 7) >> 3;
475 
476 /// Internal state of the `CancellationToken` pair above
477 struct CancellationTokenState {
478     state: AtomicUsize,
479     parent: Option<NonNull<CancellationTokenState>>,
480     from_parent: SynchronizedThroughParent,
481     synchronized: Mutex<SynchronizedState>,
482 }
483 
484 impl CancellationTokenState {
new( parent: Option<NonNull<CancellationTokenState>>, state: StateSnapshot, ) -> CancellationTokenState485     fn new(
486         parent: Option<NonNull<CancellationTokenState>>,
487         state: StateSnapshot,
488     ) -> CancellationTokenState {
489         CancellationTokenState {
490             parent,
491             from_parent: SynchronizedThroughParent {
492                 prev_peer: None,
493                 next_peer: None,
494             },
495             state: AtomicUsize::new(state.pack()),
496             synchronized: Mutex::new(SynchronizedState::new()),
497         }
498     }
499 
500     /// Returns a snapshot of the current atomic state of the token
snapshot(&self) -> StateSnapshot501     fn snapshot(&self) -> StateSnapshot {
502         StateSnapshot::unpack(self.state.load(Ordering::SeqCst))
503     }
504 
atomic_update_state<F>(&self, mut current_state: StateSnapshot, func: F) -> StateSnapshot where F: Fn(StateSnapshot) -> StateSnapshot,505     fn atomic_update_state<F>(&self, mut current_state: StateSnapshot, func: F) -> StateSnapshot
506     where
507         F: Fn(StateSnapshot) -> StateSnapshot,
508     {
509         let mut current_packed_state = current_state.pack();
510         loop {
511             let next_state = func(current_state);
512             match self.state.compare_exchange(
513                 current_packed_state,
514                 next_state.pack(),
515                 Ordering::SeqCst,
516                 Ordering::SeqCst,
517             ) {
518                 Ok(_) => {
519                     return next_state;
520                 }
521                 Err(actual) => {
522                     current_packed_state = actual;
523                     current_state = StateSnapshot::unpack(actual);
524                 }
525             }
526         }
527     }
528 
increment_refcount(&self, current_state: StateSnapshot) -> StateSnapshot529     fn increment_refcount(&self, current_state: StateSnapshot) -> StateSnapshot {
530         self.atomic_update_state(current_state, |mut state: StateSnapshot| {
531             if state.refcount >= MAX_REFS as usize {
532                 eprintln!("[ERROR] Maximum reference count for CancellationToken was exceeded");
533                 std::process::abort();
534             }
535             state.refcount += 1;
536             state
537         })
538     }
539 
decrement_refcount(&self, current_state: StateSnapshot) -> StateSnapshot540     fn decrement_refcount(&self, current_state: StateSnapshot) -> StateSnapshot {
541         let current_state = self.atomic_update_state(current_state, |mut state: StateSnapshot| {
542             state.refcount -= 1;
543             state
544         });
545 
546         // Drop the State if it is not referenced anymore
547         if !current_state.has_refs() {
548             // Safety: `CancellationTokenState` is always stored in refcounted
549             // Boxes
550             let _ = unsafe { Box::from_raw(self as *const Self as *mut Self) };
551         }
552 
553         current_state
554     }
555 
remove_parent_ref(&self, current_state: StateSnapshot) -> StateSnapshot556     fn remove_parent_ref(&self, current_state: StateSnapshot) -> StateSnapshot {
557         let current_state = self.atomic_update_state(current_state, |mut state: StateSnapshot| {
558             state.has_parent_ref = false;
559             state
560         });
561 
562         // Drop the State if it is not referenced anymore
563         if !current_state.has_refs() {
564             // Safety: `CancellationTokenState` is always stored in refcounted
565             // Boxes
566             let _ = unsafe { Box::from_raw(self as *const Self as *mut Self) };
567         }
568 
569         current_state
570     }
571 
572     /// Unregisters a child from the parent token.
573     /// The child tokens state is not exactly known at this point in time.
574     /// If the parent token is cancelled, the child token gets removed from the
575     /// parents list, and might therefore already have been freed. If the parent
576     /// token is not cancelled, the child token is still valid.
unregister_child( &mut self, mut child_state: NonNull<CancellationTokenState>, current_child_state: StateSnapshot, )577     fn unregister_child(
578         &mut self,
579         mut child_state: NonNull<CancellationTokenState>,
580         current_child_state: StateSnapshot,
581     ) {
582         let removed_child = {
583             // Remove the child toke from the parents linked list
584             let mut guard = self.synchronized.lock().unwrap();
585             if !guard.is_cancelled {
586                 // Safety: Since the token was not cancelled, the child must
587                 // still be in the list and valid.
588                 let mut child_state = unsafe { child_state.as_mut() };
589                 debug_assert!(child_state.snapshot().has_parent_ref);
590 
591                 if guard.first_child == Some(child_state.into()) {
592                     guard.first_child = child_state.from_parent.next_peer;
593                 }
594                 // Safety: If peers wouldn't be valid anymore, they would try
595                 // to remove themselves from the list. This would require locking
596                 // the Mutex that we currently own.
597                 unsafe {
598                     if let Some(mut prev_peer) = child_state.from_parent.prev_peer {
599                         prev_peer.as_mut().from_parent.next_peer =
600                             child_state.from_parent.next_peer;
601                     }
602                     if let Some(mut next_peer) = child_state.from_parent.next_peer {
603                         next_peer.as_mut().from_parent.prev_peer =
604                             child_state.from_parent.prev_peer;
605                     }
606                 }
607                 child_state.from_parent.prev_peer = None;
608                 child_state.from_parent.next_peer = None;
609 
610                 // The child is no longer referenced by the parent, since we were able
611                 // to remove its reference from the parents list.
612                 true
613             } else {
614                 // Do not touch the linked list anymore. If the parent is cancelled
615                 // it will move all childs outside of the Mutex and manipulate
616                 // the pointers there. Manipulating the pointers here too could
617                 // lead to races. Therefore leave them just as as and let the
618                 // parent deal with it. The parent will make sure to retain a
619                 // reference to this state as long as it manipulates the list
620                 // pointers. Therefore the pointers are not dangling.
621                 false
622             }
623         };
624 
625         if removed_child {
626             // If the token removed itself from the parents list, it can reset
627             // the parent ref status. If it is isn't able to do so, because the
628             // parent removed it from the list, there is no need to do this.
629             // The parent ref acts as as another reference count. Therefore
630             // removing this reference can free the object.
631             // Safety: The token was in the list. This means the parent wasn't
632             // cancelled before, and the token must still be alive.
633             unsafe { child_state.as_mut().remove_parent_ref(current_child_state) };
634         }
635 
636         // Decrement the refcount on the parent and free it if necessary
637         self.decrement_refcount(self.snapshot());
638     }
639 
cancel(&self)640     fn cancel(&self) {
641         // Move the state of the CancellationToken from `NotCancelled` to `Cancelling`
642         let mut current_state = self.snapshot();
643 
644         let state_after_cancellation = loop {
645             if current_state.cancel_state != CancellationState::NotCancelled {
646                 // Another task already initiated the cancellation
647                 return;
648             }
649 
650             let mut next_state = current_state;
651             next_state.cancel_state = CancellationState::Cancelling;
652             match self.state.compare_exchange(
653                 current_state.pack(),
654                 next_state.pack(),
655                 Ordering::SeqCst,
656                 Ordering::SeqCst,
657             ) {
658                 Ok(_) => break next_state,
659                 Err(actual) => current_state = StateSnapshot::unpack(actual),
660             }
661         };
662 
663         // This task cancelled the token
664 
665         // Take the task list out of the Token
666         // We do not want to cancel child token inside this lock. If one of the
667         // child tasks would have additional child tokens, we would recursively
668         // take locks.
669 
670         // Doing this action has an impact if the child token is dropped concurrently:
671         // It will try to deregister itself from the parent task, but can not find
672         // itself in the task list anymore. Therefore it needs to assume the parent
673         // has extracted the list and will process it. It may not modify the list.
674         // This is OK from a memory safety perspective, since the parent still
675         // retains a reference to the child task until it finished iterating over
676         // it.
677 
678         let mut first_child = {
679             let mut guard = self.synchronized.lock().unwrap();
680             // Save the cancellation also inside the Mutex
681             // This allows child tokens which want to detach themselves to detect
682             // that this is no longer required since the parent cleared the list.
683             guard.is_cancelled = true;
684 
685             // Wakeup all waiters
686             // This happens inside the lock to make cancellation reliable
687             // If we would access waiters outside of the lock, the pointers
688             // may no longer be valid.
689             // Typically this shouldn't be an issue, since waking a task should
690             // only move it from the blocked into the ready state and not have
691             // further side effects.
692 
693             // Use a reverse iterator, so that the oldest waiter gets
694             // scheduled first
695             guard.waiters.reverse_drain(|waiter| {
696                 // We are not allowed to move the `Waker` out of the list node.
697                 // The `Future` relies on the fact that the old `Waker` stays there
698                 // as long as the `Future` has not completed in order to perform
699                 // the `will_wake()` check.
700                 // Therefore `wake_by_ref` is used instead of `wake()`
701                 if let Some(handle) = &mut waiter.task {
702                     handle.wake_by_ref();
703                 }
704                 // Mark the waiter to have been removed from the list.
705                 waiter.state = PollState::Done;
706             });
707 
708             guard.first_child.take()
709         };
710 
711         while let Some(mut child) = first_child {
712             // Safety: We know this is a valid pointer since it is in our child pointer
713             // list. It can't have been freed in between, since we retain a a reference
714             // to each child.
715             let mut_child = unsafe { child.as_mut() };
716 
717             // Get the next child and clean up list pointers
718             first_child = mut_child.from_parent.next_peer;
719             mut_child.from_parent.prev_peer = None;
720             mut_child.from_parent.next_peer = None;
721 
722             // Cancel the child task
723             mut_child.cancel();
724 
725             // Drop the parent reference. This `CancellationToken` is not interested
726             // in interacting with the child anymore.
727             // This is ONLY allowed once we promised not to touch the state anymore
728             // after this interaction.
729             mut_child.remove_parent_ref(mut_child.snapshot());
730         }
731 
732         // The cancellation has completed
733         // At this point in time tasks which registered a wait node can be sure
734         // that this wait node already had been dequeued from the list without
735         // needing to inspect the list.
736         self.atomic_update_state(state_after_cancellation, |mut state| {
737             state.cancel_state = CancellationState::Cancelled;
738             state
739         });
740     }
741 
742     /// Returns `true` if the `CancellationToken` had been cancelled
is_cancelled(&self) -> bool743     fn is_cancelled(&self) -> bool {
744         let current_state = self.snapshot();
745         current_state.cancel_state != CancellationState::NotCancelled
746     }
747 
748     /// Registers a waiting task at the `CancellationToken`.
749     /// Safety: This method is only safe as long as the waiting waiting task
750     /// will properly unregister the wait node before it gets moved.
register( &self, wait_node: &mut ListNode<WaitQueueEntry>, cx: &mut Context<'_>, ) -> Poll<()>751     unsafe fn register(
752         &self,
753         wait_node: &mut ListNode<WaitQueueEntry>,
754         cx: &mut Context<'_>,
755     ) -> Poll<()> {
756         debug_assert_eq!(PollState::New, wait_node.state);
757         let current_state = self.snapshot();
758 
759         // Perform an optimistic cancellation check before. This is not strictly
760         // necessary since we also check for cancellation in the Mutex, but
761         // reduces the necessary work to be performed for tasks which already
762         // had been cancelled.
763         if current_state.cancel_state != CancellationState::NotCancelled {
764             return Poll::Ready(());
765         }
766 
767         // So far the token is not cancelled. However it could be cancelld before
768         // we get the chance to store the `Waker`. Therfore we need to check
769         // for cancellation again inside the mutex.
770         let mut guard = self.synchronized.lock().unwrap();
771         if guard.is_cancelled {
772             // Cancellation was signalled
773             wait_node.state = PollState::Done;
774             Poll::Ready(())
775         } else {
776             // Added the task to the wait queue
777             wait_node.task = Some(cx.waker().clone());
778             wait_node.state = PollState::Waiting;
779             guard.waiters.add_front(wait_node);
780             Poll::Pending
781         }
782     }
783 
check_for_cancellation( &self, wait_node: &mut ListNode<WaitQueueEntry>, cx: &mut Context<'_>, ) -> Poll<()>784     fn check_for_cancellation(
785         &self,
786         wait_node: &mut ListNode<WaitQueueEntry>,
787         cx: &mut Context<'_>,
788     ) -> Poll<()> {
789         debug_assert!(
790             wait_node.task.is_some(),
791             "Method can only be called after task had been registered"
792         );
793 
794         let current_state = self.snapshot();
795 
796         if current_state.cancel_state != CancellationState::NotCancelled {
797             // If the cancellation had been fully completed we know that our `Waker`
798             // is no longer registered at the `CancellationToken`.
799             // Otherwise the cancel call may or may not yet have iterated
800             // through the waiters list and removed the wait nodes.
801             // If it hasn't yet, we need to remove it. Otherwise an attempt to
802             // reuse the `wait_node´ might get freed due to the `WaitForCancellationFuture`
803             // getting dropped before the cancellation had interacted with it.
804             if current_state.cancel_state != CancellationState::Cancelled {
805                 self.unregister(wait_node);
806             }
807             Poll::Ready(())
808         } else {
809             // Check if we need to swap the `Waker`. This will make the check more
810             // expensive, since the `Waker` is synchronized through the Mutex.
811             // If we don't need to perform a `Waker` update, an atomic check for
812             // cancellation is sufficient.
813             let need_waker_update = wait_node
814                 .task
815                 .as_ref()
816                 .map(|waker| waker.will_wake(cx.waker()))
817                 .unwrap_or(true);
818 
819             if need_waker_update {
820                 let guard = self.synchronized.lock().unwrap();
821                 if guard.is_cancelled {
822                     // Cancellation was signalled. Since this cancellation signal
823                     // is set inside the Mutex, the old waiter must already have
824                     // been removed from the waiting list
825                     debug_assert_eq!(PollState::Done, wait_node.state);
826                     wait_node.task = None;
827                     Poll::Ready(())
828                 } else {
829                     // The WaitForCancellationFuture is already in the queue.
830                     // The CancellationToken can't have been cancelled,
831                     // since this would change the is_cancelled flag inside the mutex.
832                     // Therefore we just have to update the Waker. A follow-up
833                     // cancellation will always use the new waker.
834                     wait_node.task = Some(cx.waker().clone());
835                     Poll::Pending
836                 }
837             } else {
838                 // Do nothing. If the token gets cancelled, this task will get
839                 // woken again and can fetch the cancellation.
840                 Poll::Pending
841             }
842         }
843     }
844 
unregister(&self, wait_node: &mut ListNode<WaitQueueEntry>)845     fn unregister(&self, wait_node: &mut ListNode<WaitQueueEntry>) {
846         debug_assert!(
847             wait_node.task.is_some(),
848             "waiter can not be active without task"
849         );
850 
851         let mut guard = self.synchronized.lock().unwrap();
852         // WaitForCancellationFuture only needs to get removed if it has been added to
853         // the wait queue of the CancellationToken.
854         // This has happened in the PollState::Waiting case.
855         if let PollState::Waiting = wait_node.state {
856             // Safety: Due to the state, we know that the node must be part
857             // of the waiter list
858             if !unsafe { guard.waiters.remove(wait_node) } {
859                 // Panic if the address isn't found. This can only happen if the contract was
860                 // violated, e.g. the WaitQueueEntry got moved after the initial poll.
861                 panic!("Future could not be removed from wait queue");
862             }
863             wait_node.state = PollState::Done;
864         }
865         wait_node.task = None;
866     }
867 }
868