1 use pin_project_lite::pin_project;
2 use std::cell::RefCell;
3 use std::error::Error;
4 use std::future::Future;
5 use std::marker::PhantomPinned;
6 use std::pin::Pin;
7 use std::task::{Context, Poll};
8 use std::{fmt, thread};
9 
10 /// Declares a new task-local key of type [`tokio::task::LocalKey`].
11 ///
12 /// # Syntax
13 ///
14 /// The macro wraps any number of static declarations and makes them local to the current task.
15 /// Publicity and attributes for each static is preserved. For example:
16 ///
17 /// # Examples
18 ///
19 /// ```
20 /// # use tokio::task_local;
21 /// task_local! {
22 ///     pub static ONE: u32;
23 ///
24 ///     #[allow(unused)]
25 ///     static TWO: f32;
26 /// }
27 /// # fn main() {}
28 /// ```
29 ///
30 /// See [LocalKey documentation][`tokio::task::LocalKey`] for more
31 /// information.
32 ///
33 /// [`tokio::task::LocalKey`]: struct@crate::task::LocalKey
34 #[macro_export]
35 #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
36 macro_rules! task_local {
37      // empty (base case for the recursion)
38     () => {};
39 
40     ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty; $($rest:tt)*) => {
41         $crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
42         $crate::task_local!($($rest)*);
43     };
44 
45     ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty) => {
46         $crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
47     }
48 }
49 
50 #[doc(hidden)]
51 #[macro_export]
52 macro_rules! __task_local_inner {
53     ($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => {
54         $vis static $name: $crate::task::LocalKey<$t> = {
55             std::thread_local! {
56                 static __KEY: std::cell::RefCell<Option<$t>> = std::cell::RefCell::new(None);
57             }
58 
59             $crate::task::LocalKey { inner: __KEY }
60         };
61     };
62 }
63 
64 /// A key for task-local data.
65 ///
66 /// This type is generated by the `task_local!` macro.
67 ///
68 /// Unlike [`std::thread::LocalKey`], `tokio::task::LocalKey` will
69 /// _not_ lazily initialize the value on first access. Instead, the
70 /// value is first initialized when the future containing
71 /// the task-local is first polled by a futures executor, like Tokio.
72 ///
73 /// # Examples
74 ///
75 /// ```
76 /// # async fn dox() {
77 /// tokio::task_local! {
78 ///     static NUMBER: u32;
79 /// }
80 ///
81 /// NUMBER.scope(1, async move {
82 ///     assert_eq!(NUMBER.get(), 1);
83 /// }).await;
84 ///
85 /// NUMBER.scope(2, async move {
86 ///     assert_eq!(NUMBER.get(), 2);
87 ///
88 ///     NUMBER.scope(3, async move {
89 ///         assert_eq!(NUMBER.get(), 3);
90 ///     }).await;
91 /// }).await;
92 /// # }
93 /// ```
94 /// [`std::thread::LocalKey`]: struct@std::thread::LocalKey
95 #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
96 pub struct LocalKey<T: 'static> {
97     #[doc(hidden)]
98     pub inner: thread::LocalKey<RefCell<Option<T>>>,
99 }
100 
101 impl<T: 'static> LocalKey<T> {
102     /// Sets a value `T` as the task-local value for the future `F`.
103     ///
104     /// On completion of `scope`, the task-local will be dropped.
105     ///
106     /// ### Examples
107     ///
108     /// ```
109     /// # async fn dox() {
110     /// tokio::task_local! {
111     ///     static NUMBER: u32;
112     /// }
113     ///
114     /// NUMBER.scope(1, async move {
115     ///     println!("task local value: {}", NUMBER.get());
116     /// }).await;
117     /// # }
118     /// ```
scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F> where F: Future,119     pub fn scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F>
120     where
121         F: Future,
122     {
123         TaskLocalFuture {
124             local: self,
125             slot: Some(value),
126             future: f,
127             _pinned: PhantomPinned,
128         }
129     }
130 
131     /// Sets a value `T` as the task-local value for the closure `F`.
132     ///
133     /// On completion of `scope`, the task-local will be dropped.
134     ///
135     /// ### Examples
136     ///
137     /// ```
138     /// # async fn dox() {
139     /// tokio::task_local! {
140     ///     static NUMBER: u32;
141     /// }
142     ///
143     /// NUMBER.sync_scope(1, || {
144     ///     println!("task local value: {}", NUMBER.get());
145     /// });
146     /// # }
147     /// ```
sync_scope<F, R>(&'static self, value: T, f: F) -> R where F: FnOnce() -> R,148     pub fn sync_scope<F, R>(&'static self, value: T, f: F) -> R
149     where
150         F: FnOnce() -> R,
151     {
152         let scope = TaskLocalFuture {
153             local: self,
154             slot: Some(value),
155             future: (),
156             _pinned: PhantomPinned,
157         };
158         crate::pin!(scope);
159         scope.with_task(|_| f())
160     }
161 
162     /// Accesses the current task-local and runs the provided closure.
163     ///
164     /// # Panics
165     ///
166     /// This function will panic if not called within the context
167     /// of a future containing a task-local with the corresponding key.
with<F, R>(&'static self, f: F) -> R where F: FnOnce(&T) -> R,168     pub fn with<F, R>(&'static self, f: F) -> R
169     where
170         F: FnOnce(&T) -> R,
171     {
172         self.try_with(f).expect(
173             "cannot access a Task Local Storage value \
174              without setting it via `LocalKey::set`",
175         )
176     }
177 
178     /// Accesses the current task-local and runs the provided closure.
179     ///
180     /// If the task-local with the associated key is not present, this
181     /// method will return an `AccessError`. For a panicking variant,
182     /// see `with`.
try_with<F, R>(&'static self, f: F) -> Result<R, AccessError> where F: FnOnce(&T) -> R,183     pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError>
184     where
185         F: FnOnce(&T) -> R,
186     {
187         self.inner.with(|v| {
188             if let Some(val) = v.borrow().as_ref() {
189                 Ok(f(val))
190             } else {
191                 Err(AccessError { _private: () })
192             }
193         })
194     }
195 }
196 
197 impl<T: Copy + 'static> LocalKey<T> {
198     /// Returns a copy of the task-local value
199     /// if the task-local value implements `Copy`.
get(&'static self) -> T200     pub fn get(&'static self) -> T {
201         self.with(|v| *v)
202     }
203 }
204 
205 impl<T: 'static> fmt::Debug for LocalKey<T> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result206     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207         f.pad("LocalKey { .. }")
208     }
209 }
210 
211 pin_project! {
212     /// A future that sets a value `T` of a task local for the future `F` during
213     /// its execution.
214     ///
215     /// The value of the task-local must be `'static` and will be dropped on the
216     /// completion of the future.
217     ///
218     /// Created by the function [`LocalKey::scope`](self::LocalKey::scope).
219     ///
220     /// ### Examples
221     ///
222     /// ```
223     /// # async fn dox() {
224     /// tokio::task_local! {
225     ///     static NUMBER: u32;
226     /// }
227     ///
228     /// NUMBER.scope(1, async move {
229     ///     println!("task local value: {}", NUMBER.get());
230     /// }).await;
231     /// # }
232     /// ```
233     pub struct TaskLocalFuture<T, F>
234     where
235         T: 'static
236     {
237         local: &'static LocalKey<T>,
238         slot: Option<T>,
239         #[pin]
240         future: F,
241         #[pin]
242         _pinned: PhantomPinned,
243     }
244 }
245 
246 impl<T: 'static, F> TaskLocalFuture<T, F> {
with_task<F2: FnOnce(Pin<&mut F>) -> R, R>(self: Pin<&mut Self>, f: F2) -> R247     fn with_task<F2: FnOnce(Pin<&mut F>) -> R, R>(self: Pin<&mut Self>, f: F2) -> R {
248         struct Guard<'a, T: 'static> {
249             local: &'static LocalKey<T>,
250             slot: &'a mut Option<T>,
251             prev: Option<T>,
252         }
253 
254         impl<T> Drop for Guard<'_, T> {
255             fn drop(&mut self) {
256                 let value = self.local.inner.with(|c| c.replace(self.prev.take()));
257                 *self.slot = value;
258             }
259         }
260 
261         let mut project = self.project();
262         let val = project.slot.take();
263 
264         let prev = project.local.inner.with(|c| c.replace(val));
265 
266         let _guard = Guard {
267             prev,
268             slot: &mut project.slot,
269             local: *project.local,
270         };
271 
272         f(project.future)
273     }
274 }
275 
276 impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> {
277     type Output = F::Output;
278 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>279     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
280         self.with_task(|f| f.poll(cx))
281     }
282 }
283 
284 /// An error returned by [`LocalKey::try_with`](method@LocalKey::try_with).
285 #[derive(Clone, Copy, Eq, PartialEq)]
286 pub struct AccessError {
287     _private: (),
288 }
289 
290 impl fmt::Debug for AccessError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result291     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
292         f.debug_struct("AccessError").finish()
293     }
294 }
295 
296 impl fmt::Display for AccessError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result297     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298         fmt::Display::fmt("task-local value not set", f)
299     }
300 }
301 
302 impl Error for AccessError {}
303