1 use pin_project_lite::pin_project; 2 use std::cell::RefCell; 3 use std::error::Error; 4 use std::future::Future; 5 use std::marker::PhantomPinned; 6 use std::pin::Pin; 7 use std::task::{Context, Poll}; 8 use std::{fmt, thread}; 9 10 /// Declares a new task-local key of type [`tokio::task::LocalKey`]. 11 /// 12 /// # Syntax 13 /// 14 /// The macro wraps any number of static declarations and makes them local to the current task. 15 /// Publicity and attributes for each static is preserved. For example: 16 /// 17 /// # Examples 18 /// 19 /// ``` 20 /// # use tokio::task_local; 21 /// task_local! { 22 /// pub static ONE: u32; 23 /// 24 /// #[allow(unused)] 25 /// static TWO: f32; 26 /// } 27 /// # fn main() {} 28 /// ``` 29 /// 30 /// See [LocalKey documentation][`tokio::task::LocalKey`] for more 31 /// information. 32 /// 33 /// [`tokio::task::LocalKey`]: struct@crate::task::LocalKey 34 #[macro_export] 35 #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] 36 macro_rules! task_local { 37 // empty (base case for the recursion) 38 () => {}; 39 40 ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty; $($rest:tt)*) => { 41 $crate::__task_local_inner!($(#[$attr])* $vis $name, $t); 42 $crate::task_local!($($rest)*); 43 }; 44 45 ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty) => { 46 $crate::__task_local_inner!($(#[$attr])* $vis $name, $t); 47 } 48 } 49 50 #[doc(hidden)] 51 #[macro_export] 52 macro_rules! __task_local_inner { 53 ($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => { 54 $vis static $name: $crate::task::LocalKey<$t> = { 55 std::thread_local! { 56 static __KEY: std::cell::RefCell<Option<$t>> = std::cell::RefCell::new(None); 57 } 58 59 $crate::task::LocalKey { inner: __KEY } 60 }; 61 }; 62 } 63 64 /// A key for task-local data. 65 /// 66 /// This type is generated by the `task_local!` macro. 67 /// 68 /// Unlike [`std::thread::LocalKey`], `tokio::task::LocalKey` will 69 /// _not_ lazily initialize the value on first access. Instead, the 70 /// value is first initialized when the future containing 71 /// the task-local is first polled by a futures executor, like Tokio. 72 /// 73 /// # Examples 74 /// 75 /// ``` 76 /// # async fn dox() { 77 /// tokio::task_local! { 78 /// static NUMBER: u32; 79 /// } 80 /// 81 /// NUMBER.scope(1, async move { 82 /// assert_eq!(NUMBER.get(), 1); 83 /// }).await; 84 /// 85 /// NUMBER.scope(2, async move { 86 /// assert_eq!(NUMBER.get(), 2); 87 /// 88 /// NUMBER.scope(3, async move { 89 /// assert_eq!(NUMBER.get(), 3); 90 /// }).await; 91 /// }).await; 92 /// # } 93 /// ``` 94 /// [`std::thread::LocalKey`]: struct@std::thread::LocalKey 95 #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] 96 pub struct LocalKey<T: 'static> { 97 #[doc(hidden)] 98 pub inner: thread::LocalKey<RefCell<Option<T>>>, 99 } 100 101 impl<T: 'static> LocalKey<T> { 102 /// Sets a value `T` as the task-local value for the future `F`. 103 /// 104 /// On completion of `scope`, the task-local will be dropped. 105 /// 106 /// ### Examples 107 /// 108 /// ``` 109 /// # async fn dox() { 110 /// tokio::task_local! { 111 /// static NUMBER: u32; 112 /// } 113 /// 114 /// NUMBER.scope(1, async move { 115 /// println!("task local value: {}", NUMBER.get()); 116 /// }).await; 117 /// # } 118 /// ``` scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F> where F: Future,119 pub fn scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F> 120 where 121 F: Future, 122 { 123 TaskLocalFuture { 124 local: self, 125 slot: Some(value), 126 future: f, 127 _pinned: PhantomPinned, 128 } 129 } 130 131 /// Sets a value `T` as the task-local value for the closure `F`. 132 /// 133 /// On completion of `scope`, the task-local will be dropped. 134 /// 135 /// ### Examples 136 /// 137 /// ``` 138 /// # async fn dox() { 139 /// tokio::task_local! { 140 /// static NUMBER: u32; 141 /// } 142 /// 143 /// NUMBER.sync_scope(1, || { 144 /// println!("task local value: {}", NUMBER.get()); 145 /// }); 146 /// # } 147 /// ``` sync_scope<F, R>(&'static self, value: T, f: F) -> R where F: FnOnce() -> R,148 pub fn sync_scope<F, R>(&'static self, value: T, f: F) -> R 149 where 150 F: FnOnce() -> R, 151 { 152 let scope = TaskLocalFuture { 153 local: self, 154 slot: Some(value), 155 future: (), 156 _pinned: PhantomPinned, 157 }; 158 crate::pin!(scope); 159 scope.with_task(|_| f()) 160 } 161 162 /// Accesses the current task-local and runs the provided closure. 163 /// 164 /// # Panics 165 /// 166 /// This function will panic if not called within the context 167 /// of a future containing a task-local with the corresponding key. with<F, R>(&'static self, f: F) -> R where F: FnOnce(&T) -> R,168 pub fn with<F, R>(&'static self, f: F) -> R 169 where 170 F: FnOnce(&T) -> R, 171 { 172 self.try_with(f).expect( 173 "cannot access a Task Local Storage value \ 174 without setting it via `LocalKey::set`", 175 ) 176 } 177 178 /// Accesses the current task-local and runs the provided closure. 179 /// 180 /// If the task-local with the associated key is not present, this 181 /// method will return an `AccessError`. For a panicking variant, 182 /// see `with`. try_with<F, R>(&'static self, f: F) -> Result<R, AccessError> where F: FnOnce(&T) -> R,183 pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError> 184 where 185 F: FnOnce(&T) -> R, 186 { 187 self.inner.with(|v| { 188 if let Some(val) = v.borrow().as_ref() { 189 Ok(f(val)) 190 } else { 191 Err(AccessError { _private: () }) 192 } 193 }) 194 } 195 } 196 197 impl<T: Copy + 'static> LocalKey<T> { 198 /// Returns a copy of the task-local value 199 /// if the task-local value implements `Copy`. get(&'static self) -> T200 pub fn get(&'static self) -> T { 201 self.with(|v| *v) 202 } 203 } 204 205 impl<T: 'static> fmt::Debug for LocalKey<T> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result206 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 207 f.pad("LocalKey { .. }") 208 } 209 } 210 211 pin_project! { 212 /// A future that sets a value `T` of a task local for the future `F` during 213 /// its execution. 214 /// 215 /// The value of the task-local must be `'static` and will be dropped on the 216 /// completion of the future. 217 /// 218 /// Created by the function [`LocalKey::scope`](self::LocalKey::scope). 219 /// 220 /// ### Examples 221 /// 222 /// ``` 223 /// # async fn dox() { 224 /// tokio::task_local! { 225 /// static NUMBER: u32; 226 /// } 227 /// 228 /// NUMBER.scope(1, async move { 229 /// println!("task local value: {}", NUMBER.get()); 230 /// }).await; 231 /// # } 232 /// ``` 233 pub struct TaskLocalFuture<T, F> 234 where 235 T: 'static 236 { 237 local: &'static LocalKey<T>, 238 slot: Option<T>, 239 #[pin] 240 future: F, 241 #[pin] 242 _pinned: PhantomPinned, 243 } 244 } 245 246 impl<T: 'static, F> TaskLocalFuture<T, F> { with_task<F2: FnOnce(Pin<&mut F>) -> R, R>(self: Pin<&mut Self>, f: F2) -> R247 fn with_task<F2: FnOnce(Pin<&mut F>) -> R, R>(self: Pin<&mut Self>, f: F2) -> R { 248 struct Guard<'a, T: 'static> { 249 local: &'static LocalKey<T>, 250 slot: &'a mut Option<T>, 251 prev: Option<T>, 252 } 253 254 impl<T> Drop for Guard<'_, T> { 255 fn drop(&mut self) { 256 let value = self.local.inner.with(|c| c.replace(self.prev.take())); 257 *self.slot = value; 258 } 259 } 260 261 let mut project = self.project(); 262 let val = project.slot.take(); 263 264 let prev = project.local.inner.with(|c| c.replace(val)); 265 266 let _guard = Guard { 267 prev, 268 slot: &mut project.slot, 269 local: *project.local, 270 }; 271 272 f(project.future) 273 } 274 } 275 276 impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> { 277 type Output = F::Output; 278 poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>279 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 280 self.with_task(|f| f.poll(cx)) 281 } 282 } 283 284 /// An error returned by [`LocalKey::try_with`](method@LocalKey::try_with). 285 #[derive(Clone, Copy, Eq, PartialEq)] 286 pub struct AccessError { 287 _private: (), 288 } 289 290 impl fmt::Debug for AccessError { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result291 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 292 f.debug_struct("AccessError").finish() 293 } 294 } 295 296 impl fmt::Display for AccessError { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result297 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 298 fmt::Display::fmt("task-local value not set", f) 299 } 300 } 301 302 impl Error for AccessError {} 303