1 //! Waking mechanism for threads blocked on channel operations.
2 
3 use std::sync::atomic::{AtomicBool, Ordering};
4 use std::thread::{self, ThreadId};
5 
6 use crate::context::Context;
7 use crate::select::{Operation, Selected};
8 use crate::utils::Spinlock;
9 
10 /// Represents a thread blocked on a specific channel operation.
11 pub(crate) struct Entry {
12     /// The operation.
13     pub(crate) oper: Operation,
14 
15     /// Optional packet.
16     pub(crate) packet: usize,
17 
18     /// Context associated with the thread owning this operation.
19     pub(crate) cx: Context,
20 }
21 
22 /// A queue of threads blocked on channel operations.
23 ///
24 /// This data structure is used by threads to register blocking operations and get woken up once
25 /// an operation becomes ready.
26 pub(crate) struct Waker {
27     /// A list of select operations.
28     selectors: Vec<Entry>,
29 
30     /// A list of operations waiting to be ready.
31     observers: Vec<Entry>,
32 }
33 
34 impl Waker {
35     /// Creates a new `Waker`.
36     #[inline]
new() -> Self37     pub(crate) fn new() -> Self {
38         Waker {
39             selectors: Vec::new(),
40             observers: Vec::new(),
41         }
42     }
43 
44     /// Registers a select operation.
45     #[inline]
register(&mut self, oper: Operation, cx: &Context)46     pub(crate) fn register(&mut self, oper: Operation, cx: &Context) {
47         self.register_with_packet(oper, 0, cx);
48     }
49 
50     /// Registers a select operation and a packet.
51     #[inline]
register_with_packet(&mut self, oper: Operation, packet: usize, cx: &Context)52     pub(crate) fn register_with_packet(&mut self, oper: Operation, packet: usize, cx: &Context) {
53         self.selectors.push(Entry {
54             oper,
55             packet,
56             cx: cx.clone(),
57         });
58     }
59 
60     /// Unregisters a select operation.
61     #[inline]
unregister(&mut self, oper: Operation) -> Option<Entry>62     pub(crate) fn unregister(&mut self, oper: Operation) -> Option<Entry> {
63         if let Some((i, _)) = self
64             .selectors
65             .iter()
66             .enumerate()
67             .find(|&(_, entry)| entry.oper == oper)
68         {
69             let entry = self.selectors.remove(i);
70             Some(entry)
71         } else {
72             None
73         }
74     }
75 
76     /// Attempts to find another thread's entry, select the operation, and wake it up.
77     #[inline]
try_select(&mut self) -> Option<Entry>78     pub(crate) fn try_select(&mut self) -> Option<Entry> {
79         let mut entry = None;
80 
81         if !self.selectors.is_empty() {
82             let thread_id = current_thread_id();
83 
84             for i in 0..self.selectors.len() {
85                 // Does the entry belong to a different thread?
86                 if self.selectors[i].cx.thread_id() != thread_id {
87                     // Try selecting this operation.
88                     let sel = Selected::Operation(self.selectors[i].oper);
89                     let res = self.selectors[i].cx.try_select(sel);
90 
91                     if res.is_ok() {
92                         // Provide the packet.
93                         self.selectors[i].cx.store_packet(self.selectors[i].packet);
94                         // Wake the thread up.
95                         self.selectors[i].cx.unpark();
96 
97                         // Remove the entry from the queue to keep it clean and improve
98                         // performance.
99                         entry = Some(self.selectors.remove(i));
100                         break;
101                     }
102                 }
103             }
104         }
105 
106         entry
107     }
108 
109     /// Returns `true` if there is an entry which can be selected by the current thread.
110     #[inline]
can_select(&self) -> bool111     pub(crate) fn can_select(&self) -> bool {
112         if self.selectors.is_empty() {
113             false
114         } else {
115             let thread_id = current_thread_id();
116 
117             self.selectors.iter().any(|entry| {
118                 entry.cx.thread_id() != thread_id && entry.cx.selected() == Selected::Waiting
119             })
120         }
121     }
122 
123     /// Registers an operation waiting to be ready.
124     #[inline]
watch(&mut self, oper: Operation, cx: &Context)125     pub(crate) fn watch(&mut self, oper: Operation, cx: &Context) {
126         self.observers.push(Entry {
127             oper,
128             packet: 0,
129             cx: cx.clone(),
130         });
131     }
132 
133     /// Unregisters an operation waiting to be ready.
134     #[inline]
unwatch(&mut self, oper: Operation)135     pub(crate) fn unwatch(&mut self, oper: Operation) {
136         self.observers.retain(|e| e.oper != oper);
137     }
138 
139     /// Notifies all operations waiting to be ready.
140     #[inline]
notify(&mut self)141     pub(crate) fn notify(&mut self) {
142         for entry in self.observers.drain(..) {
143             if entry.cx.try_select(Selected::Operation(entry.oper)).is_ok() {
144                 entry.cx.unpark();
145             }
146         }
147     }
148 
149     /// Notifies all registered operations that the channel is disconnected.
150     #[inline]
disconnect(&mut self)151     pub(crate) fn disconnect(&mut self) {
152         for entry in self.selectors.iter() {
153             if entry.cx.try_select(Selected::Disconnected).is_ok() {
154                 // Wake the thread up.
155                 //
156                 // Here we don't remove the entry from the queue. Registered threads must
157                 // unregister from the waker by themselves. They might also want to recover the
158                 // packet value and destroy it, if necessary.
159                 entry.cx.unpark();
160             }
161         }
162 
163         self.notify();
164     }
165 }
166 
167 impl Drop for Waker {
168     #[inline]
drop(&mut self)169     fn drop(&mut self) {
170         debug_assert_eq!(self.selectors.len(), 0);
171         debug_assert_eq!(self.observers.len(), 0);
172     }
173 }
174 
175 /// A waker that can be shared among threads without locking.
176 ///
177 /// This is a simple wrapper around `Waker` that internally uses a mutex for synchronization.
178 pub(crate) struct SyncWaker {
179     /// The inner `Waker`.
180     inner: Spinlock<Waker>,
181 
182     /// `true` if the waker is empty.
183     is_empty: AtomicBool,
184 }
185 
186 impl SyncWaker {
187     /// Creates a new `SyncWaker`.
188     #[inline]
new() -> Self189     pub(crate) fn new() -> Self {
190         SyncWaker {
191             inner: Spinlock::new(Waker::new()),
192             is_empty: AtomicBool::new(true),
193         }
194     }
195 
196     /// Registers the current thread with an operation.
197     #[inline]
register(&self, oper: Operation, cx: &Context)198     pub(crate) fn register(&self, oper: Operation, cx: &Context) {
199         let mut inner = self.inner.lock();
200         inner.register(oper, cx);
201         self.is_empty.store(
202             inner.selectors.is_empty() && inner.observers.is_empty(),
203             Ordering::SeqCst,
204         );
205     }
206 
207     /// Unregisters an operation previously registered by the current thread.
208     #[inline]
unregister(&self, oper: Operation) -> Option<Entry>209     pub(crate) fn unregister(&self, oper: Operation) -> Option<Entry> {
210         let mut inner = self.inner.lock();
211         let entry = inner.unregister(oper);
212         self.is_empty.store(
213             inner.selectors.is_empty() && inner.observers.is_empty(),
214             Ordering::SeqCst,
215         );
216         entry
217     }
218 
219     /// Attempts to find one thread (not the current one), select its operation, and wake it up.
220     #[inline]
notify(&self)221     pub(crate) fn notify(&self) {
222         if !self.is_empty.load(Ordering::SeqCst) {
223             let mut inner = self.inner.lock();
224             if !self.is_empty.load(Ordering::SeqCst) {
225                 inner.try_select();
226                 inner.notify();
227                 self.is_empty.store(
228                     inner.selectors.is_empty() && inner.observers.is_empty(),
229                     Ordering::SeqCst,
230                 );
231             }
232         }
233     }
234 
235     /// Registers an operation waiting to be ready.
236     #[inline]
watch(&self, oper: Operation, cx: &Context)237     pub(crate) fn watch(&self, oper: Operation, cx: &Context) {
238         let mut inner = self.inner.lock();
239         inner.watch(oper, cx);
240         self.is_empty.store(
241             inner.selectors.is_empty() && inner.observers.is_empty(),
242             Ordering::SeqCst,
243         );
244     }
245 
246     /// Unregisters an operation waiting to be ready.
247     #[inline]
unwatch(&self, oper: Operation)248     pub(crate) fn unwatch(&self, oper: Operation) {
249         let mut inner = self.inner.lock();
250         inner.unwatch(oper);
251         self.is_empty.store(
252             inner.selectors.is_empty() && inner.observers.is_empty(),
253             Ordering::SeqCst,
254         );
255     }
256 
257     /// Notifies all threads that the channel is disconnected.
258     #[inline]
disconnect(&self)259     pub(crate) fn disconnect(&self) {
260         let mut inner = self.inner.lock();
261         inner.disconnect();
262         self.is_empty.store(
263             inner.selectors.is_empty() && inner.observers.is_empty(),
264             Ordering::SeqCst,
265         );
266     }
267 }
268 
269 impl Drop for SyncWaker {
270     #[inline]
drop(&mut self)271     fn drop(&mut self) {
272         debug_assert_eq!(self.is_empty.load(Ordering::SeqCst), true);
273     }
274 }
275 
276 /// Returns the id of the current thread.
277 #[inline]
current_thread_id() -> ThreadId278 fn current_thread_id() -> ThreadId {
279     thread_local! {
280         /// Cached thread-local id.
281         static THREAD_ID: ThreadId = thread::current().id();
282     }
283 
284     THREAD_ID
285         .try_with(|id| *id)
286         .unwrap_or_else(|_| thread::current().id())
287 }
288