1 //! This module has containers for storing the tasks spawned on a scheduler. The
2 //! `OwnedTasks` container is thread-safe but can only store tasks that
3 //! implement Send. The `LocalOwnedTasks` container is not thread safe, but can
4 //! store non-Send tasks.
5 //!
6 //! The collections can be closed to prevent adding new tasks during shutdown of
7 //! the scheduler with the collection.
8 
9 use crate::future::Future;
10 use crate::loom::cell::UnsafeCell;
11 use crate::loom::sync::Mutex;
12 use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task};
13 use crate::util::linked_list::{Link, LinkedList};
14 
15 use std::marker::PhantomData;
16 
17 // The id from the module below is used to verify whether a given task is stored
18 // in this OwnedTasks, or some other task. The counter starts at one so we can
19 // use zero for tasks not owned by any list.
20 //
21 // The safety checks in this file can technically be violated if the counter is
22 // overflown, but the checks are not supposed to ever fail unless there is a
23 // bug in Tokio, so we accept that certain bugs would not be caught if the two
24 // mixed up runtimes happen to have the same id.
25 
26 cfg_has_atomic_u64! {
27     use std::sync::atomic::{AtomicU64, Ordering};
28 
29     static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1);
30 
31     fn get_next_id() -> u64 {
32         loop {
33             let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
34             if id != 0 {
35                 return id;
36             }
37         }
38     }
39 }
40 
41 cfg_not_has_atomic_u64! {
42     use std::sync::atomic::{AtomicU32, Ordering};
43 
44     static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1);
45 
46     fn get_next_id() -> u64 {
47         loop {
48             let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
49             if id != 0 {
50                 return u64::from(id);
51             }
52         }
53     }
54 }
55 
56 pub(crate) struct OwnedTasks<S: 'static> {
57     inner: Mutex<OwnedTasksInner<S>>,
58     id: u64,
59 }
60 pub(crate) struct LocalOwnedTasks<S: 'static> {
61     inner: UnsafeCell<OwnedTasksInner<S>>,
62     id: u64,
63     _not_send_or_sync: PhantomData<*const ()>,
64 }
65 struct OwnedTasksInner<S: 'static> {
66     list: LinkedList<Task<S>, <Task<S> as Link>::Target>,
67     closed: bool,
68 }
69 
70 impl<S: 'static> OwnedTasks<S> {
new() -> Self71     pub(crate) fn new() -> Self {
72         Self {
73             inner: Mutex::new(OwnedTasksInner {
74                 list: LinkedList::new(),
75                 closed: false,
76             }),
77             id: get_next_id(),
78         }
79     }
80 
81     /// Binds the provided task to this OwnedTasks instance. This fails if the
82     /// OwnedTasks has been closed.
bind<T>( &self, task: T, scheduler: S, ) -> (JoinHandle<T::Output>, Option<Notified<S>>) where S: Schedule, T: Future + Send + 'static, T::Output: Send + 'static,83     pub(crate) fn bind<T>(
84         &self,
85         task: T,
86         scheduler: S,
87     ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
88     where
89         S: Schedule,
90         T: Future + Send + 'static,
91         T::Output: Send + 'static,
92     {
93         let (task, notified, join) = super::new_task(task, scheduler);
94 
95         unsafe {
96             // safety: We just created the task, so we have exclusive access
97             // to the field.
98             task.header().set_owner_id(self.id);
99         }
100 
101         let mut lock = self.inner.lock();
102         if lock.closed {
103             drop(lock);
104             drop(notified);
105             task.shutdown();
106             (join, None)
107         } else {
108             lock.list.push_front(task);
109             (join, Some(notified))
110         }
111     }
112 
113     /// Asserts that the given task is owned by this OwnedTasks and convert it to
114     /// a LocalNotified, giving the thread permission to poll this task.
115     #[inline]
assert_owner(&self, task: Notified<S>) -> LocalNotified<S>116     pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
117         assert_eq!(task.header().get_owner_id(), self.id);
118 
119         // safety: All tasks bound to this OwnedTasks are Send, so it is safe
120         // to poll it on this thread no matter what thread we are on.
121         LocalNotified {
122             task: task.0,
123             _not_send: PhantomData,
124         }
125     }
126 
127     /// Shuts down all tasks in the collection. This call also closes the
128     /// collection, preventing new items from being added.
close_and_shutdown_all(&self) where S: Schedule,129     pub(crate) fn close_and_shutdown_all(&self)
130     where
131         S: Schedule,
132     {
133         // The first iteration of the loop was unrolled so it can set the
134         // closed bool.
135         let first_task = {
136             let mut lock = self.inner.lock();
137             lock.closed = true;
138             lock.list.pop_back()
139         };
140         match first_task {
141             Some(task) => task.shutdown(),
142             None => return,
143         }
144 
145         loop {
146             let task = match self.inner.lock().list.pop_back() {
147                 Some(task) => task,
148                 None => return,
149             };
150 
151             task.shutdown();
152         }
153     }
154 
remove(&self, task: &Task<S>) -> Option<Task<S>>155     pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
156         let task_id = task.header().get_owner_id();
157         if task_id == 0 {
158             // The task is unowned.
159             return None;
160         }
161 
162         assert_eq!(task_id, self.id);
163 
164         // safety: We just checked that the provided task is not in some other
165         // linked list.
166         unsafe { self.inner.lock().list.remove(task.header().into()) }
167     }
168 
is_empty(&self) -> bool169     pub(crate) fn is_empty(&self) -> bool {
170         self.inner.lock().list.is_empty()
171     }
172 }
173 
174 impl<S: 'static> LocalOwnedTasks<S> {
new() -> Self175     pub(crate) fn new() -> Self {
176         Self {
177             inner: UnsafeCell::new(OwnedTasksInner {
178                 list: LinkedList::new(),
179                 closed: false,
180             }),
181             id: get_next_id(),
182             _not_send_or_sync: PhantomData,
183         }
184     }
185 
bind<T>( &self, task: T, scheduler: S, ) -> (JoinHandle<T::Output>, Option<Notified<S>>) where S: Schedule, T: Future + 'static, T::Output: 'static,186     pub(crate) fn bind<T>(
187         &self,
188         task: T,
189         scheduler: S,
190     ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
191     where
192         S: Schedule,
193         T: Future + 'static,
194         T::Output: 'static,
195     {
196         let (task, notified, join) = super::new_task(task, scheduler);
197 
198         unsafe {
199             // safety: We just created the task, so we have exclusive access
200             // to the field.
201             task.header().set_owner_id(self.id);
202         }
203 
204         if self.is_closed() {
205             drop(notified);
206             task.shutdown();
207             (join, None)
208         } else {
209             self.with_inner(|inner| {
210                 inner.list.push_front(task);
211             });
212             (join, Some(notified))
213         }
214     }
215 
216     /// Shuts down all tasks in the collection. This call also closes the
217     /// collection, preventing new items from being added.
close_and_shutdown_all(&self) where S: Schedule,218     pub(crate) fn close_and_shutdown_all(&self)
219     where
220         S: Schedule,
221     {
222         self.with_inner(|inner| inner.closed = true);
223 
224         while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) {
225             task.shutdown();
226         }
227     }
228 
remove(&self, task: &Task<S>) -> Option<Task<S>>229     pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
230         let task_id = task.header().get_owner_id();
231         if task_id == 0 {
232             // The task is unowned.
233             return None;
234         }
235 
236         assert_eq!(task_id, self.id);
237 
238         self.with_inner(|inner|
239             // safety: We just checked that the provided task is not in some
240             // other linked list.
241             unsafe { inner.list.remove(task.header().into()) })
242     }
243 
244     /// Asserts that the given task is owned by this LocalOwnedTasks and convert
245     /// it to a LocalNotified, giving the thread permission to poll this task.
246     #[inline]
assert_owner(&self, task: Notified<S>) -> LocalNotified<S>247     pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
248         assert_eq!(task.header().get_owner_id(), self.id);
249 
250         // safety: The task was bound to this LocalOwnedTasks, and the
251         // LocalOwnedTasks is not Send or Sync, so we are on the right thread
252         // for polling this task.
253         LocalNotified {
254             task: task.0,
255             _not_send: PhantomData,
256         }
257     }
258 
259     #[inline]
with_inner<F, T>(&self, f: F) -> T where F: FnOnce(&mut OwnedTasksInner<S>) -> T,260     fn with_inner<F, T>(&self, f: F) -> T
261     where
262         F: FnOnce(&mut OwnedTasksInner<S>) -> T,
263     {
264         // safety: This type is not Sync, so concurrent calls of this method
265         // can't happen.  Furthermore, all uses of this method in this file make
266         // sure that they don't call `with_inner` recursively.
267         self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) })
268     }
269 
is_closed(&self) -> bool270     pub(crate) fn is_closed(&self) -> bool {
271         self.with_inner(|inner| inner.closed)
272     }
273 
is_empty(&self) -> bool274     pub(crate) fn is_empty(&self) -> bool {
275         self.with_inner(|inner| inner.list.is_empty())
276     }
277 }
278 
279 #[cfg(all(test))]
280 mod tests {
281     use super::*;
282 
283     // This test may run in parallel with other tests, so we only test that ids
284     // come in increasing order.
285     #[test]
test_id_not_broken()286     fn test_id_not_broken() {
287         let mut last_id = get_next_id();
288         assert_ne!(last_id, 0);
289 
290         for _ in 0..1000 {
291             let next_id = get_next_id();
292             assert_ne!(next_id, 0);
293             assert!(last_id < next_id);
294             last_id = next_id;
295         }
296     }
297 }
298