1 //! This module has containers for storing the tasks spawned on a scheduler. The 2 //! `OwnedTasks` container is thread-safe but can only store tasks that 3 //! implement Send. The `LocalOwnedTasks` container is not thread safe, but can 4 //! store non-Send tasks. 5 //! 6 //! The collections can be closed to prevent adding new tasks during shutdown of 7 //! the scheduler with the collection. 8 9 use crate::future::Future; 10 use crate::loom::cell::UnsafeCell; 11 use crate::loom::sync::Mutex; 12 use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task}; 13 use crate::util::linked_list::{Link, LinkedList}; 14 15 use std::marker::PhantomData; 16 17 // The id from the module below is used to verify whether a given task is stored 18 // in this OwnedTasks, or some other task. The counter starts at one so we can 19 // use zero for tasks not owned by any list. 20 // 21 // The safety checks in this file can technically be violated if the counter is 22 // overflown, but the checks are not supposed to ever fail unless there is a 23 // bug in Tokio, so we accept that certain bugs would not be caught if the two 24 // mixed up runtimes happen to have the same id. 25 26 cfg_has_atomic_u64! { 27 use std::sync::atomic::{AtomicU64, Ordering}; 28 29 static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1); 30 31 fn get_next_id() -> u64 { 32 loop { 33 let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); 34 if id != 0 { 35 return id; 36 } 37 } 38 } 39 } 40 41 cfg_not_has_atomic_u64! { 42 use std::sync::atomic::{AtomicU32, Ordering}; 43 44 static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1); 45 46 fn get_next_id() -> u64 { 47 loop { 48 let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); 49 if id != 0 { 50 return u64::from(id); 51 } 52 } 53 } 54 } 55 56 pub(crate) struct OwnedTasks<S: 'static> { 57 inner: Mutex<OwnedTasksInner<S>>, 58 id: u64, 59 } 60 pub(crate) struct LocalOwnedTasks<S: 'static> { 61 inner: UnsafeCell<OwnedTasksInner<S>>, 62 id: u64, 63 _not_send_or_sync: PhantomData<*const ()>, 64 } 65 struct OwnedTasksInner<S: 'static> { 66 list: LinkedList<Task<S>, <Task<S> as Link>::Target>, 67 closed: bool, 68 } 69 70 impl<S: 'static> OwnedTasks<S> { new() -> Self71 pub(crate) fn new() -> Self { 72 Self { 73 inner: Mutex::new(OwnedTasksInner { 74 list: LinkedList::new(), 75 closed: false, 76 }), 77 id: get_next_id(), 78 } 79 } 80 81 /// Binds the provided task to this OwnedTasks instance. This fails if the 82 /// OwnedTasks has been closed. bind<T>( &self, task: T, scheduler: S, ) -> (JoinHandle<T::Output>, Option<Notified<S>>) where S: Schedule, T: Future + Send + 'static, T::Output: Send + 'static,83 pub(crate) fn bind<T>( 84 &self, 85 task: T, 86 scheduler: S, 87 ) -> (JoinHandle<T::Output>, Option<Notified<S>>) 88 where 89 S: Schedule, 90 T: Future + Send + 'static, 91 T::Output: Send + 'static, 92 { 93 let (task, notified, join) = super::new_task(task, scheduler); 94 95 unsafe { 96 // safety: We just created the task, so we have exclusive access 97 // to the field. 98 task.header().set_owner_id(self.id); 99 } 100 101 let mut lock = self.inner.lock(); 102 if lock.closed { 103 drop(lock); 104 drop(notified); 105 task.shutdown(); 106 (join, None) 107 } else { 108 lock.list.push_front(task); 109 (join, Some(notified)) 110 } 111 } 112 113 /// Asserts that the given task is owned by this OwnedTasks and convert it to 114 /// a LocalNotified, giving the thread permission to poll this task. 115 #[inline] assert_owner(&self, task: Notified<S>) -> LocalNotified<S>116 pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { 117 assert_eq!(task.header().get_owner_id(), self.id); 118 119 // safety: All tasks bound to this OwnedTasks are Send, so it is safe 120 // to poll it on this thread no matter what thread we are on. 121 LocalNotified { 122 task: task.0, 123 _not_send: PhantomData, 124 } 125 } 126 127 /// Shuts down all tasks in the collection. This call also closes the 128 /// collection, preventing new items from being added. close_and_shutdown_all(&self) where S: Schedule,129 pub(crate) fn close_and_shutdown_all(&self) 130 where 131 S: Schedule, 132 { 133 // The first iteration of the loop was unrolled so it can set the 134 // closed bool. 135 let first_task = { 136 let mut lock = self.inner.lock(); 137 lock.closed = true; 138 lock.list.pop_back() 139 }; 140 match first_task { 141 Some(task) => task.shutdown(), 142 None => return, 143 } 144 145 loop { 146 let task = match self.inner.lock().list.pop_back() { 147 Some(task) => task, 148 None => return, 149 }; 150 151 task.shutdown(); 152 } 153 } 154 remove(&self, task: &Task<S>) -> Option<Task<S>>155 pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { 156 let task_id = task.header().get_owner_id(); 157 if task_id == 0 { 158 // The task is unowned. 159 return None; 160 } 161 162 assert_eq!(task_id, self.id); 163 164 // safety: We just checked that the provided task is not in some other 165 // linked list. 166 unsafe { self.inner.lock().list.remove(task.header().into()) } 167 } 168 is_empty(&self) -> bool169 pub(crate) fn is_empty(&self) -> bool { 170 self.inner.lock().list.is_empty() 171 } 172 } 173 174 impl<S: 'static> LocalOwnedTasks<S> { new() -> Self175 pub(crate) fn new() -> Self { 176 Self { 177 inner: UnsafeCell::new(OwnedTasksInner { 178 list: LinkedList::new(), 179 closed: false, 180 }), 181 id: get_next_id(), 182 _not_send_or_sync: PhantomData, 183 } 184 } 185 bind<T>( &self, task: T, scheduler: S, ) -> (JoinHandle<T::Output>, Option<Notified<S>>) where S: Schedule, T: Future + 'static, T::Output: 'static,186 pub(crate) fn bind<T>( 187 &self, 188 task: T, 189 scheduler: S, 190 ) -> (JoinHandle<T::Output>, Option<Notified<S>>) 191 where 192 S: Schedule, 193 T: Future + 'static, 194 T::Output: 'static, 195 { 196 let (task, notified, join) = super::new_task(task, scheduler); 197 198 unsafe { 199 // safety: We just created the task, so we have exclusive access 200 // to the field. 201 task.header().set_owner_id(self.id); 202 } 203 204 if self.is_closed() { 205 drop(notified); 206 task.shutdown(); 207 (join, None) 208 } else { 209 self.with_inner(|inner| { 210 inner.list.push_front(task); 211 }); 212 (join, Some(notified)) 213 } 214 } 215 216 /// Shuts down all tasks in the collection. This call also closes the 217 /// collection, preventing new items from being added. close_and_shutdown_all(&self) where S: Schedule,218 pub(crate) fn close_and_shutdown_all(&self) 219 where 220 S: Schedule, 221 { 222 self.with_inner(|inner| inner.closed = true); 223 224 while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) { 225 task.shutdown(); 226 } 227 } 228 remove(&self, task: &Task<S>) -> Option<Task<S>>229 pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { 230 let task_id = task.header().get_owner_id(); 231 if task_id == 0 { 232 // The task is unowned. 233 return None; 234 } 235 236 assert_eq!(task_id, self.id); 237 238 self.with_inner(|inner| 239 // safety: We just checked that the provided task is not in some 240 // other linked list. 241 unsafe { inner.list.remove(task.header().into()) }) 242 } 243 244 /// Asserts that the given task is owned by this LocalOwnedTasks and convert 245 /// it to a LocalNotified, giving the thread permission to poll this task. 246 #[inline] assert_owner(&self, task: Notified<S>) -> LocalNotified<S>247 pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { 248 assert_eq!(task.header().get_owner_id(), self.id); 249 250 // safety: The task was bound to this LocalOwnedTasks, and the 251 // LocalOwnedTasks is not Send or Sync, so we are on the right thread 252 // for polling this task. 253 LocalNotified { 254 task: task.0, 255 _not_send: PhantomData, 256 } 257 } 258 259 #[inline] with_inner<F, T>(&self, f: F) -> T where F: FnOnce(&mut OwnedTasksInner<S>) -> T,260 fn with_inner<F, T>(&self, f: F) -> T 261 where 262 F: FnOnce(&mut OwnedTasksInner<S>) -> T, 263 { 264 // safety: This type is not Sync, so concurrent calls of this method 265 // can't happen. Furthermore, all uses of this method in this file make 266 // sure that they don't call `with_inner` recursively. 267 self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) }) 268 } 269 is_closed(&self) -> bool270 pub(crate) fn is_closed(&self) -> bool { 271 self.with_inner(|inner| inner.closed) 272 } 273 is_empty(&self) -> bool274 pub(crate) fn is_empty(&self) -> bool { 275 self.with_inner(|inner| inner.list.is_empty()) 276 } 277 } 278 279 #[cfg(all(test))] 280 mod tests { 281 use super::*; 282 283 // This test may run in parallel with other tests, so we only test that ids 284 // come in increasing order. 285 #[test] test_id_not_broken()286 fn test_id_not_broken() { 287 let mut last_id = get_next_id(); 288 assert_ne!(last_id, 0); 289 290 for _ in 0..1000 { 291 let next_id = get_next_id(); 292 assert_ne!(next_id, 0); 293 assert!(last_id < next_id); 294 last_id = next_id; 295 } 296 } 297 } 298