1 use crate::task::{ArcWake, waker_ref};
2 use futures_core::future::{FusedFuture, Future};
3 use futures_core::task::{Context, Poll, Waker};
4 use slab::Slab;
5 use std::cell::UnsafeCell;
6 use std::fmt;
7 use std::pin::Pin;
8 use std::sync::atomic::AtomicUsize;
9 use std::sync::atomic::Ordering::SeqCst;
10 use std::sync::{Arc, Mutex};
11 
12 /// Future for the [`shared`](super::FutureExt::shared) method.
13 #[must_use = "futures do nothing unless you `.await` or poll them"]
14 pub struct Shared<Fut: Future> {
15     inner: Option<Arc<Inner<Fut>>>,
16     waker_key: usize,
17 }
18 
19 struct Inner<Fut: Future> {
20     future_or_output: UnsafeCell<FutureOrOutput<Fut>>,
21     notifier: Arc<Notifier>,
22 }
23 
24 struct Notifier {
25     state: AtomicUsize,
26     wakers: Mutex<Option<Slab<Option<Waker>>>>,
27 }
28 
29 // The future itself is polled behind the `Arc`, so it won't be moved
30 // when `Shared` is moved.
31 impl<Fut: Future> Unpin for Shared<Fut> {}
32 
33 impl<Fut: Future> fmt::Debug for Shared<Fut> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result34     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35         f.debug_struct("Shared")
36             .field("inner", &self.inner)
37             .field("waker_key", &self.waker_key)
38             .finish()
39     }
40 }
41 
42 impl<Fut: Future> fmt::Debug for Inner<Fut> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result43     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44         f.debug_struct("Inner").finish()
45     }
46 }
47 
48 enum FutureOrOutput<Fut: Future> {
49     Future(Fut),
50     Output(Fut::Output),
51 }
52 
53 unsafe impl<Fut> Send for Inner<Fut>
54 where
55     Fut: Future + Send,
56     Fut::Output: Send + Sync,
57 {}
58 
59 unsafe impl<Fut> Sync for Inner<Fut>
60 where
61     Fut: Future + Send,
62     Fut::Output: Send + Sync,
63 {}
64 
65 const IDLE: usize = 0;
66 const POLLING: usize = 1;
67 const REPOLL: usize = 2;
68 const COMPLETE: usize = 3;
69 const POISONED: usize = 4;
70 
71 const NULL_WAKER_KEY: usize = usize::max_value();
72 
73 impl<Fut: Future> Shared<Fut> {
new(future: Fut) -> Shared<Fut>74     pub(super) fn new(future: Fut) -> Shared<Fut> {
75         let inner = Inner {
76             future_or_output: UnsafeCell::new(FutureOrOutput::Future(future)),
77             notifier: Arc::new(Notifier {
78                 state: AtomicUsize::new(IDLE),
79                 wakers: Mutex::new(Some(Slab::new())),
80             }),
81         };
82 
83         Shared {
84             inner: Some(Arc::new(inner)),
85             waker_key: NULL_WAKER_KEY,
86         }
87     }
88 }
89 
90 impl<Fut> Shared<Fut>
91 where
92     Fut: Future,
93     Fut::Output: Clone,
94 {
95     /// Returns [`Some`] containing a reference to this [`Shared`]'s output if
96     /// it has already been computed by a clone or [`None`] if it hasn't been
97     /// computed yet or this [`Shared`] already returned its output from
98     /// [`poll`](Future::poll).
peek(&self) -> Option<&Fut::Output>99     pub fn peek(&self) -> Option<&Fut::Output> {
100         if let Some(inner) = self.inner.as_ref() {
101             match inner.notifier.state.load(SeqCst) {
102                 COMPLETE => unsafe { return Some(inner.output()) },
103                 POISONED => panic!("inner future panicked during poll"),
104                 _ => {}
105             }
106         }
107         None
108     }
109 
110     /// Registers the current task to receive a wakeup when `Inner` is awoken.
set_waker(&mut self, cx: &mut Context<'_>)111     fn set_waker(&mut self, cx: &mut Context<'_>) {
112         // Acquire the lock first before checking COMPLETE to ensure there
113         // isn't a race.
114         let mut wakers_guard = if let Some(inner) = self.inner.as_ref() {
115             inner.notifier.wakers.lock().unwrap()
116         } else {
117             return;
118         };
119 
120         let wakers = if let Some(wakers) = wakers_guard.as_mut() {
121             wakers
122         } else {
123             return;
124         };
125 
126         if self.waker_key == NULL_WAKER_KEY {
127             self.waker_key = wakers.insert(Some(cx.waker().clone()));
128         } else {
129             let waker_slot = &mut wakers[self.waker_key];
130             let needs_replacement = if let Some(_old_waker) = waker_slot {
131                 // If there's still an unwoken waker in the slot, only replace
132                 // if the current one wouldn't wake the same task.
133                 // TODO: This API is currently not available, so replace always
134                 // !waker.will_wake_nonlocal(old_waker)
135                 true
136             } else {
137                 true
138             };
139             if needs_replacement {
140                 *waker_slot = Some(cx.waker().clone());
141             }
142         }
143         debug_assert!(self.waker_key != NULL_WAKER_KEY);
144     }
145 
146     /// Safety: callers must first ensure that `self.inner.state`
147     /// is `COMPLETE`
take_or_clone_output(&mut self) -> Fut::Output148     unsafe fn take_or_clone_output(&mut self) -> Fut::Output {
149         let inner = self.inner.take().unwrap();
150 
151         match Arc::try_unwrap(inner) {
152             Ok(inner) => match inner.future_or_output.into_inner() {
153                 FutureOrOutput::Output(item) => item,
154                 FutureOrOutput::Future(_) => unreachable!(),
155             },
156             Err(inner) => inner.output().clone(),
157         }
158     }
159 }
160 
161 impl<Fut> Inner<Fut>
162 where
163     Fut: Future,
164     Fut::Output: Clone,
165 {
166     /// Safety: callers must first ensure that `self.inner.state`
167     /// is `COMPLETE`
output(&self) -> &Fut::Output168     unsafe fn output(&self) -> &Fut::Output {
169         match &*self.future_or_output.get() {
170             FutureOrOutput::Output(ref item) => &item,
171             FutureOrOutput::Future(_) => unreachable!(),
172         }
173     }
174 }
175 
176 impl<Fut> FusedFuture for Shared<Fut>
177 where
178     Fut: Future,
179     Fut::Output: Clone,
180 {
is_terminated(&self) -> bool181     fn is_terminated(&self) -> bool {
182         self.inner.is_none()
183     }
184 }
185 
186 impl<Fut> Future for Shared<Fut>
187 where
188     Fut: Future,
189     Fut::Output: Clone,
190 {
191     type Output = Fut::Output;
192 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>193     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
194         let this = &mut *self;
195 
196         this.set_waker(cx);
197 
198         let inner = if let Some(inner) = this.inner.as_ref() {
199             inner
200         } else {
201             panic!("Shared future polled again after completion");
202         };
203 
204         match inner.notifier.state.compare_and_swap(IDLE, POLLING, SeqCst) {
205             IDLE => {
206                 // Lock acquired, fall through
207             }
208             POLLING | REPOLL => {
209                 // Another task is currently polling, at this point we just want
210                 // to ensure that the waker for this task is registered
211 
212                 return Poll::Pending;
213             }
214             COMPLETE => {
215                 // Safety: We're in the COMPLETE state
216                 return unsafe { Poll::Ready(this.take_or_clone_output()) };
217             }
218             POISONED => panic!("inner future panicked during poll"),
219             _ => unreachable!(),
220         }
221 
222         let waker = waker_ref(&inner.notifier);
223         let mut cx = Context::from_waker(&waker);
224 
225         struct Reset<'a>(&'a AtomicUsize);
226 
227         impl Drop for Reset<'_> {
228             fn drop(&mut self) {
229                 use std::thread;
230 
231                 if thread::panicking() {
232                     self.0.store(POISONED, SeqCst);
233                 }
234             }
235         }
236 
237         let _reset = Reset(&inner.notifier.state);
238 
239         let output = loop {
240             let future = unsafe {
241                 match &mut *inner.future_or_output.get() {
242                     FutureOrOutput::Future(fut) => Pin::new_unchecked(fut),
243                     _ => unreachable!(),
244                 }
245             };
246 
247             let poll = future.poll(&mut cx);
248 
249             match poll {
250                 Poll::Pending => {
251                     let state = &inner.notifier.state;
252                     match state.compare_and_swap(POLLING, IDLE, SeqCst) {
253                         POLLING => {
254                             // Success
255                             return Poll::Pending;
256                         }
257                         REPOLL => {
258                             // Was woken since: Gotta poll again!
259                             let prev = state.swap(POLLING, SeqCst);
260                             assert_eq!(prev, REPOLL);
261                         }
262                         _ => unreachable!(),
263                     }
264                 }
265                 Poll::Ready(output) => break output,
266             }
267         };
268 
269         unsafe {
270             *inner.future_or_output.get() =
271                 FutureOrOutput::Output(output);
272         }
273 
274         inner.notifier.state.store(COMPLETE, SeqCst);
275 
276         // Wake all tasks and drop the slab
277         let mut wakers_guard = inner.notifier.wakers.lock().unwrap();
278         let wakers = &mut wakers_guard.take().unwrap();
279         for (_key, opt_waker) in wakers {
280             if let Some(waker) = opt_waker.take() {
281                 waker.wake();
282             }
283         }
284 
285         drop(_reset); // Make borrow checker happy
286         drop(wakers_guard);
287 
288         // Safety: We're in the COMPLETE state
289         unsafe { Poll::Ready(this.take_or_clone_output()) }
290     }
291 }
292 
293 impl<Fut> Clone for Shared<Fut>
294 where
295     Fut: Future,
296 {
clone(&self) -> Self297     fn clone(&self) -> Self {
298         Shared {
299             inner: self.inner.clone(),
300             waker_key: NULL_WAKER_KEY,
301         }
302     }
303 }
304 
305 impl<Fut> Drop for Shared<Fut>
306 where
307     Fut: Future,
308 {
drop(&mut self)309     fn drop(&mut self) {
310         if self.waker_key != NULL_WAKER_KEY {
311             if let Some(ref inner) = self.inner {
312                 if let Ok(mut wakers) = inner.notifier.wakers.lock() {
313                     if let Some(wakers) = wakers.as_mut() {
314                         wakers.remove(self.waker_key);
315                     }
316                 }
317             }
318         }
319     }
320 }
321 
322 impl ArcWake for Notifier {
wake_by_ref(arc_self: &Arc<Self>)323     fn wake_by_ref(arc_self: &Arc<Self>) {
324         arc_self.state.compare_and_swap(POLLING, REPOLL, SeqCst);
325 
326         let wakers = &mut *arc_self.wakers.lock().unwrap();
327         if let Some(wakers) = wakers.as_mut() {
328             for (_key, opt_waker) in wakers {
329                 if let Some(waker) = opt_waker.take() {
330                     waker.wake();
331                 }
332             }
333         }
334     }
335 }
336