1 #![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))]
2 
3 //! Thread-safe, asynchronous counting semaphore.
4 //!
5 //! A `Semaphore` instance holds a set of permits. Permits are used to
6 //! synchronize access to a shared resource.
7 //!
8 //! Before accessing the shared resource, callers acquire a permit from the
9 //! semaphore. Once the permit is acquired, the caller then enters the critical
10 //! section. If no permits are available, then acquiring the semaphore returns
11 //! `Pending`. The task is woken once a permit becomes available.
12 
13 use crate::loom::cell::UnsafeCell;
14 use crate::loom::future::AtomicWaker;
15 use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize};
16 use crate::loom::thread;
17 
18 use std::cmp;
19 use std::fmt;
20 use std::ptr::{self, NonNull};
21 use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed, Release};
22 use std::task::Poll::{Pending, Ready};
23 use std::task::{Context, Poll};
24 use std::usize;
25 
26 /// Futures-aware semaphore.
27 pub(crate) struct Semaphore {
28     /// Tracks both the waiter queue tail pointer and the number of remaining
29     /// permits.
30     state: AtomicUsize,
31 
32     /// waiter queue head pointer.
33     head: UnsafeCell<NonNull<Waiter>>,
34 
35     /// Coordinates access to the queue head.
36     rx_lock: AtomicUsize,
37 
38     /// Stub waiter node used as part of the MPSC channel algorithm.
39     stub: Box<Waiter>,
40 }
41 
42 /// A semaphore permit
43 ///
44 /// Tracks the lifecycle of a semaphore permit.
45 ///
46 /// An instance of `Permit` is intended to be used with a **single** instance of
47 /// `Semaphore`. Using a single instance of `Permit` with multiple semaphore
48 /// instances will result in unexpected behavior.
49 ///
50 /// `Permit` does **not** release the permit back to the semaphore on drop. It
51 /// is the user's responsibility to ensure that `Permit::release` is called
52 /// before dropping the permit.
53 #[derive(Debug)]
54 pub(crate) struct Permit {
55     waiter: Option<Box<Waiter>>,
56     state: PermitState,
57 }
58 
59 /// Error returned by `Permit::poll_acquire`.
60 #[derive(Debug)]
61 pub(crate) struct AcquireError(());
62 
63 /// Error returned by `Permit::try_acquire`.
64 #[derive(Debug)]
65 pub(crate) enum TryAcquireError {
66     Closed,
67     NoPermits,
68 }
69 
70 /// Node used to notify the semaphore waiter when permit is available.
71 #[derive(Debug)]
72 struct Waiter {
73     /// Stores waiter state.
74     ///
75     /// See `WaiterState` for more details.
76     state: AtomicUsize,
77 
78     /// Task to wake when a permit is made available.
79     waker: AtomicWaker,
80 
81     /// Next pointer in the queue of waiting senders.
82     next: AtomicPtr<Waiter>,
83 }
84 
85 /// Semaphore state
86 ///
87 /// The 2 low bits track the modes.
88 ///
89 /// - Closed
90 /// - Full
91 ///
92 /// When not full, the rest of the `usize` tracks the total number of messages
93 /// in the channel. When full, the rest of the `usize` is a pointer to the tail
94 /// of the "waiting senders" queue.
95 #[derive(Copy, Clone)]
96 struct SemState(usize);
97 
98 /// Permit state
99 #[derive(Debug, Copy, Clone)]
100 enum PermitState {
101     /// Currently waiting for permits to be made available and assigned to the
102     /// waiter.
103     Waiting(u16),
104 
105     /// The number of acquired permits
106     Acquired(u16),
107 }
108 
109 /// State for an individual waker node
110 #[derive(Debug, Copy, Clone)]
111 struct WaiterState(usize);
112 
113 /// Waiter node is in the semaphore queue
114 const QUEUED: usize = 0b001;
115 
116 /// Semaphore has been closed, no more permits will be issued.
117 const CLOSED: usize = 0b10;
118 
119 /// The permit that owns the `Waiter` dropped.
120 const DROPPED: usize = 0b100;
121 
122 /// Represents "one requested permit" in the waiter state
123 const PERMIT_ONE: usize = 0b1000;
124 
125 /// Masks the waiter state to only contain bits tracking number of requested
126 /// permits.
127 const PERMIT_MASK: usize = usize::MAX - (PERMIT_ONE - 1);
128 
129 /// How much to shift a permit count to pack it into the waker state
130 const PERMIT_SHIFT: u32 = PERMIT_ONE.trailing_zeros();
131 
132 /// Flag differentiating between available permits and waiter pointers.
133 ///
134 /// If we assume pointers are properly aligned, then the least significant bit
135 /// will always be zero. So, we use that bit to track if the value represents a
136 /// number.
137 const NUM_FLAG: usize = 0b01;
138 
139 /// Signal the semaphore is closed
140 const CLOSED_FLAG: usize = 0b10;
141 
142 /// Maximum number of permits a semaphore can manage
143 const MAX_PERMITS: usize = usize::MAX >> NUM_SHIFT;
144 
145 /// When representing "numbers", the state has to be shifted this much (to get
146 /// rid of the flag bit).
147 const NUM_SHIFT: usize = 2;
148 
149 // ===== impl Semaphore =====
150 
151 impl Semaphore {
152     /// Creates a new semaphore with the initial number of permits
153     ///
154     /// # Panics
155     ///
156     /// Panics if `permits` is zero.
new(permits: usize) -> Semaphore157     pub(crate) fn new(permits: usize) -> Semaphore {
158         let stub = Box::new(Waiter::new());
159         let ptr = NonNull::from(&*stub);
160 
161         // Allocations are aligned
162         debug_assert!(ptr.as_ptr() as usize & NUM_FLAG == 0);
163 
164         let state = SemState::new(permits, &stub);
165 
166         Semaphore {
167             state: AtomicUsize::new(state.to_usize()),
168             head: UnsafeCell::new(ptr),
169             rx_lock: AtomicUsize::new(0),
170             stub,
171         }
172     }
173 
174     /// Returns the current number of available permits
available_permits(&self) -> usize175     pub(crate) fn available_permits(&self) -> usize {
176         let curr = SemState(self.state.load(Acquire));
177         curr.available_permits()
178     }
179 
180     /// Tries to acquire the requested number of permits, registering the waiter
181     /// if not enough permits are available.
poll_acquire( &self, cx: &mut Context<'_>, num_permits: u16, permit: &mut Permit, ) -> Poll<Result<(), AcquireError>>182     fn poll_acquire(
183         &self,
184         cx: &mut Context<'_>,
185         num_permits: u16,
186         permit: &mut Permit,
187     ) -> Poll<Result<(), AcquireError>> {
188         self.poll_acquire2(num_permits, || {
189             let waiter = permit.waiter.get_or_insert_with(|| Box::new(Waiter::new()));
190 
191             waiter.waker.register_by_ref(cx.waker());
192 
193             Some(NonNull::from(&**waiter))
194         })
195     }
196 
try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError>197     fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> {
198         match self.poll_acquire2(num_permits, || None) {
199             Poll::Ready(res) => res.map_err(to_try_acquire),
200             Poll::Pending => Err(TryAcquireError::NoPermits),
201         }
202     }
203 
204     /// Polls for a permit
205     ///
206     /// Tries to acquire available permits first. If unable to acquire a
207     /// sufficient number of permits, the caller's waiter is pushed onto the
208     /// semaphore's wait queue.
poll_acquire2<F>( &self, num_permits: u16, mut get_waiter: F, ) -> Poll<Result<(), AcquireError>> where F: FnMut() -> Option<NonNull<Waiter>>,209     fn poll_acquire2<F>(
210         &self,
211         num_permits: u16,
212         mut get_waiter: F,
213     ) -> Poll<Result<(), AcquireError>>
214     where
215         F: FnMut() -> Option<NonNull<Waiter>>,
216     {
217         let num_permits = num_permits as usize;
218 
219         // Load the current state
220         let mut curr = SemState(self.state.load(Acquire));
221 
222         // Saves a ref to the waiter node
223         let mut maybe_waiter: Option<NonNull<Waiter>> = None;
224 
225         /// Used in branches where we attempt to push the waiter into the wait
226         /// queue but fail due to permits becoming available or the wait queue
227         /// transitioning to "closed". In this case, the waiter must be
228         /// transitioned back to the "idle" state.
229         macro_rules! revert_to_idle {
230             () => {
231                 if let Some(waiter) = maybe_waiter {
232                     unsafe { waiter.as_ref() }.revert_to_idle();
233                 }
234             };
235         }
236 
237         loop {
238             let mut next = curr;
239 
240             if curr.is_closed() {
241                 revert_to_idle!();
242                 return Ready(Err(AcquireError::closed()));
243             }
244 
245             let acquired = next.acquire_permits(num_permits, &self.stub);
246 
247             if !acquired {
248                 // There are not enough available permits to satisfy the
249                 // request. The permit transitions to a waiting state.
250                 debug_assert!(curr.waiter().is_some() || curr.available_permits() < num_permits);
251 
252                 if let Some(waiter) = maybe_waiter.as_ref() {
253                     // Safety: the caller owns the waiter.
254                     let w = unsafe { waiter.as_ref() };
255                     w.set_permits_to_acquire(num_permits - curr.available_permits());
256                 } else {
257                     // Get the waiter for the permit.
258                     if let Some(waiter) = get_waiter() {
259                         // Safety: the caller owns the waiter.
260                         let w = unsafe { waiter.as_ref() };
261 
262                         // If there are any currently available permits, the
263                         // waiter acquires those immediately and waits for the
264                         // remaining permits to become available.
265                         if !w.to_queued(num_permits - curr.available_permits()) {
266                             // The node is alrady queued, there is no further work
267                             // to do.
268                             return Pending;
269                         }
270 
271                         maybe_waiter = Some(waiter);
272                     } else {
273                         // No waiter, this indicates the caller does not wish to
274                         // "wait", so there is nothing left to do.
275                         return Pending;
276                     }
277                 }
278 
279                 next.set_waiter(maybe_waiter.unwrap());
280             }
281 
282             debug_assert_ne!(curr.0, 0);
283             debug_assert_ne!(next.0, 0);
284 
285             match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
286                 Ok(_) => {
287                     if acquired {
288                         // Successfully acquire permits **without** queuing the
289                         // waiter node. The waiter node is not currently in the
290                         // queue.
291                         revert_to_idle!();
292                         return Ready(Ok(()));
293                     } else {
294                         // The node is pushed into the queue, the final step is
295                         // to set the node's "next" pointer to return the wait
296                         // queue into a consistent state.
297 
298                         let prev_waiter =
299                             curr.waiter().unwrap_or_else(|| NonNull::from(&*self.stub));
300 
301                         let waiter = maybe_waiter.unwrap();
302 
303                         // Link the nodes.
304                         //
305                         // Safety: the mpsc algorithm guarantees the old tail of
306                         // the queue is not removed from the queue during the
307                         // push process.
308                         unsafe {
309                             prev_waiter.as_ref().store_next(waiter);
310                         }
311 
312                         return Pending;
313                     }
314                 }
315                 Err(actual) => {
316                     curr = SemState(actual);
317                 }
318             }
319         }
320     }
321 
322     /// Closes the semaphore. This prevents the semaphore from issuing new
323     /// permits and notifies all pending waiters.
close(&self)324     pub(crate) fn close(&self) {
325         // Acquire the `rx_lock`, setting the "closed" flag on the lock.
326         let prev = self.rx_lock.fetch_or(1, AcqRel);
327 
328         if prev != 0 {
329             // Another thread has the lock and will be responsible for notifying
330             // pending waiters.
331             return;
332         }
333 
334         self.add_permits_locked(0, true);
335     }
336 
337     /// Adds `n` new permits to the semaphore.
add_permits(&self, n: usize)338     pub(crate) fn add_permits(&self, n: usize) {
339         if n == 0 {
340             return;
341         }
342 
343         // TODO: Handle overflow. A panic is not sufficient, the process must
344         // abort.
345         let prev = self.rx_lock.fetch_add(n << 1, AcqRel);
346 
347         if prev != 0 {
348             // Another thread has the lock and will be responsible for notifying
349             // pending waiters.
350             return;
351         }
352 
353         self.add_permits_locked(n, false);
354     }
355 
add_permits_locked(&self, mut rem: usize, mut closed: bool)356     fn add_permits_locked(&self, mut rem: usize, mut closed: bool) {
357         while rem > 0 || closed {
358             if closed {
359                 SemState::fetch_set_closed(&self.state, AcqRel);
360             }
361 
362             // Release the permits and notify
363             self.add_permits_locked2(rem, closed);
364 
365             let n = rem << 1;
366 
367             let actual = if closed {
368                 let actual = self.rx_lock.fetch_sub(n | 1, AcqRel);
369                 closed = false;
370                 actual
371             } else {
372                 let actual = self.rx_lock.fetch_sub(n, AcqRel);
373                 closed = actual & 1 == 1;
374                 actual
375             };
376 
377             rem = (actual >> 1) - rem;
378         }
379     }
380 
381     /// Releases a specific amount of permits to the semaphore
382     ///
383     /// This function is called by `add_permits` after the add lock has been
384     /// acquired.
add_permits_locked2(&self, mut n: usize, closed: bool)385     fn add_permits_locked2(&self, mut n: usize, closed: bool) {
386         // If closing the semaphore, we want to drain the entire queue. The
387         // number of permits being assigned doesn't matter.
388         if closed {
389             n = usize::MAX;
390         }
391 
392         'outer: while n > 0 {
393             unsafe {
394                 let mut head = self.head.with(|head| *head);
395                 let mut next_ptr = head.as_ref().next.load(Acquire);
396 
397                 let stub = self.stub();
398 
399                 if head == stub {
400                     // The stub node indicates an empty queue. Any remaining
401                     // permits get assigned back to the semaphore.
402                     let next = match NonNull::new(next_ptr) {
403                         Some(next) => next,
404                         None => {
405                             // This loop is not part of the standard intrusive mpsc
406                             // channel algorithm. This is where we atomically pop
407                             // the last task and add `n` to the remaining capacity.
408                             //
409                             // This modification to the pop algorithm works because,
410                             // at this point, we have not done any work (only done
411                             // reading). We have a *pretty* good idea that there is
412                             // no concurrent pusher.
413                             //
414                             // The capacity is then atomically added by doing an
415                             // AcqRel CAS on `state`. The `state` cell is the
416                             // linchpin of the algorithm.
417                             //
418                             // By successfully CASing `head` w/ AcqRel, we ensure
419                             // that, if any thread was racing and entered a push, we
420                             // see that and abort pop, retrying as it is
421                             // "inconsistent".
422                             let mut curr = SemState::load(&self.state, Acquire);
423 
424                             loop {
425                                 if curr.has_waiter(&self.stub) {
426                                     // A waiter is being added concurrently.
427                                     // This is the MPSC queue's "inconsistent"
428                                     // state and we must loop and try again.
429                                     thread::yield_now();
430                                     continue 'outer;
431                                 }
432 
433                                 // If closing, nothing more to do.
434                                 if closed {
435                                     debug_assert!(curr.is_closed(), "state = {:?}", curr);
436                                     return;
437                                 }
438 
439                                 let mut next = curr;
440                                 next.release_permits(n, &self.stub);
441 
442                                 match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
443                                     Ok(_) => return,
444                                     Err(actual) => {
445                                         curr = SemState(actual);
446                                     }
447                                 }
448                             }
449                         }
450                     };
451 
452                     self.head.with_mut(|head| *head = next);
453                     head = next;
454                     next_ptr = next.as_ref().next.load(Acquire);
455                 }
456 
457                 // `head` points to a waiter assign permits to the waiter. If
458                 // all requested permits are satisfied, then we can continue,
459                 // otherwise the node stays in the wait queue.
460                 if !head.as_ref().assign_permits(&mut n, closed) {
461                     assert_eq!(n, 0);
462                     return;
463                 }
464 
465                 if let Some(next) = NonNull::new(next_ptr) {
466                     self.head.with_mut(|head| *head = next);
467 
468                     self.remove_queued(head, closed);
469                     continue 'outer;
470                 }
471 
472                 let state = SemState::load(&self.state, Acquire);
473 
474                 // This must always be a pointer as the wait list is not empty.
475                 let tail = state.waiter().unwrap();
476 
477                 if tail != head {
478                     // Inconsistent
479                     thread::yield_now();
480                     continue 'outer;
481                 }
482 
483                 self.push_stub(closed);
484 
485                 next_ptr = head.as_ref().next.load(Acquire);
486 
487                 if let Some(next) = NonNull::new(next_ptr) {
488                     self.head.with_mut(|head| *head = next);
489 
490                     self.remove_queued(head, closed);
491                     continue 'outer;
492                 }
493 
494                 // Inconsistent state, loop
495                 thread::yield_now();
496             }
497         }
498     }
499 
500     /// The wait node has had all of its permits assigned and has been removed
501     /// from the wait queue.
502     ///
503     /// Attempt to remove the QUEUED bit from the node. If additional permits
504     /// are concurrently requested, the node must be pushed back into the wait
505     /// queued.
remove_queued(&self, waiter: NonNull<Waiter>, closed: bool)506     fn remove_queued(&self, waiter: NonNull<Waiter>, closed: bool) {
507         let mut curr = WaiterState(unsafe { waiter.as_ref() }.state.load(Acquire));
508 
509         loop {
510             if curr.is_dropped() {
511                 // The Permit dropped, it is on us to release the memory
512                 let _ = unsafe { Box::from_raw(waiter.as_ptr()) };
513                 return;
514             }
515 
516             // The node is removed from the queue. We attempt to unset the
517             // queued bit, but concurrently the waiter has requested more
518             // permits. When the waiter requested more permits, it saw the
519             // queued bit set so took no further action. This requires us to
520             // push the node back into the queue.
521             if curr.permits_to_acquire() > 0 {
522                 // More permits are requested. The waiter must be re-queued
523                 unsafe {
524                     self.push_waiter(waiter, closed);
525                 }
526                 return;
527             }
528 
529             let mut next = curr;
530             next.unset_queued();
531 
532             let w = unsafe { waiter.as_ref() };
533 
534             match w.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
535                 Ok(_) => return,
536                 Err(actual) => {
537                     curr = WaiterState(actual);
538                 }
539             }
540         }
541     }
542 
push_stub(&self, closed: bool)543     unsafe fn push_stub(&self, closed: bool) {
544         self.push_waiter(self.stub(), closed);
545     }
546 
push_waiter(&self, waiter: NonNull<Waiter>, closed: bool)547     unsafe fn push_waiter(&self, waiter: NonNull<Waiter>, closed: bool) {
548         // Set the next pointer. This does not require an atomic operation as
549         // this node is not accessible. The write will be flushed with the next
550         // operation
551         waiter.as_ref().next.store(ptr::null_mut(), Relaxed);
552 
553         // Update the tail to point to the new node. We need to see the previous
554         // node in order to update the next pointer as well as release `task`
555         // to any other threads calling `push`.
556         let next = SemState::new_ptr(waiter, closed);
557         let prev = SemState(self.state.swap(next.0, AcqRel));
558 
559         debug_assert_eq!(closed, prev.is_closed());
560 
561         // This function is only called when there are pending tasks. Because of
562         // this, the state must *always* be in pointer mode.
563         let prev = prev.waiter().unwrap();
564 
565         // No cycles plz
566         debug_assert_ne!(prev, waiter);
567 
568         // Release `task` to the consume end.
569         prev.as_ref().next.store(waiter.as_ptr(), Release);
570     }
571 
stub(&self) -> NonNull<Waiter>572     fn stub(&self) -> NonNull<Waiter> {
573         unsafe { NonNull::new_unchecked(&*self.stub as *const _ as *mut _) }
574     }
575 }
576 
577 impl Drop for Semaphore {
drop(&mut self)578     fn drop(&mut self) {
579         self.close();
580     }
581 }
582 
583 impl fmt::Debug for Semaphore {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result584     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
585         fmt.debug_struct("Semaphore")
586             .field("state", &SemState::load(&self.state, Relaxed))
587             .field("head", &self.head.with(|ptr| ptr))
588             .field("rx_lock", &self.rx_lock.load(Relaxed))
589             .field("stub", &self.stub)
590             .finish()
591     }
592 }
593 
594 unsafe impl Send for Semaphore {}
595 unsafe impl Sync for Semaphore {}
596 
597 // ===== impl Permit =====
598 
599 impl Permit {
600     /// Creates a new `Permit`.
601     ///
602     /// The permit begins in the "unacquired" state.
new() -> Permit603     pub(crate) fn new() -> Permit {
604         use PermitState::Acquired;
605 
606         Permit {
607             waiter: None,
608             state: Acquired(0),
609         }
610     }
611 
612     /// Returns `true` if the permit has been acquired
613     #[allow(dead_code)] // may be used later
is_acquired(&self) -> bool614     pub(crate) fn is_acquired(&self) -> bool {
615         match self.state {
616             PermitState::Acquired(num) if num > 0 => true,
617             _ => false,
618         }
619     }
620 
621     /// Tries to acquire the permit. If no permits are available, the current task
622     /// is notified once a new permit becomes available.
poll_acquire( &mut self, cx: &mut Context<'_>, num_permits: u16, semaphore: &Semaphore, ) -> Poll<Result<(), AcquireError>>623     pub(crate) fn poll_acquire(
624         &mut self,
625         cx: &mut Context<'_>,
626         num_permits: u16,
627         semaphore: &Semaphore,
628     ) -> Poll<Result<(), AcquireError>> {
629         use std::cmp::Ordering::*;
630         use PermitState::*;
631 
632         match self.state {
633             Waiting(requested) => {
634                 // There must be a waiter
635                 let waiter = self.waiter.as_ref().unwrap();
636 
637                 match requested.cmp(&num_permits) {
638                     Less => {
639                         let delta = num_permits - requested;
640 
641                         // Request additional permits. If the waiter has been
642                         // dequeued, it must be re-queued.
643                         if !waiter.try_inc_permits_to_acquire(delta as usize) {
644                             let waiter = NonNull::from(&**waiter);
645 
646                             // Ignore the result. The check for
647                             // `permits_to_acquire()` will converge the state as
648                             // needed
649                             let _ = semaphore.poll_acquire2(delta, || Some(waiter))?;
650                         }
651 
652                         self.state = Waiting(num_permits);
653                     }
654                     Greater => {
655                         let delta = requested - num_permits;
656                         let to_release = waiter.try_dec_permits_to_acquire(delta as usize);
657 
658                         semaphore.add_permits(to_release);
659                         self.state = Waiting(num_permits);
660                     }
661                     Equal => {}
662                 }
663 
664                 if waiter.permits_to_acquire()? == 0 {
665                     self.state = Acquired(requested);
666                     return Ready(Ok(()));
667                 }
668 
669                 waiter.waker.register_by_ref(cx.waker());
670 
671                 if waiter.permits_to_acquire()? == 0 {
672                     self.state = Acquired(requested);
673                     return Ready(Ok(()));
674                 }
675 
676                 Pending
677             }
678             Acquired(acquired) => {
679                 if acquired >= num_permits {
680                     Ready(Ok(()))
681                 } else {
682                     match semaphore.poll_acquire(cx, num_permits - acquired, self)? {
683                         Ready(()) => {
684                             self.state = Acquired(num_permits);
685                             Ready(Ok(()))
686                         }
687                         Pending => {
688                             self.state = Waiting(num_permits);
689                             Pending
690                         }
691                     }
692                 }
693             }
694         }
695     }
696 
697     /// Tries to acquire the permit.
try_acquire( &mut self, num_permits: u16, semaphore: &Semaphore, ) -> Result<(), TryAcquireError>698     pub(crate) fn try_acquire(
699         &mut self,
700         num_permits: u16,
701         semaphore: &Semaphore,
702     ) -> Result<(), TryAcquireError> {
703         use PermitState::*;
704 
705         match self.state {
706             Waiting(requested) => {
707                 // There must be a waiter
708                 let waiter = self.waiter.as_ref().unwrap();
709 
710                 if requested > num_permits {
711                     let delta = requested - num_permits;
712                     let to_release = waiter.try_dec_permits_to_acquire(delta as usize);
713 
714                     semaphore.add_permits(to_release);
715                     self.state = Waiting(num_permits);
716                 }
717 
718                 let res = waiter.permits_to_acquire().map_err(to_try_acquire)?;
719 
720                 if res == 0 {
721                     if requested < num_permits {
722                         // Try to acquire the additional permits
723                         semaphore.try_acquire(num_permits - requested)?;
724                     }
725 
726                     self.state = Acquired(num_permits);
727                     Ok(())
728                 } else {
729                     Err(TryAcquireError::NoPermits)
730                 }
731             }
732             Acquired(acquired) => {
733                 if acquired < num_permits {
734                     semaphore.try_acquire(num_permits - acquired)?;
735                     self.state = Acquired(num_permits);
736                 }
737 
738                 Ok(())
739             }
740         }
741     }
742 
743     /// Releases a permit back to the semaphore
release(&mut self, n: u16, semaphore: &Semaphore)744     pub(crate) fn release(&mut self, n: u16, semaphore: &Semaphore) {
745         let n = self.forget(n);
746         semaphore.add_permits(n as usize);
747     }
748 
749     /// Forgets the permit **without** releasing it back to the semaphore.
750     ///
751     /// After calling `forget`, `poll_acquire` is able to acquire new permit
752     /// from the sempahore.
753     ///
754     /// Repeatedly calling `forget` without associated calls to `add_permit`
755     /// will result in the semaphore losing all permits.
756     ///
757     /// Will forget **at most** the number of acquired permits. This number is
758     /// returned.
forget(&mut self, n: u16) -> u16759     pub(crate) fn forget(&mut self, n: u16) -> u16 {
760         use PermitState::*;
761 
762         match self.state {
763             Waiting(requested) => {
764                 let n = cmp::min(n, requested);
765 
766                 // Decrement
767                 let acquired = self
768                     .waiter
769                     .as_ref()
770                     .unwrap()
771                     .try_dec_permits_to_acquire(n as usize) as u16;
772 
773                 if n == requested {
774                     self.state = Acquired(0);
775                 } else if acquired == requested - n {
776                     self.state = Waiting(acquired);
777                 } else {
778                     self.state = Waiting(requested - n);
779                 }
780 
781                 acquired
782             }
783             Acquired(acquired) => {
784                 let n = cmp::min(n, acquired);
785                 self.state = Acquired(acquired - n);
786                 n
787             }
788         }
789     }
790 }
791 
792 impl Default for Permit {
default() -> Self793     fn default() -> Self {
794         Self::new()
795     }
796 }
797 
798 impl Drop for Permit {
drop(&mut self)799     fn drop(&mut self) {
800         if let Some(waiter) = self.waiter.take() {
801             // Set the dropped flag
802             let state = WaiterState(waiter.state.fetch_or(DROPPED, AcqRel));
803 
804             if state.is_queued() {
805                 // The waiter is stored in the queue. The semaphore will drop it
806                 std::mem::forget(waiter);
807             }
808         }
809     }
810 }
811 
812 // ===== impl AcquireError ====
813 
814 impl AcquireError {
closed() -> AcquireError815     fn closed() -> AcquireError {
816         AcquireError(())
817     }
818 }
819 
to_try_acquire(_: AcquireError) -> TryAcquireError820 fn to_try_acquire(_: AcquireError) -> TryAcquireError {
821     TryAcquireError::Closed
822 }
823 
824 impl fmt::Display for AcquireError {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result825     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
826         write!(fmt, "semaphore closed")
827     }
828 }
829 
830 impl std::error::Error for AcquireError {}
831 
832 // ===== impl TryAcquireError =====
833 
834 impl TryAcquireError {
835     /// Returns `true` if the error was caused by a closed semaphore.
is_closed(&self) -> bool836     pub(crate) fn is_closed(&self) -> bool {
837         match self {
838             TryAcquireError::Closed => true,
839             _ => false,
840         }
841     }
842 
843     /// Returns `true` if the error was caused by calling `try_acquire` on a
844     /// semaphore with no available permits.
is_no_permits(&self) -> bool845     pub(crate) fn is_no_permits(&self) -> bool {
846         match self {
847             TryAcquireError::NoPermits => true,
848             _ => false,
849         }
850     }
851 }
852 
853 impl fmt::Display for TryAcquireError {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result854     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
855         match self {
856             TryAcquireError::Closed => write!(fmt, "{}", "semaphore closed"),
857             TryAcquireError::NoPermits => write!(fmt, "{}", "no permits available"),
858         }
859     }
860 }
861 
862 impl std::error::Error for TryAcquireError {}
863 
864 // ===== impl Waiter =====
865 
866 impl Waiter {
new() -> Waiter867     fn new() -> Waiter {
868         Waiter {
869             state: AtomicUsize::new(0),
870             waker: AtomicWaker::new(),
871             next: AtomicPtr::new(ptr::null_mut()),
872         }
873     }
874 
permits_to_acquire(&self) -> Result<usize, AcquireError>875     fn permits_to_acquire(&self) -> Result<usize, AcquireError> {
876         let state = WaiterState(self.state.load(Acquire));
877 
878         if state.is_closed() {
879             Err(AcquireError(()))
880         } else {
881             Ok(state.permits_to_acquire())
882         }
883     }
884 
885     /// Only increments the number of permits *if* the waiter is currently
886     /// queued.
887     ///
888     /// # Returns
889     ///
890     /// `true` if the number of permits to acquire has been incremented. `false`
891     /// otherwise. On `false`, the caller should use `Semaphore::poll_acquire`.
try_inc_permits_to_acquire(&self, n: usize) -> bool892     fn try_inc_permits_to_acquire(&self, n: usize) -> bool {
893         let mut curr = WaiterState(self.state.load(Acquire));
894 
895         loop {
896             if !curr.is_queued() {
897                 assert_eq!(0, curr.permits_to_acquire());
898                 return false;
899             }
900 
901             let mut next = curr;
902             next.set_permits_to_acquire(n + curr.permits_to_acquire());
903 
904             match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
905                 Ok(_) => return true,
906                 Err(actual) => curr = WaiterState(actual),
907             }
908         }
909     }
910 
911     /// Try to decrement the number of permits to acquire. This returns the
912     /// actual number of permits that were decremented. The delta betweeen `n`
913     /// and the return has been assigned to the permit and the caller must
914     /// assign these back to the semaphore.
try_dec_permits_to_acquire(&self, n: usize) -> usize915     fn try_dec_permits_to_acquire(&self, n: usize) -> usize {
916         let mut curr = WaiterState(self.state.load(Acquire));
917 
918         loop {
919             if !curr.is_queued() {
920                 assert_eq!(0, curr.permits_to_acquire());
921             }
922 
923             let delta = cmp::min(n, curr.permits_to_acquire());
924             let rem = curr.permits_to_acquire() - delta;
925 
926             let mut next = curr;
927             next.set_permits_to_acquire(rem);
928 
929             match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
930                 Ok(_) => return n - delta,
931                 Err(actual) => curr = WaiterState(actual),
932             }
933         }
934     }
935 
936     /// Store the number of remaining permits needed to satisfy the waiter and
937     /// transition to the "QUEUED" state.
938     ///
939     /// # Returns
940     ///
941     /// `true` if the `QUEUED` bit was set as part of the transition.
to_queued(&self, num_permits: usize) -> bool942     fn to_queued(&self, num_permits: usize) -> bool {
943         let mut curr = WaiterState(self.state.load(Acquire));
944 
945         // The waiter should **not** be waiting for any permits.
946         debug_assert_eq!(curr.permits_to_acquire(), 0);
947 
948         loop {
949             let mut next = curr;
950             next.set_permits_to_acquire(num_permits);
951             next.set_queued();
952 
953             match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
954                 Ok(_) => {
955                     if curr.is_queued() {
956                         return false;
957                     } else {
958                         // Make sure the next pointer is null
959                         self.next.store(ptr::null_mut(), Relaxed);
960                         return true;
961                     }
962                 }
963                 Err(actual) => curr = WaiterState(actual),
964             }
965         }
966     }
967 
968     /// Set the number of permits to acquire.
969     ///
970     /// This function is only called when the waiter is being inserted into the
971     /// wait queue. Because of this, there are no concurrent threads that can
972     /// modify the state and using `store` is safe.
set_permits_to_acquire(&self, num_permits: usize)973     fn set_permits_to_acquire(&self, num_permits: usize) {
974         debug_assert!(WaiterState(self.state.load(Acquire)).is_queued());
975 
976         let mut state = WaiterState(QUEUED);
977         state.set_permits_to_acquire(num_permits);
978 
979         self.state.store(state.0, Release);
980     }
981 
982     /// Assign permits to the waiter.
983     ///
984     /// Returns `true` if the waiter should be removed from the queue
assign_permits(&self, n: &mut usize, closed: bool) -> bool985     fn assign_permits(&self, n: &mut usize, closed: bool) -> bool {
986         let mut curr = WaiterState(self.state.load(Acquire));
987 
988         loop {
989             let mut next = curr;
990 
991             // Number of permits to assign to this waiter
992             let assign = cmp::min(curr.permits_to_acquire(), *n);
993 
994             // Assign the permits
995             next.set_permits_to_acquire(curr.permits_to_acquire() - assign);
996 
997             if closed {
998                 next.set_closed();
999             }
1000 
1001             match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
1002                 Ok(_) => {
1003                     // Update `n`
1004                     *n -= assign;
1005 
1006                     if next.permits_to_acquire() == 0 {
1007                         if curr.permits_to_acquire() > 0 {
1008                             self.waker.wake();
1009                         }
1010 
1011                         return true;
1012                     } else {
1013                         return false;
1014                     }
1015                 }
1016                 Err(actual) => curr = WaiterState(actual),
1017             }
1018         }
1019     }
1020 
revert_to_idle(&self)1021     fn revert_to_idle(&self) {
1022         // An idle node is not waiting on any permits
1023         self.state.store(0, Relaxed);
1024     }
1025 
store_next(&self, next: NonNull<Waiter>)1026     fn store_next(&self, next: NonNull<Waiter>) {
1027         self.next.store(next.as_ptr(), Release);
1028     }
1029 }
1030 
1031 // ===== impl SemState =====
1032 
1033 impl SemState {
1034     /// Returns a new default `State` value.
new(permits: usize, stub: &Waiter) -> SemState1035     fn new(permits: usize, stub: &Waiter) -> SemState {
1036         assert!(permits <= MAX_PERMITS);
1037 
1038         if permits > 0 {
1039             SemState((permits << NUM_SHIFT) | NUM_FLAG)
1040         } else {
1041             SemState(stub as *const _ as usize)
1042         }
1043     }
1044 
1045     /// Returns a `State` tracking `ptr` as the tail of the queue.
new_ptr(tail: NonNull<Waiter>, closed: bool) -> SemState1046     fn new_ptr(tail: NonNull<Waiter>, closed: bool) -> SemState {
1047         let mut val = tail.as_ptr() as usize;
1048 
1049         if closed {
1050             val |= CLOSED_FLAG;
1051         }
1052 
1053         SemState(val)
1054     }
1055 
1056     /// Returns the amount of remaining capacity
available_permits(self) -> usize1057     fn available_permits(self) -> usize {
1058         if !self.has_available_permits() {
1059             return 0;
1060         }
1061 
1062         self.0 >> NUM_SHIFT
1063     }
1064 
1065     /// Returns `true` if the state has permits that can be claimed by a waiter.
has_available_permits(self) -> bool1066     fn has_available_permits(self) -> bool {
1067         self.0 & NUM_FLAG == NUM_FLAG
1068     }
1069 
has_waiter(self, stub: &Waiter) -> bool1070     fn has_waiter(self, stub: &Waiter) -> bool {
1071         !self.has_available_permits() && !self.is_stub(stub)
1072     }
1073 
1074     /// Tries to atomically acquire specified number of permits.
1075     ///
1076     /// # Return
1077     ///
1078     /// Returns `true` if the specified number of permits were acquired, `false`
1079     /// otherwise. Returning false does not mean that there are no more
1080     /// available permits.
acquire_permits(&mut self, num: usize, stub: &Waiter) -> bool1081     fn acquire_permits(&mut self, num: usize, stub: &Waiter) -> bool {
1082         debug_assert!(num > 0);
1083 
1084         if self.available_permits() < num {
1085             return false;
1086         }
1087 
1088         debug_assert!(self.waiter().is_none());
1089 
1090         self.0 -= num << NUM_SHIFT;
1091 
1092         if self.0 == NUM_FLAG {
1093             // Set the state to the stub pointer.
1094             self.0 = stub as *const _ as usize;
1095         }
1096 
1097         true
1098     }
1099 
1100     /// Releases permits
1101     ///
1102     /// Returns `true` if the permits were accepted.
release_permits(&mut self, permits: usize, stub: &Waiter)1103     fn release_permits(&mut self, permits: usize, stub: &Waiter) {
1104         debug_assert!(permits > 0);
1105 
1106         if self.is_stub(stub) {
1107             self.0 = (permits << NUM_SHIFT) | NUM_FLAG | (self.0 & CLOSED_FLAG);
1108             return;
1109         }
1110 
1111         debug_assert!(self.has_available_permits());
1112 
1113         self.0 += permits << NUM_SHIFT;
1114     }
1115 
is_waiter(self) -> bool1116     fn is_waiter(self) -> bool {
1117         self.0 & NUM_FLAG == 0
1118     }
1119 
1120     /// Returns the waiter, if one is set.
waiter(self) -> Option<NonNull<Waiter>>1121     fn waiter(self) -> Option<NonNull<Waiter>> {
1122         if self.is_waiter() {
1123             let waiter = NonNull::new(self.as_ptr()).expect("null pointer stored");
1124 
1125             Some(waiter)
1126         } else {
1127             None
1128         }
1129     }
1130 
1131     /// Assumes `self` represents a pointer
as_ptr(self) -> *mut Waiter1132     fn as_ptr(self) -> *mut Waiter {
1133         (self.0 & !CLOSED_FLAG) as *mut Waiter
1134     }
1135 
1136     /// Sets to a pointer to a waiter.
1137     ///
1138     /// This can only be done from the full state.
set_waiter(&mut self, waiter: NonNull<Waiter>)1139     fn set_waiter(&mut self, waiter: NonNull<Waiter>) {
1140         let waiter = waiter.as_ptr() as usize;
1141         debug_assert!(!self.is_closed());
1142 
1143         self.0 = waiter;
1144     }
1145 
is_stub(self, stub: &Waiter) -> bool1146     fn is_stub(self, stub: &Waiter) -> bool {
1147         self.as_ptr() as usize == stub as *const _ as usize
1148     }
1149 
1150     /// Loads the state from an AtomicUsize.
load(cell: &AtomicUsize, ordering: Ordering) -> SemState1151     fn load(cell: &AtomicUsize, ordering: Ordering) -> SemState {
1152         let value = cell.load(ordering);
1153         SemState(value)
1154     }
1155 
fetch_set_closed(cell: &AtomicUsize, ordering: Ordering) -> SemState1156     fn fetch_set_closed(cell: &AtomicUsize, ordering: Ordering) -> SemState {
1157         let value = cell.fetch_or(CLOSED_FLAG, ordering);
1158         SemState(value)
1159     }
1160 
is_closed(self) -> bool1161     fn is_closed(self) -> bool {
1162         self.0 & CLOSED_FLAG == CLOSED_FLAG
1163     }
1164 
1165     /// Converts the state into a `usize` representation.
to_usize(self) -> usize1166     fn to_usize(self) -> usize {
1167         self.0
1168     }
1169 }
1170 
1171 impl fmt::Debug for SemState {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result1172     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
1173         let mut fmt = fmt.debug_struct("SemState");
1174 
1175         if self.is_waiter() {
1176             fmt.field("state", &"<waiter>");
1177         } else {
1178             fmt.field("permits", &self.available_permits());
1179         }
1180 
1181         fmt.finish()
1182     }
1183 }
1184 
1185 // ===== impl WaiterState =====
1186 
1187 impl WaiterState {
permits_to_acquire(self) -> usize1188     fn permits_to_acquire(self) -> usize {
1189         self.0 >> PERMIT_SHIFT
1190     }
1191 
set_permits_to_acquire(&mut self, val: usize)1192     fn set_permits_to_acquire(&mut self, val: usize) {
1193         self.0 = (val << PERMIT_SHIFT) | (self.0 & !PERMIT_MASK)
1194     }
1195 
is_queued(self) -> bool1196     fn is_queued(self) -> bool {
1197         self.0 & QUEUED == QUEUED
1198     }
1199 
set_queued(&mut self)1200     fn set_queued(&mut self) {
1201         self.0 |= QUEUED;
1202     }
1203 
is_closed(self) -> bool1204     fn is_closed(self) -> bool {
1205         self.0 & CLOSED == CLOSED
1206     }
1207 
set_closed(&mut self)1208     fn set_closed(&mut self) {
1209         self.0 |= CLOSED;
1210     }
1211 
unset_queued(&mut self)1212     fn unset_queued(&mut self) {
1213         assert!(self.is_queued());
1214         self.0 -= QUEUED;
1215     }
1216 
is_dropped(self) -> bool1217     fn is_dropped(self) -> bool {
1218         self.0 & DROPPED == DROPPED
1219     }
1220 }
1221