1 use std::cell::{Cell, RefCell};
2 use std::fmt;
3 use std::marker::PhantomData;
4 
5 #[derive(Debug, Clone, Copy)]
6 pub(crate) enum EnterContext {
7     #[cfg_attr(not(feature = "rt"), allow(dead_code))]
8     Entered {
9         allow_blocking: bool,
10     },
11     NotEntered,
12 }
13 
14 impl EnterContext {
is_entered(self) -> bool15     pub(crate) fn is_entered(self) -> bool {
16         matches!(self, EnterContext::Entered { .. })
17     }
18 }
19 
20 thread_local!(static ENTERED: Cell<EnterContext> = Cell::new(EnterContext::NotEntered));
21 
22 /// Represents an executor context.
23 pub(crate) struct Enter {
24     _p: PhantomData<RefCell<()>>,
25 }
26 
27 cfg_rt! {
28     use crate::park::thread::ParkError;
29 
30     use std::time::Duration;
31 
32     /// Marks the current thread as being within the dynamic extent of an
33     /// executor.
34     pub(crate) fn enter(allow_blocking: bool) -> Enter {
35         if let Some(enter) = try_enter(allow_blocking) {
36             return enter;
37         }
38 
39         panic!(
40             "Cannot start a runtime from within a runtime. This happens \
41             because a function (like `block_on`) attempted to block the \
42             current thread while the thread is being used to drive \
43             asynchronous tasks."
44         );
45     }
46 
47     /// Tries to enter a runtime context, returns `None` if already in a runtime
48     /// context.
49     pub(crate) fn try_enter(allow_blocking: bool) -> Option<Enter> {
50         ENTERED.with(|c| {
51             if c.get().is_entered() {
52                 None
53             } else {
54                 c.set(EnterContext::Entered { allow_blocking });
55                 Some(Enter { _p: PhantomData })
56             }
57         })
58     }
59 }
60 
61 // Forces the current "entered" state to be cleared while the closure
62 // is executed.
63 //
64 // # Warning
65 //
66 // This is hidden for a reason. Do not use without fully understanding
67 // executors. Misusing can easily cause your program to deadlock.
68 cfg_rt_multi_thread! {
69     pub(crate) fn exit<F: FnOnce() -> R, R>(f: F) -> R {
70         // Reset in case the closure panics
71         struct Reset(EnterContext);
72         impl Drop for Reset {
73             fn drop(&mut self) {
74                 ENTERED.with(|c| {
75                     assert!(!c.get().is_entered(), "closure claimed permanent executor");
76                     c.set(self.0);
77                 });
78             }
79         }
80 
81         let was = ENTERED.with(|c| {
82             let e = c.get();
83             assert!(e.is_entered(), "asked to exit when not entered");
84             c.set(EnterContext::NotEntered);
85             e
86         });
87 
88         let _reset = Reset(was);
89         // dropping _reset after f() will reset ENTERED
90         f()
91     }
92 }
93 
94 cfg_rt! {
95     /// Disallows blocking in the current runtime context until the guard is dropped.
96     pub(crate) fn disallow_blocking() -> DisallowBlockingGuard {
97         let reset = ENTERED.with(|c| {
98             if let EnterContext::Entered {
99                 allow_blocking: true,
100             } = c.get()
101             {
102                 c.set(EnterContext::Entered {
103                     allow_blocking: false,
104                 });
105                 true
106             } else {
107                 false
108             }
109         });
110         DisallowBlockingGuard(reset)
111     }
112 
113     pub(crate) struct DisallowBlockingGuard(bool);
114     impl Drop for DisallowBlockingGuard {
115         fn drop(&mut self) {
116             if self.0 {
117                 // XXX: Do we want some kind of assertion here, or is "best effort" okay?
118                 ENTERED.with(|c| {
119                     if let EnterContext::Entered {
120                         allow_blocking: false,
121                     } = c.get()
122                     {
123                         c.set(EnterContext::Entered {
124                             allow_blocking: true,
125                         });
126                     }
127                 })
128             }
129         }
130     }
131 }
132 
133 cfg_rt_multi_thread! {
134     /// Returns true if in a runtime context.
135     pub(crate) fn context() -> EnterContext {
136         ENTERED.with(|c| c.get())
137     }
138 }
139 
140 cfg_rt! {
141     impl Enter {
142         /// Blocks the thread on the specified future, returning the value with
143         /// which that future completes.
144         pub(crate) fn block_on<F>(&mut self, f: F) -> Result<F::Output, ParkError>
145         where
146             F: std::future::Future,
147         {
148             use crate::park::thread::CachedParkThread;
149 
150             let mut park = CachedParkThread::new();
151             park.block_on(f)
152         }
153 
154         /// Blocks the thread on the specified future for **at most** `timeout`
155         ///
156         /// If the future completes before `timeout`, the result is returned. If
157         /// `timeout` elapses, then `Err` is returned.
158         pub(crate) fn block_on_timeout<F>(&mut self, f: F, timeout: Duration) -> Result<F::Output, ParkError>
159         where
160             F: std::future::Future,
161         {
162             use crate::park::Park;
163             use crate::park::thread::CachedParkThread;
164             use std::task::Context;
165             use std::task::Poll::Ready;
166             use std::time::Instant;
167 
168             let mut park = CachedParkThread::new();
169             let waker = park.get_unpark()?.into_waker();
170             let mut cx = Context::from_waker(&waker);
171 
172             pin!(f);
173             let when = Instant::now() + timeout;
174 
175             loop {
176                 if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) {
177                     return Ok(v);
178                 }
179 
180                 let now = Instant::now();
181 
182                 if now >= when {
183                     return Err(());
184                 }
185 
186                 park.park_timeout(when - now)?;
187             }
188         }
189     }
190 }
191 
192 impl fmt::Debug for Enter {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result193     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194         f.debug_struct("Enter").finish()
195     }
196 }
197 
198 impl Drop for Enter {
drop(&mut self)199     fn drop(&mut self) {
200         ENTERED.with(|c| {
201             assert!(c.get().is_entered());
202             c.set(EnterContext::NotEntered);
203         });
204     }
205 }
206