1 use crate::task::AtomicWaker; 2 use alloc::sync::Arc; 3 use core::fmt; 4 use core::pin::Pin; 5 use core::sync::atomic::{AtomicBool, Ordering}; 6 use futures_core::future::Future; 7 use futures_core::task::{Context, Poll}; 8 use futures_core::Stream; 9 use pin_project_lite::pin_project; 10 11 pin_project! { 12 /// A future/stream which can be remotely short-circuited using an `AbortHandle`. 13 #[derive(Debug, Clone)] 14 #[must_use = "futures/streams do nothing unless you poll them"] 15 pub struct Abortable<T> { 16 #[pin] 17 task: T, 18 inner: Arc<AbortInner>, 19 } 20 } 21 22 impl<T> Abortable<T> { 23 /// Creates a new `Abortable` future/stream using an existing `AbortRegistration`. 24 /// `AbortRegistration`s can be acquired through `AbortHandle::new`. 25 /// 26 /// When `abort` is called on the handle tied to `reg` or if `abort` has 27 /// already been called, the future/stream will complete immediately without making 28 /// any further progress. 29 /// 30 /// # Examples: 31 /// 32 /// Usage with futures: 33 /// 34 /// ``` 35 /// # futures::executor::block_on(async { 36 /// use futures::future::{Abortable, AbortHandle, Aborted}; 37 /// 38 /// let (abort_handle, abort_registration) = AbortHandle::new_pair(); 39 /// let future = Abortable::new(async { 2 }, abort_registration); 40 /// abort_handle.abort(); 41 /// assert_eq!(future.await, Err(Aborted)); 42 /// # }); 43 /// ``` 44 /// 45 /// Usage with streams: 46 /// 47 /// ``` 48 /// # futures::executor::block_on(async { 49 /// # use futures::future::{Abortable, AbortHandle}; 50 /// # use futures::stream::{self, StreamExt}; 51 /// 52 /// let (abort_handle, abort_registration) = AbortHandle::new_pair(); 53 /// let mut stream = Abortable::new(stream::iter(vec![1, 2, 3]), abort_registration); 54 /// abort_handle.abort(); 55 /// assert_eq!(stream.next().await, None); 56 /// # }); 57 /// ``` new(task: T, reg: AbortRegistration) -> Self58 pub fn new(task: T, reg: AbortRegistration) -> Self { 59 Self { task, inner: reg.inner } 60 } 61 62 /// Checks whether the task has been aborted. Note that all this 63 /// method indicates is whether [`AbortHandle::abort`] was *called*. 64 /// This means that it will return `true` even if: 65 /// * `abort` was called after the task had completed. 66 /// * `abort` was called while the task was being polled - the task may still be running and 67 /// will not be stopped until `poll` returns. is_aborted(&self) -> bool68 pub fn is_aborted(&self) -> bool { 69 self.inner.aborted.load(Ordering::Relaxed) 70 } 71 } 72 73 /// A registration handle for an `Abortable` task. 74 /// Values of this type can be acquired from `AbortHandle::new` and are used 75 /// in calls to `Abortable::new`. 76 #[derive(Debug)] 77 pub struct AbortRegistration { 78 inner: Arc<AbortInner>, 79 } 80 81 /// A handle to an `Abortable` task. 82 #[derive(Debug, Clone)] 83 pub struct AbortHandle { 84 inner: Arc<AbortInner>, 85 } 86 87 impl AbortHandle { 88 /// Creates an (`AbortHandle`, `AbortRegistration`) pair which can be used 89 /// to abort a running future or stream. 90 /// 91 /// This function is usually paired with a call to [`Abortable::new`]. new_pair() -> (Self, AbortRegistration)92 pub fn new_pair() -> (Self, AbortRegistration) { 93 let inner = 94 Arc::new(AbortInner { waker: AtomicWaker::new(), aborted: AtomicBool::new(false) }); 95 96 (Self { inner: inner.clone() }, AbortRegistration { inner }) 97 } 98 } 99 100 // Inner type storing the waker to awaken and a bool indicating that it 101 // should be aborted. 102 #[derive(Debug)] 103 struct AbortInner { 104 waker: AtomicWaker, 105 aborted: AtomicBool, 106 } 107 108 /// Indicator that the `Abortable` task was aborted. 109 #[derive(Copy, Clone, Debug, Eq, PartialEq)] 110 pub struct Aborted; 111 112 impl fmt::Display for Aborted { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 114 write!(f, "`Abortable` future has been aborted") 115 } 116 } 117 118 #[cfg(feature = "std")] 119 impl std::error::Error for Aborted {} 120 121 impl<T> Abortable<T> { try_poll<I>( mut self: Pin<&mut Self>, cx: &mut Context<'_>, poll: impl Fn(Pin<&mut T>, &mut Context<'_>) -> Poll<I>, ) -> Poll<Result<I, Aborted>>122 fn try_poll<I>( 123 mut self: Pin<&mut Self>, 124 cx: &mut Context<'_>, 125 poll: impl Fn(Pin<&mut T>, &mut Context<'_>) -> Poll<I>, 126 ) -> Poll<Result<I, Aborted>> { 127 // Check if the task has been aborted 128 if self.is_aborted() { 129 return Poll::Ready(Err(Aborted)); 130 } 131 132 // attempt to complete the task 133 if let Poll::Ready(x) = poll(self.as_mut().project().task, cx) { 134 return Poll::Ready(Ok(x)); 135 } 136 137 // Register to receive a wakeup if the task is aborted in the future 138 self.inner.waker.register(cx.waker()); 139 140 // Check to see if the task was aborted between the first check and 141 // registration. 142 // Checking with `is_aborted` which uses `Relaxed` is sufficient because 143 // `register` introduces an `AcqRel` barrier. 144 if self.is_aborted() { 145 return Poll::Ready(Err(Aborted)); 146 } 147 148 Poll::Pending 149 } 150 } 151 152 impl<Fut> Future for Abortable<Fut> 153 where 154 Fut: Future, 155 { 156 type Output = Result<Fut::Output, Aborted>; 157 poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>158 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 159 self.try_poll(cx, |fut, cx| fut.poll(cx)) 160 } 161 } 162 163 impl<St> Stream for Abortable<St> 164 where 165 St: Stream, 166 { 167 type Item = St::Item; 168 poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>169 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { 170 self.try_poll(cx, |stream, cx| stream.poll_next(cx)).map(Result::ok).map(Option::flatten) 171 } 172 } 173 174 impl AbortHandle { 175 /// Abort the `Abortable` stream/future associated with this handle. 176 /// 177 /// Notifies the Abortable task associated with this handle that it 178 /// should abort. Note that if the task is currently being polled on 179 /// another thread, it will not immediately stop running. Instead, it will 180 /// continue to run until its poll method returns. abort(&self)181 pub fn abort(&self) { 182 self.inner.aborted.store(true, Ordering::Relaxed); 183 self.inner.waker.wake(); 184 } 185 } 186