1 //! # Implementation Details
2 //!
3 //! The semaphore is implemented using an intrusive linked list of waiters. An
4 //! atomic counter tracks the number of available permits. If the semaphore does
5 //! not contain the required number of permits, the task attempting to acquire
6 //! permits places its waker at the end of a queue. When new permits are made
7 //! available (such as by releasing an initial acquisition), they are assigned
8 //! to the task at the front of the queue, waking that task if its requested
9 //! number of permits is met.
10 //!
11 //! Because waiters are enqueued at the back of the linked list and dequeued
12 //! from the front, the semaphore is fair. Tasks trying to acquire large numbers
13 //! of permits at a time will always be woken eventually, even if many other
14 //! tasks are acquiring smaller numbers of permits. This means that in a
15 //! use-case like tokio's read-write lock, writers will not be starved by
16 //! readers.
17 use crate::loom::cell::UnsafeCell;
18 use crate::loom::sync::atomic::AtomicUsize;
19 use crate::loom::sync::{Mutex, MutexGuard};
20 use crate::util::linked_list::{self, LinkedList};
21 
22 use std::future::Future;
23 use std::marker::PhantomPinned;
24 use std::pin::Pin;
25 use std::ptr::NonNull;
26 use std::sync::atomic::Ordering::*;
27 use std::task::Poll::*;
28 use std::task::{Context, Poll, Waker};
29 use std::{cmp, fmt};
30 
31 /// An asynchronous counting semaphore which permits waiting on multiple permits at once.
32 pub(crate) struct Semaphore {
33     waiters: Mutex<Waitlist>,
34     /// The current number of available permits in the semaphore.
35     permits: AtomicUsize,
36 }
37 
38 struct Waitlist {
39     queue: LinkedList<Waiter>,
40     closed: bool,
41 }
42 
43 /// Error returned by `Semaphore::try_acquire`.
44 #[derive(Debug)]
45 pub(crate) enum TryAcquireError {
46     Closed,
47     NoPermits,
48 }
49 /// Error returned by `Semaphore::acquire`.
50 #[derive(Debug)]
51 pub(crate) struct AcquireError(());
52 
53 pub(crate) struct Acquire<'a> {
54     node: Waiter,
55     semaphore: &'a Semaphore,
56     num_permits: u16,
57     queued: bool,
58 }
59 
60 /// An entry in the wait queue.
61 struct Waiter {
62     /// The current state of the waiter.
63     ///
64     /// This is either the number of remaining permits required by
65     /// the waiter, or a flag indicating that the waiter is not yet queued.
66     state: AtomicUsize,
67 
68     /// The waker to notify the task awaiting permits.
69     ///
70     /// # Safety
71     ///
72     /// This may only be accessed while the wait queue is locked.
73     waker: UnsafeCell<Option<Waker>>,
74 
75     /// Intrusive linked-list pointers.
76     ///
77     /// # Safety
78     ///
79     /// This may only be accessed while the wait queue is locked.
80     ///
81     /// TODO: Ideally, we would be able to use loom to enforce that
82     /// this isn't accessed concurrently. However, it is difficult to
83     /// use a `UnsafeCell` here, since the `Link` trait requires _returning_
84     /// references to `Pointers`, and `UnsafeCell` requires that checked access
85     /// take place inside a closure. We should consider changing `Pointers` to
86     /// use `UnsafeCell` internally.
87     pointers: linked_list::Pointers<Waiter>,
88 
89     /// Should not be `Unpin`.
90     _p: PhantomPinned,
91 }
92 
93 impl Semaphore {
94     /// The maximum number of permits which a semaphore can hold.
95     ///
96     /// Note that this reserves three bits of flags in the permit counter, but
97     /// we only actually use one of them. However, the previous semaphore
98     /// implementation used three bits, so we will continue to reserve them to
99     /// avoid a breaking change if additional flags need to be aadded in the
100     /// future.
101     pub(crate) const MAX_PERMITS: usize = std::usize::MAX >> 3;
102     const CLOSED: usize = 1;
103     const PERMIT_SHIFT: usize = 1;
104 
105     /// Creates a new semaphore with the initial number of permits
new(permits: usize) -> Self106     pub(crate) fn new(permits: usize) -> Self {
107         assert!(
108             permits <= Self::MAX_PERMITS,
109             "a semaphore may not have more than MAX_PERMITS permits ({})",
110             Self::MAX_PERMITS
111         );
112         Self {
113             permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT),
114             waiters: Mutex::new(Waitlist {
115                 queue: LinkedList::new(),
116                 closed: false,
117             }),
118         }
119     }
120 
121     /// Returns the current number of available permits
available_permits(&self) -> usize122     pub(crate) fn available_permits(&self) -> usize {
123         self.permits.load(Acquire) >> Self::PERMIT_SHIFT
124     }
125 
126     /// Adds `n` new permits to the semaphore.
release(&self, added: usize)127     pub(crate) fn release(&self, added: usize) {
128         if added == 0 {
129             return;
130         }
131 
132         // Assign permits to the wait queue
133         self.add_permits_locked(added, self.waiters.lock().unwrap());
134     }
135 
136     /// Closes the semaphore. This prevents the semaphore from issuing new
137     /// permits and notifies all pending waiters.
138     // This will be used once the bounded MPSC is updated to use the new
139     // semaphore implementation.
140     #[allow(dead_code)]
close(&self)141     pub(crate) fn close(&self) {
142         let mut waiters = self.waiters.lock().unwrap();
143         // If the semaphore's permits counter has enough permits for an
144         // unqueued waiter to acquire all the permits it needs immediately,
145         // it won't touch the wait list. Therefore, we have to set a bit on
146         // the permit counter as well. However, we must do this while
147         // holding the lock --- otherwise, if we set the bit and then wait
148         // to acquire the lock we'll enter an inconsistent state where the
149         // permit counter is closed, but the wait list is not.
150         self.permits.fetch_or(Self::CLOSED, Release);
151         waiters.closed = true;
152         while let Some(mut waiter) = waiters.queue.pop_back() {
153             let waker = unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) };
154             if let Some(waker) = waker {
155                 waker.wake();
156             }
157         }
158     }
159 
try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError>160     pub(crate) fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> {
161         let mut curr = self.permits.load(Acquire);
162         let num_permits = (num_permits as usize) << Self::PERMIT_SHIFT;
163         loop {
164             // Has the semaphore closed?git
165             if curr & Self::CLOSED > 0 {
166                 return Err(TryAcquireError::Closed);
167             }
168 
169             // Are there enough permits remaining?
170             if curr < num_permits {
171                 return Err(TryAcquireError::NoPermits);
172             }
173 
174             let next = curr - num_permits;
175 
176             match self.permits.compare_exchange(curr, next, AcqRel, Acquire) {
177                 Ok(_) => return Ok(()),
178                 Err(actual) => curr = actual,
179             }
180         }
181     }
182 
acquire(&self, num_permits: u16) -> Acquire<'_>183     pub(crate) fn acquire(&self, num_permits: u16) -> Acquire<'_> {
184         Acquire::new(self, num_permits)
185     }
186 
187     /// Release `rem` permits to the semaphore's wait list, starting from the
188     /// end of the queue.
189     ///
190     /// If `rem` exceeds the number of permits needed by the wait list, the
191     /// remainder are assigned back to the semaphore.
add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>)192     fn add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>) {
193         let mut wakers: [Option<Waker>; 8] = Default::default();
194         let mut lock = Some(waiters);
195         let mut is_empty = false;
196         while rem > 0 {
197             let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock().unwrap());
198             'inner: for slot in &mut wakers[..] {
199                 // Was the waiter assigned enough permits to wake it?
200                 match waiters.queue.last() {
201                     Some(waiter) => {
202                         if !waiter.assign_permits(&mut rem) {
203                             break 'inner;
204                         }
205                     }
206                     None => {
207                         is_empty = true;
208                         // If we assigned permits to all the waiters in the queue, and there are
209                         // still permits left over, assign them back to the semaphore.
210                         break 'inner;
211                     }
212                 };
213                 let mut waiter = waiters.queue.pop_back().unwrap();
214                 *slot = unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) };
215             }
216 
217             if rem > 0 && is_empty {
218                 let permits = rem << Self::PERMIT_SHIFT;
219                 assert!(
220                     permits < Self::MAX_PERMITS,
221                     "cannot add more than MAX_PERMITS permits ({})",
222                     Self::MAX_PERMITS
223                 );
224                 let prev = self.permits.fetch_add(rem << Self::PERMIT_SHIFT, Release);
225                 assert!(
226                     prev + permits <= Self::MAX_PERMITS,
227                     "number of added permits ({}) would overflow MAX_PERMITS ({})",
228                     rem,
229                     Self::MAX_PERMITS
230                 );
231                 rem = 0;
232             }
233 
234             drop(waiters); // release the lock
235 
236             wakers
237                 .iter_mut()
238                 .filter_map(Option::take)
239                 .for_each(Waker::wake);
240         }
241 
242         assert_eq!(rem, 0);
243     }
244 
poll_acquire( &self, cx: &mut Context<'_>, num_permits: u16, node: Pin<&mut Waiter>, queued: bool, ) -> Poll<Result<(), AcquireError>>245     fn poll_acquire(
246         &self,
247         cx: &mut Context<'_>,
248         num_permits: u16,
249         node: Pin<&mut Waiter>,
250         queued: bool,
251     ) -> Poll<Result<(), AcquireError>> {
252         let mut acquired = 0;
253 
254         let needed = if queued {
255             node.state.load(Acquire) << Self::PERMIT_SHIFT
256         } else {
257             (num_permits as usize) << Self::PERMIT_SHIFT
258         };
259 
260         let mut lock = None;
261         // First, try to take the requested number of permits from the
262         // semaphore.
263         let mut curr = self.permits.load(Acquire);
264         let mut waiters = loop {
265             // Has the semaphore closed?
266             if curr & Self::CLOSED > 0 {
267                 return Ready(Err(AcquireError::closed()));
268             }
269 
270             let mut remaining = 0;
271             let total = curr
272                 .checked_add(acquired)
273                 .expect("number of permits must not overflow");
274             let (next, acq) = if total >= needed {
275                 let next = curr - (needed - acquired);
276                 (next, needed >> Self::PERMIT_SHIFT)
277             } else {
278                 remaining = (needed - acquired) - curr;
279                 (0, curr >> Self::PERMIT_SHIFT)
280             };
281 
282             if remaining > 0 && lock.is_none() {
283                 // No permits were immediately available, so this permit will
284                 // (probably) need to wait. We'll need to acquire a lock on the
285                 // wait queue before continuing. We need to do this _before_ the
286                 // CAS that sets the new value of the semaphore's `permits`
287                 // counter. Otherwise, if we subtract the permits and then
288                 // acquire the lock, we might miss additional permits being
289                 // added while waiting for the lock.
290                 lock = Some(self.waiters.lock().unwrap());
291             }
292 
293             match self.permits.compare_exchange(curr, next, AcqRel, Acquire) {
294                 Ok(_) => {
295                     acquired += acq;
296                     if remaining == 0 {
297                         if !queued {
298                             return Ready(Ok(()));
299                         } else if lock.is_none() {
300                             break self.waiters.lock().unwrap();
301                         }
302                     }
303                     break lock.expect("lock must be acquired before waiting");
304                 }
305                 Err(actual) => curr = actual,
306             }
307         };
308 
309         if waiters.closed {
310             return Ready(Err(AcquireError::closed()));
311         }
312 
313         if node.assign_permits(&mut acquired) {
314             self.add_permits_locked(acquired, waiters);
315             return Ready(Ok(()));
316         }
317 
318         assert_eq!(acquired, 0);
319 
320         // Otherwise, register the waker & enqueue the node.
321         node.waker.with_mut(|waker| {
322             // Safety: the wait list is locked, so we may modify the waker.
323             let waker = unsafe { &mut *waker };
324             // Do we need to register the new waker?
325             if waker
326                 .as_ref()
327                 .map(|waker| !waker.will_wake(cx.waker()))
328                 .unwrap_or(true)
329             {
330                 *waker = Some(cx.waker().clone());
331             }
332         });
333 
334         // If the waiter is not already in the wait queue, enqueue it.
335         if !queued {
336             let node = unsafe {
337                 let node = Pin::into_inner_unchecked(node) as *mut _;
338                 NonNull::new_unchecked(node)
339             };
340 
341             waiters.queue.push_front(node);
342         }
343 
344         Pending
345     }
346 }
347 
348 impl fmt::Debug for Semaphore {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result349     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
350         fmt.debug_struct("Semaphore")
351             .field("permits", &self.permits.load(Relaxed))
352             .finish()
353     }
354 }
355 
356 impl Waiter {
new(num_permits: u16) -> Self357     fn new(num_permits: u16) -> Self {
358         Waiter {
359             waker: UnsafeCell::new(None),
360             state: AtomicUsize::new(num_permits as usize),
361             pointers: linked_list::Pointers::new(),
362             _p: PhantomPinned,
363         }
364     }
365 
366     /// Assign permits to the waiter.
367     ///
368     /// Returns `true` if the waiter should be removed from the queue
assign_permits(&self, n: &mut usize) -> bool369     fn assign_permits(&self, n: &mut usize) -> bool {
370         let mut curr = self.state.load(Acquire);
371         loop {
372             let assign = cmp::min(curr, *n);
373             let next = curr - assign;
374             match self.state.compare_exchange(curr, next, AcqRel, Acquire) {
375                 Ok(_) => {
376                     *n -= assign;
377                     return next == 0;
378                 }
379                 Err(actual) => curr = actual,
380             }
381         }
382     }
383 }
384 
385 impl Future for Acquire<'_> {
386     type Output = Result<(), AcquireError>;
387 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>388     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
389         let (node, semaphore, needed, queued) = self.project();
390         match semaphore.poll_acquire(cx, needed, node, *queued) {
391             Pending => {
392                 *queued = true;
393                 Pending
394             }
395             Ready(r) => {
396                 r?;
397                 *queued = false;
398                 Ready(Ok(()))
399             }
400         }
401     }
402 }
403 
404 impl<'a> Acquire<'a> {
new(semaphore: &'a Semaphore, num_permits: u16) -> Self405     fn new(semaphore: &'a Semaphore, num_permits: u16) -> Self {
406         Self {
407             node: Waiter::new(num_permits),
408             semaphore,
409             num_permits,
410             queued: false,
411         }
412     }
413 
project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, u16, &mut bool)414     fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, u16, &mut bool) {
415         fn is_unpin<T: Unpin>() {}
416         unsafe {
417             // Safety: all fields other than `node` are `Unpin`
418 
419             is_unpin::<&Semaphore>();
420             is_unpin::<&mut bool>();
421             is_unpin::<u16>();
422 
423             let this = self.get_unchecked_mut();
424             (
425                 Pin::new_unchecked(&mut this.node),
426                 &this.semaphore,
427                 this.num_permits,
428                 &mut this.queued,
429             )
430         }
431     }
432 }
433 
434 impl Drop for Acquire<'_> {
drop(&mut self)435     fn drop(&mut self) {
436         // If the future is completed, there is no node in the wait list, so we
437         // can skip acquiring the lock.
438         if !self.queued {
439             return;
440         }
441 
442         // This is where we ensure safety. The future is being dropped,
443         // which means we must ensure that the waiter entry is no longer stored
444         // in the linked list.
445         let mut waiters = match self.semaphore.waiters.lock() {
446             Ok(lock) => lock,
447             // Removing the node from the linked list is necessary to ensure
448             // safety. Even if the lock was poisoned, we need to make sure it is
449             // removed from the linked list before dropping it --- otherwise,
450             // the list will contain a dangling pointer to this node.
451             Err(e) => e.into_inner(),
452         };
453 
454         // remove the entry from the list
455         let node = NonNull::from(&mut self.node);
456         // Safety: we have locked the wait list.
457         unsafe { waiters.queue.remove(node) };
458 
459         let acquired_permits = self.num_permits as usize - self.node.state.load(Acquire);
460         if acquired_permits > 0 {
461             self.semaphore.add_permits_locked(acquired_permits, waiters);
462         }
463     }
464 }
465 
466 // Safety: the `Acquire` future is not `Sync` automatically because it contains
467 // a `Waiter`, which, in turn, contains an `UnsafeCell`. However, the
468 // `UnsafeCell` is only accessed when the future is borrowed mutably (either in
469 // `poll` or in `drop`). Therefore, it is safe (although not particularly
470 // _useful_) for the future to be borrowed immutably across threads.
471 unsafe impl Sync for Acquire<'_> {}
472 
473 // ===== impl AcquireError ====
474 
475 impl AcquireError {
closed() -> AcquireError476     fn closed() -> AcquireError {
477         AcquireError(())
478     }
479 }
480 
481 impl fmt::Display for AcquireError {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result482     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
483         write!(fmt, "semaphore closed")
484     }
485 }
486 
487 impl std::error::Error for AcquireError {}
488 
489 // ===== impl TryAcquireError =====
490 
491 impl TryAcquireError {
492     /// Returns `true` if the error was caused by a closed semaphore.
493     #[allow(dead_code)] // may be used later!
is_closed(&self) -> bool494     pub(crate) fn is_closed(&self) -> bool {
495         match self {
496             TryAcquireError::Closed => true,
497             _ => false,
498         }
499     }
500 
501     /// Returns `true` if the error was caused by calling `try_acquire` on a
502     /// semaphore with no available permits.
503     #[allow(dead_code)] // may be used later!
is_no_permits(&self) -> bool504     pub(crate) fn is_no_permits(&self) -> bool {
505         match self {
506             TryAcquireError::NoPermits => true,
507             _ => false,
508         }
509     }
510 }
511 
512 impl fmt::Display for TryAcquireError {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result513     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
514         match self {
515             TryAcquireError::Closed => write!(fmt, "{}", "semaphore closed"),
516             TryAcquireError::NoPermits => write!(fmt, "{}", "no permits available"),
517         }
518     }
519 }
520 
521 impl std::error::Error for TryAcquireError {}
522 
523 /// # Safety
524 ///
525 /// `Waiter` is forced to be !Unpin.
526 unsafe impl linked_list::Link for Waiter {
527     // XXX: ideally, we would be able to use `Pin` here, to enforce the
528     // invariant that list entries may not move while in the list. However, we
529     // can't do this currently, as using `Pin<&'a mut Waiter>` as the `Handle`
530     // type would require `Semaphore` to be generic over a lifetime. We can't
531     // use `Pin<*mut Waiter>`, as raw pointers are `Unpin` regardless of whether
532     // or not they dereference to an `!Unpin` target.
533     type Handle = NonNull<Waiter>;
534     type Target = Waiter;
535 
as_raw(handle: &Self::Handle) -> NonNull<Waiter>536     fn as_raw(handle: &Self::Handle) -> NonNull<Waiter> {
537         *handle
538     }
539 
from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter>540     unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> {
541         ptr
542     }
543 
pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>>544     unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> {
545         NonNull::from(&mut target.as_mut().pointers)
546     }
547 }
548