1 use crate::task::{ArcWake, waker_ref}; 2 use futures_core::future::{FusedFuture, Future}; 3 use futures_core::task::{Context, Poll, Waker}; 4 use slab::Slab; 5 use std::cell::UnsafeCell; 6 use std::fmt; 7 use std::pin::Pin; 8 use std::sync::atomic::AtomicUsize; 9 use std::sync::atomic::Ordering::SeqCst; 10 use std::sync::{Arc, Mutex}; 11 12 /// Future for the [`shared`](super::FutureExt::shared) method. 13 #[must_use = "futures do nothing unless you `.await` or poll them"] 14 pub struct Shared<Fut: Future> { 15 inner: Option<Arc<Inner<Fut>>>, 16 waker_key: usize, 17 } 18 19 struct Inner<Fut: Future> { 20 future_or_output: UnsafeCell<FutureOrOutput<Fut>>, 21 notifier: Arc<Notifier>, 22 } 23 24 struct Notifier { 25 state: AtomicUsize, 26 wakers: Mutex<Option<Slab<Option<Waker>>>>, 27 } 28 29 // The future itself is polled behind the `Arc`, so it won't be moved 30 // when `Shared` is moved. 31 impl<Fut: Future> Unpin for Shared<Fut> {} 32 33 impl<Fut: Future> fmt::Debug for Shared<Fut> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 35 f.debug_struct("Shared") 36 .field("inner", &self.inner) 37 .field("waker_key", &self.waker_key) 38 .finish() 39 } 40 } 41 42 impl<Fut: Future> fmt::Debug for Inner<Fut> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 44 f.debug_struct("Inner").finish() 45 } 46 } 47 48 enum FutureOrOutput<Fut: Future> { 49 Future(Fut), 50 Output(Fut::Output), 51 } 52 53 unsafe impl<Fut> Send for Inner<Fut> 54 where 55 Fut: Future + Send, 56 Fut::Output: Send + Sync, 57 {} 58 59 unsafe impl<Fut> Sync for Inner<Fut> 60 where 61 Fut: Future + Send, 62 Fut::Output: Send + Sync, 63 {} 64 65 const IDLE: usize = 0; 66 const POLLING: usize = 1; 67 const REPOLL: usize = 2; 68 const COMPLETE: usize = 3; 69 const POISONED: usize = 4; 70 71 const NULL_WAKER_KEY: usize = usize::max_value(); 72 73 impl<Fut: Future> Shared<Fut> { new(future: Fut) -> Shared<Fut>74 pub(super) fn new(future: Fut) -> Shared<Fut> { 75 let inner = Inner { 76 future_or_output: UnsafeCell::new(FutureOrOutput::Future(future)), 77 notifier: Arc::new(Notifier { 78 state: AtomicUsize::new(IDLE), 79 wakers: Mutex::new(Some(Slab::new())), 80 }), 81 }; 82 83 Shared { 84 inner: Some(Arc::new(inner)), 85 waker_key: NULL_WAKER_KEY, 86 } 87 } 88 } 89 90 impl<Fut> Shared<Fut> 91 where 92 Fut: Future, 93 Fut::Output: Clone, 94 { 95 /// Returns [`Some`] containing a reference to this [`Shared`]'s output if 96 /// it has already been computed by a clone or [`None`] if it hasn't been 97 /// computed yet or this [`Shared`] already returned its output from 98 /// [`poll`](Future::poll). peek(&self) -> Option<&Fut::Output>99 pub fn peek(&self) -> Option<&Fut::Output> { 100 if let Some(inner) = self.inner.as_ref() { 101 match inner.notifier.state.load(SeqCst) { 102 COMPLETE => unsafe { return Some(inner.output()) }, 103 POISONED => panic!("inner future panicked during poll"), 104 _ => {} 105 } 106 } 107 None 108 } 109 110 /// Registers the current task to receive a wakeup when `Inner` is awoken. set_waker(&mut self, cx: &mut Context<'_>)111 fn set_waker(&mut self, cx: &mut Context<'_>) { 112 // Acquire the lock first before checking COMPLETE to ensure there 113 // isn't a race. 114 let mut wakers_guard = if let Some(inner) = self.inner.as_ref() { 115 inner.notifier.wakers.lock().unwrap() 116 } else { 117 return; 118 }; 119 120 let wakers = if let Some(wakers) = wakers_guard.as_mut() { 121 wakers 122 } else { 123 return; 124 }; 125 126 if self.waker_key == NULL_WAKER_KEY { 127 self.waker_key = wakers.insert(Some(cx.waker().clone())); 128 } else { 129 let waker_slot = &mut wakers[self.waker_key]; 130 let needs_replacement = if let Some(_old_waker) = waker_slot { 131 // If there's still an unwoken waker in the slot, only replace 132 // if the current one wouldn't wake the same task. 133 // TODO: This API is currently not available, so replace always 134 // !waker.will_wake_nonlocal(old_waker) 135 true 136 } else { 137 true 138 }; 139 if needs_replacement { 140 *waker_slot = Some(cx.waker().clone()); 141 } 142 } 143 debug_assert!(self.waker_key != NULL_WAKER_KEY); 144 } 145 146 /// Safety: callers must first ensure that `self.inner.state` 147 /// is `COMPLETE` take_or_clone_output(&mut self) -> Fut::Output148 unsafe fn take_or_clone_output(&mut self) -> Fut::Output { 149 let inner = self.inner.take().unwrap(); 150 151 match Arc::try_unwrap(inner) { 152 Ok(inner) => match inner.future_or_output.into_inner() { 153 FutureOrOutput::Output(item) => item, 154 FutureOrOutput::Future(_) => unreachable!(), 155 }, 156 Err(inner) => inner.output().clone(), 157 } 158 } 159 } 160 161 impl<Fut> Inner<Fut> 162 where 163 Fut: Future, 164 Fut::Output: Clone, 165 { 166 /// Safety: callers must first ensure that `self.inner.state` 167 /// is `COMPLETE` output(&self) -> &Fut::Output168 unsafe fn output(&self) -> &Fut::Output { 169 match &*self.future_or_output.get() { 170 FutureOrOutput::Output(ref item) => &item, 171 FutureOrOutput::Future(_) => unreachable!(), 172 } 173 } 174 } 175 176 impl<Fut> FusedFuture for Shared<Fut> 177 where 178 Fut: Future, 179 Fut::Output: Clone, 180 { is_terminated(&self) -> bool181 fn is_terminated(&self) -> bool { 182 self.inner.is_none() 183 } 184 } 185 186 impl<Fut> Future for Shared<Fut> 187 where 188 Fut: Future, 189 Fut::Output: Clone, 190 { 191 type Output = Fut::Output; 192 poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>193 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 194 let this = &mut *self; 195 196 this.set_waker(cx); 197 198 let inner = if let Some(inner) = this.inner.as_ref() { 199 inner 200 } else { 201 panic!("Shared future polled again after completion"); 202 }; 203 204 match inner.notifier.state.compare_and_swap(IDLE, POLLING, SeqCst) { 205 IDLE => { 206 // Lock acquired, fall through 207 } 208 POLLING | REPOLL => { 209 // Another task is currently polling, at this point we just want 210 // to ensure that the waker for this task is registered 211 212 return Poll::Pending; 213 } 214 COMPLETE => { 215 // Safety: We're in the COMPLETE state 216 return unsafe { Poll::Ready(this.take_or_clone_output()) }; 217 } 218 POISONED => panic!("inner future panicked during poll"), 219 _ => unreachable!(), 220 } 221 222 let waker = waker_ref(&inner.notifier); 223 let mut cx = Context::from_waker(&waker); 224 225 struct Reset<'a>(&'a AtomicUsize); 226 227 impl Drop for Reset<'_> { 228 fn drop(&mut self) { 229 use std::thread; 230 231 if thread::panicking() { 232 self.0.store(POISONED, SeqCst); 233 } 234 } 235 } 236 237 let _reset = Reset(&inner.notifier.state); 238 239 let output = loop { 240 let future = unsafe { 241 match &mut *inner.future_or_output.get() { 242 FutureOrOutput::Future(fut) => Pin::new_unchecked(fut), 243 _ => unreachable!(), 244 } 245 }; 246 247 let poll = future.poll(&mut cx); 248 249 match poll { 250 Poll::Pending => { 251 let state = &inner.notifier.state; 252 match state.compare_and_swap(POLLING, IDLE, SeqCst) { 253 POLLING => { 254 // Success 255 return Poll::Pending; 256 } 257 REPOLL => { 258 // Was woken since: Gotta poll again! 259 let prev = state.swap(POLLING, SeqCst); 260 assert_eq!(prev, REPOLL); 261 } 262 _ => unreachable!(), 263 } 264 } 265 Poll::Ready(output) => break output, 266 } 267 }; 268 269 unsafe { 270 *inner.future_or_output.get() = 271 FutureOrOutput::Output(output); 272 } 273 274 inner.notifier.state.store(COMPLETE, SeqCst); 275 276 // Wake all tasks and drop the slab 277 let mut wakers_guard = inner.notifier.wakers.lock().unwrap(); 278 let wakers = &mut wakers_guard.take().unwrap(); 279 for (_key, opt_waker) in wakers { 280 if let Some(waker) = opt_waker.take() { 281 waker.wake(); 282 } 283 } 284 285 drop(_reset); // Make borrow checker happy 286 drop(wakers_guard); 287 288 // Safety: We're in the COMPLETE state 289 unsafe { Poll::Ready(this.take_or_clone_output()) } 290 } 291 } 292 293 impl<Fut> Clone for Shared<Fut> 294 where 295 Fut: Future, 296 { clone(&self) -> Self297 fn clone(&self) -> Self { 298 Shared { 299 inner: self.inner.clone(), 300 waker_key: NULL_WAKER_KEY, 301 } 302 } 303 } 304 305 impl<Fut> Drop for Shared<Fut> 306 where 307 Fut: Future, 308 { drop(&mut self)309 fn drop(&mut self) { 310 if self.waker_key != NULL_WAKER_KEY { 311 if let Some(ref inner) = self.inner { 312 if let Ok(mut wakers) = inner.notifier.wakers.lock() { 313 if let Some(wakers) = wakers.as_mut() { 314 wakers.remove(self.waker_key); 315 } 316 } 317 } 318 } 319 } 320 } 321 322 impl ArcWake for Notifier { wake_by_ref(arc_self: &Arc<Self>)323 fn wake_by_ref(arc_self: &Arc<Self>) { 324 arc_self.state.compare_and_swap(POLLING, REPOLL, SeqCst); 325 326 let wakers = &mut *arc_self.wakers.lock().unwrap(); 327 if let Some(wakers) = wakers.as_mut() { 328 for (_key, opt_waker) in wakers { 329 if let Some(waker) = opt_waker.take() { 330 waker.wake(); 331 } 332 } 333 } 334 } 335 } 336