1 #![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))]
2 
3 //! Thread-safe, asynchronous counting semaphore.
4 //!
5 //! A `Semaphore` instance holds a set of permits. Permits are used to
6 //! synchronize access to a shared resource.
7 //!
8 //! Before accessing the shared resource, callers acquire a permit from the
9 //! semaphore. Once the permit is acquired, the caller then enters the critical
10 //! section. If no permits are available, then acquiring the semaphore returns
11 //! `Pending`. The task is woken once a permit becomes available.
12 
13 use crate::loom::cell::UnsafeCell;
14 use crate::loom::future::AtomicWaker;
15 use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize};
16 use crate::loom::thread;
17 
18 use std::cmp;
19 use std::fmt;
20 use std::ptr::{self, NonNull};
21 use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed, Release};
22 use std::task::Poll::{Pending, Ready};
23 use std::task::{Context, Poll};
24 use std::usize;
25 
26 /// Futures-aware semaphore.
27 pub(crate) struct Semaphore {
28     /// Tracks both the waiter queue tail pointer and the number of remaining
29     /// permits.
30     state: AtomicUsize,
31 
32     /// waiter queue head pointer.
33     head: UnsafeCell<NonNull<Waiter>>,
34 
35     /// Coordinates access to the queue head.
36     rx_lock: AtomicUsize,
37 
38     /// Stub waiter node used as part of the MPSC channel algorithm.
39     stub: Box<Waiter>,
40 }
41 
42 /// A semaphore permit
43 ///
44 /// Tracks the lifecycle of a semaphore permit.
45 ///
46 /// An instance of `Permit` is intended to be used with a **single** instance of
47 /// `Semaphore`. Using a single instance of `Permit` with multiple semaphore
48 /// instances will result in unexpected behavior.
49 ///
50 /// `Permit` does **not** release the permit back to the semaphore on drop. It
51 /// is the user's responsibility to ensure that `Permit::release` is called
52 /// before dropping the permit.
53 #[derive(Debug)]
54 pub(crate) struct Permit {
55     waiter: Option<Box<Waiter>>,
56     state: PermitState,
57 }
58 
59 /// Error returned by `Permit::poll_acquire`.
60 #[derive(Debug)]
61 pub(crate) struct AcquireError(());
62 
63 /// Error returned by `Permit::try_acquire`.
64 #[derive(Debug)]
65 pub(crate) enum TryAcquireError {
66     Closed,
67     NoPermits,
68 }
69 
70 /// Node used to notify the semaphore waiter when permit is available.
71 #[derive(Debug)]
72 struct Waiter {
73     /// Stores waiter state.
74     ///
75     /// See `WaiterState` for more details.
76     state: AtomicUsize,
77 
78     /// Task to wake when a permit is made available.
79     waker: AtomicWaker,
80 
81     /// Next pointer in the queue of waiting senders.
82     next: AtomicPtr<Waiter>,
83 }
84 
85 /// Semaphore state
86 ///
87 /// The 2 low bits track the modes.
88 ///
89 /// - Closed
90 /// - Full
91 ///
92 /// When not full, the rest of the `usize` tracks the total number of messages
93 /// in the channel. When full, the rest of the `usize` is a pointer to the tail
94 /// of the "waiting senders" queue.
95 #[derive(Copy, Clone)]
96 struct SemState(usize);
97 
98 /// Permit state
99 #[derive(Debug, Copy, Clone)]
100 enum PermitState {
101     /// Currently waiting for permits to be made available and assigned to the
102     /// waiter.
103     Waiting(u16),
104 
105     /// The number of acquired permits
106     Acquired(u16),
107 }
108 
109 /// State for an individual waker node
110 #[derive(Debug, Copy, Clone)]
111 struct WaiterState(usize);
112 
113 /// Waiter node is in the semaphore queue
114 const QUEUED: usize = 0b001;
115 
116 /// Semaphore has been closed, no more permits will be issued.
117 const CLOSED: usize = 0b10;
118 
119 /// The permit that owns the `Waiter` dropped.
120 const DROPPED: usize = 0b100;
121 
122 /// Represents "one requested permit" in the waiter state
123 const PERMIT_ONE: usize = 0b1000;
124 
125 /// Masks the waiter state to only contain bits tracking number of requested
126 /// permits.
127 const PERMIT_MASK: usize = usize::MAX - (PERMIT_ONE - 1);
128 
129 /// How much to shift a permit count to pack it into the waker state
130 const PERMIT_SHIFT: u32 = PERMIT_ONE.trailing_zeros();
131 
132 /// Flag differentiating between available permits and waiter pointers.
133 ///
134 /// If we assume pointers are properly aligned, then the least significant bit
135 /// will always be zero. So, we use that bit to track if the value represents a
136 /// number.
137 const NUM_FLAG: usize = 0b01;
138 
139 /// Signal the semaphore is closed
140 const CLOSED_FLAG: usize = 0b10;
141 
142 /// Maximum number of permits a semaphore can manage
143 const MAX_PERMITS: usize = usize::MAX >> NUM_SHIFT;
144 
145 /// When representing "numbers", the state has to be shifted this much (to get
146 /// rid of the flag bit).
147 const NUM_SHIFT: usize = 2;
148 
149 // ===== impl Semaphore =====
150 
151 impl Semaphore {
152     /// Creates a new semaphore with the initial number of permits
153     ///
154     /// # Panics
155     ///
156     /// Panics if `permits` is zero.
new(permits: usize) -> Semaphore157     pub(crate) fn new(permits: usize) -> Semaphore {
158         let stub = Box::new(Waiter::new());
159         let ptr = NonNull::from(&*stub);
160 
161         // Allocations are aligned
162         debug_assert!(ptr.as_ptr() as usize & NUM_FLAG == 0);
163 
164         let state = SemState::new(permits, &stub);
165 
166         Semaphore {
167             state: AtomicUsize::new(state.to_usize()),
168             head: UnsafeCell::new(ptr),
169             rx_lock: AtomicUsize::new(0),
170             stub,
171         }
172     }
173 
174     /// Returns the current number of available permits
available_permits(&self) -> usize175     pub(crate) fn available_permits(&self) -> usize {
176         let curr = SemState(self.state.load(Acquire));
177         curr.available_permits()
178     }
179 
180     /// Tries to acquire the requested number of permits, registering the waiter
181     /// if not enough permits are available.
poll_acquire( &self, cx: &mut Context<'_>, num_permits: u16, permit: &mut Permit, ) -> Poll<Result<(), AcquireError>>182     fn poll_acquire(
183         &self,
184         cx: &mut Context<'_>,
185         num_permits: u16,
186         permit: &mut Permit,
187     ) -> Poll<Result<(), AcquireError>> {
188         self.poll_acquire2(num_permits, || {
189             let waiter = permit.waiter.get_or_insert_with(|| Box::new(Waiter::new()));
190 
191             waiter.waker.register_by_ref(cx.waker());
192 
193             Some(NonNull::from(&**waiter))
194         })
195     }
196 
try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError>197     fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> {
198         match self.poll_acquire2(num_permits, || None) {
199             Poll::Ready(res) => res.map_err(to_try_acquire),
200             Poll::Pending => Err(TryAcquireError::NoPermits),
201         }
202     }
203 
204     /// Polls for a permit
205     ///
206     /// Tries to acquire available permits first. If unable to acquire a
207     /// sufficient number of permits, the caller's waiter is pushed onto the
208     /// semaphore's wait queue.
poll_acquire2<F>( &self, num_permits: u16, mut get_waiter: F, ) -> Poll<Result<(), AcquireError>> where F: FnMut() -> Option<NonNull<Waiter>>,209     fn poll_acquire2<F>(
210         &self,
211         num_permits: u16,
212         mut get_waiter: F,
213     ) -> Poll<Result<(), AcquireError>>
214     where
215         F: FnMut() -> Option<NonNull<Waiter>>,
216     {
217         let num_permits = num_permits as usize;
218 
219         // Load the current state
220         let mut curr = SemState(self.state.load(Acquire));
221 
222         // Saves a ref to the waiter node
223         let mut maybe_waiter: Option<NonNull<Waiter>> = None;
224 
225         /// Used in branches where we attempt to push the waiter into the wait
226         /// queue but fail due to permits becoming available or the wait queue
227         /// transitioning to "closed". In this case, the waiter must be
228         /// transitioned back to the "idle" state.
229         macro_rules! revert_to_idle {
230             () => {
231                 if let Some(waiter) = maybe_waiter {
232                     unsafe { waiter.as_ref() }.revert_to_idle();
233                 }
234             };
235         }
236 
237         loop {
238             let mut next = curr;
239 
240             if curr.is_closed() {
241                 revert_to_idle!();
242                 return Ready(Err(AcquireError::closed()));
243             }
244 
245             let acquired = next.acquire_permits(num_permits, &self.stub);
246 
247             if !acquired {
248                 // There are not enough available permits to satisfy the
249                 // request. The permit transitions to a waiting state.
250                 debug_assert!(curr.waiter().is_some() || curr.available_permits() < num_permits);
251 
252                 if let Some(waiter) = maybe_waiter.as_ref() {
253                     // Safety: the caller owns the waiter.
254                     let w = unsafe { waiter.as_ref() };
255                     w.set_permits_to_acquire(num_permits - curr.available_permits());
256                 } else {
257                     // Get the waiter for the permit.
258                     if let Some(waiter) = get_waiter() {
259                         // Safety: the caller owns the waiter.
260                         let w = unsafe { waiter.as_ref() };
261 
262                         // If there are any currently available permits, the
263                         // waiter acquires those immediately and waits for the
264                         // remaining permits to become available.
265                         if !w.to_queued(num_permits - curr.available_permits()) {
266                             // The node is alrady queued, there is no further work
267                             // to do.
268                             return Pending;
269                         }
270 
271                         maybe_waiter = Some(waiter);
272                     } else {
273                         // No waiter, this indicates the caller does not wish to
274                         // "wait", so there is nothing left to do.
275                         return Pending;
276                     }
277                 }
278 
279                 next.set_waiter(maybe_waiter.unwrap());
280             }
281 
282             debug_assert_ne!(curr.0, 0);
283             debug_assert_ne!(next.0, 0);
284 
285             match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
286                 Ok(_) => {
287                     if acquired {
288                         // Successfully acquire permits **without** queuing the
289                         // waiter node. The waiter node is not currently in the
290                         // queue.
291                         revert_to_idle!();
292                         return Ready(Ok(()));
293                     } else {
294                         // The node is pushed into the queue, the final step is
295                         // to set the node's "next" pointer to return the wait
296                         // queue into a consistent state.
297 
298                         let prev_waiter =
299                             curr.waiter().unwrap_or_else(|| NonNull::from(&*self.stub));
300 
301                         let waiter = maybe_waiter.unwrap();
302 
303                         // Link the nodes.
304                         //
305                         // Safety: the mpsc algorithm guarantees the old tail of
306                         // the queue is not removed from the queue during the
307                         // push process.
308                         unsafe {
309                             prev_waiter.as_ref().store_next(waiter);
310                         }
311 
312                         return Pending;
313                     }
314                 }
315                 Err(actual) => {
316                     curr = SemState(actual);
317                 }
318             }
319         }
320     }
321 
322     /// Closes the semaphore. This prevents the semaphore from issuing new
323     /// permits and notifies all pending waiters.
close(&self)324     pub(crate) fn close(&self) {
325         // Acquire the `rx_lock`, setting the "closed" flag on the lock.
326         let prev = self.rx_lock.fetch_or(1, AcqRel);
327 
328         if prev != 0 {
329             // Another thread has the lock and will be responsible for notifying
330             // pending waiters.
331             return;
332         }
333 
334         self.add_permits_locked(0, true);
335     }
336     /// Adds `n` new permits to the semaphore.
337     ///
338     /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded.
add_permits(&self, n: usize)339     pub(crate) fn add_permits(&self, n: usize) {
340         if n == 0 {
341             return;
342         }
343 
344         // TODO: Handle overflow. A panic is not sufficient, the process must
345         // abort.
346         let prev = self.rx_lock.fetch_add(n << 1, AcqRel);
347 
348         if prev != 0 {
349             // Another thread has the lock and will be responsible for notifying
350             // pending waiters.
351             return;
352         }
353 
354         self.add_permits_locked(n, false);
355     }
356 
add_permits_locked(&self, mut rem: usize, mut closed: bool)357     fn add_permits_locked(&self, mut rem: usize, mut closed: bool) {
358         while rem > 0 || closed {
359             if closed {
360                 SemState::fetch_set_closed(&self.state, AcqRel);
361             }
362 
363             // Release the permits and notify
364             self.add_permits_locked2(rem, closed);
365 
366             let n = rem << 1;
367 
368             let actual = if closed {
369                 let actual = self.rx_lock.fetch_sub(n | 1, AcqRel);
370                 closed = false;
371                 actual
372             } else {
373                 let actual = self.rx_lock.fetch_sub(n, AcqRel);
374                 closed = actual & 1 == 1;
375                 actual
376             };
377 
378             rem = (actual >> 1) - rem;
379         }
380     }
381 
382     /// Releases a specific amount of permits to the semaphore
383     ///
384     /// This function is called by `add_permits` after the add lock has been
385     /// acquired.
add_permits_locked2(&self, mut n: usize, closed: bool)386     fn add_permits_locked2(&self, mut n: usize, closed: bool) {
387         // If closing the semaphore, we want to drain the entire queue. The
388         // number of permits being assigned doesn't matter.
389         if closed {
390             n = usize::MAX;
391         }
392 
393         'outer: while n > 0 {
394             unsafe {
395                 let mut head = self.head.with(|head| *head);
396                 let mut next_ptr = head.as_ref().next.load(Acquire);
397 
398                 let stub = self.stub();
399 
400                 if head == stub {
401                     // The stub node indicates an empty queue. Any remaining
402                     // permits get assigned back to the semaphore.
403                     let next = match NonNull::new(next_ptr) {
404                         Some(next) => next,
405                         None => {
406                             // This loop is not part of the standard intrusive mpsc
407                             // channel algorithm. This is where we atomically pop
408                             // the last task and add `n` to the remaining capacity.
409                             //
410                             // This modification to the pop algorithm works because,
411                             // at this point, we have not done any work (only done
412                             // reading). We have a *pretty* good idea that there is
413                             // no concurrent pusher.
414                             //
415                             // The capacity is then atomically added by doing an
416                             // AcqRel CAS on `state`. The `state` cell is the
417                             // linchpin of the algorithm.
418                             //
419                             // By successfully CASing `head` w/ AcqRel, we ensure
420                             // that, if any thread was racing and entered a push, we
421                             // see that and abort pop, retrying as it is
422                             // "inconsistent".
423                             let mut curr = SemState::load(&self.state, Acquire);
424 
425                             loop {
426                                 if curr.has_waiter(&self.stub) {
427                                     // A waiter is being added concurrently.
428                                     // This is the MPSC queue's "inconsistent"
429                                     // state and we must loop and try again.
430                                     thread::yield_now();
431                                     continue 'outer;
432                                 }
433 
434                                 // If closing, nothing more to do.
435                                 if closed {
436                                     debug_assert!(curr.is_closed(), "state = {:?}", curr);
437                                     return;
438                                 }
439 
440                                 let mut next = curr;
441                                 next.release_permits(n, &self.stub);
442 
443                                 match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
444                                     Ok(_) => return,
445                                     Err(actual) => {
446                                         curr = SemState(actual);
447                                     }
448                                 }
449                             }
450                         }
451                     };
452 
453                     self.head.with_mut(|head| *head = next);
454                     head = next;
455                     next_ptr = next.as_ref().next.load(Acquire);
456                 }
457 
458                 // `head` points to a waiter assign permits to the waiter. If
459                 // all requested permits are satisfied, then we can continue,
460                 // otherwise the node stays in the wait queue.
461                 if !head.as_ref().assign_permits(&mut n, closed) {
462                     assert_eq!(n, 0);
463                     return;
464                 }
465 
466                 if let Some(next) = NonNull::new(next_ptr) {
467                     self.head.with_mut(|head| *head = next);
468 
469                     self.remove_queued(head, closed);
470                     continue 'outer;
471                 }
472 
473                 let state = SemState::load(&self.state, Acquire);
474 
475                 // This must always be a pointer as the wait list is not empty.
476                 let tail = state.waiter().unwrap();
477 
478                 if tail != head {
479                     // Inconsistent
480                     thread::yield_now();
481                     continue 'outer;
482                 }
483 
484                 self.push_stub(closed);
485 
486                 next_ptr = head.as_ref().next.load(Acquire);
487 
488                 if let Some(next) = NonNull::new(next_ptr) {
489                     self.head.with_mut(|head| *head = next);
490 
491                     self.remove_queued(head, closed);
492                     continue 'outer;
493                 }
494 
495                 // Inconsistent state, loop
496                 thread::yield_now();
497             }
498         }
499     }
500 
501     /// The wait node has had all of its permits assigned and has been removed
502     /// from the wait queue.
503     ///
504     /// Attempt to remove the QUEUED bit from the node. If additional permits
505     /// are concurrently requested, the node must be pushed back into the wait
506     /// queued.
remove_queued(&self, waiter: NonNull<Waiter>, closed: bool)507     fn remove_queued(&self, waiter: NonNull<Waiter>, closed: bool) {
508         let mut curr = WaiterState(unsafe { waiter.as_ref() }.state.load(Acquire));
509 
510         loop {
511             if curr.is_dropped() {
512                 // The Permit dropped, it is on us to release the memory
513                 let _ = unsafe { Box::from_raw(waiter.as_ptr()) };
514                 return;
515             }
516 
517             // The node is removed from the queue. We attempt to unset the
518             // queued bit, but concurrently the waiter has requested more
519             // permits. When the waiter requested more permits, it saw the
520             // queued bit set so took no further action. This requires us to
521             // push the node back into the queue.
522             if curr.permits_to_acquire() > 0 {
523                 // More permits are requested. The waiter must be re-queued
524                 unsafe {
525                     self.push_waiter(waiter, closed);
526                 }
527                 return;
528             }
529 
530             let mut next = curr;
531             next.unset_queued();
532 
533             let w = unsafe { waiter.as_ref() };
534 
535             match w.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
536                 Ok(_) => return,
537                 Err(actual) => {
538                     curr = WaiterState(actual);
539                 }
540             }
541         }
542     }
543 
push_stub(&self, closed: bool)544     unsafe fn push_stub(&self, closed: bool) {
545         self.push_waiter(self.stub(), closed);
546     }
547 
push_waiter(&self, waiter: NonNull<Waiter>, closed: bool)548     unsafe fn push_waiter(&self, waiter: NonNull<Waiter>, closed: bool) {
549         // Set the next pointer. This does not require an atomic operation as
550         // this node is not accessible. The write will be flushed with the next
551         // operation
552         waiter.as_ref().next.store(ptr::null_mut(), Relaxed);
553 
554         // Update the tail to point to the new node. We need to see the previous
555         // node in order to update the next pointer as well as release `task`
556         // to any other threads calling `push`.
557         let next = SemState::new_ptr(waiter, closed);
558         let prev = SemState(self.state.swap(next.0, AcqRel));
559 
560         debug_assert_eq!(closed, prev.is_closed());
561 
562         // This function is only called when there are pending tasks. Because of
563         // this, the state must *always* be in pointer mode.
564         let prev = prev.waiter().unwrap();
565 
566         // No cycles plz
567         debug_assert_ne!(prev, waiter);
568 
569         // Release `task` to the consume end.
570         prev.as_ref().next.store(waiter.as_ptr(), Release);
571     }
572 
stub(&self) -> NonNull<Waiter>573     fn stub(&self) -> NonNull<Waiter> {
574         unsafe { NonNull::new_unchecked(&*self.stub as *const _ as *mut _) }
575     }
576 }
577 
578 impl Drop for Semaphore {
drop(&mut self)579     fn drop(&mut self) {
580         self.close();
581     }
582 }
583 
584 impl fmt::Debug for Semaphore {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result585     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
586         fmt.debug_struct("Semaphore")
587             .field("state", &SemState::load(&self.state, Relaxed))
588             .field("head", &self.head.with(|ptr| ptr))
589             .field("rx_lock", &self.rx_lock.load(Relaxed))
590             .field("stub", &self.stub)
591             .finish()
592     }
593 }
594 
595 unsafe impl Send for Semaphore {}
596 unsafe impl Sync for Semaphore {}
597 
598 // ===== impl Permit =====
599 
600 impl Permit {
601     /// Creates a new `Permit`.
602     ///
603     /// The permit begins in the "unacquired" state.
new() -> Permit604     pub(crate) fn new() -> Permit {
605         use PermitState::Acquired;
606 
607         Permit {
608             waiter: None,
609             state: Acquired(0),
610         }
611     }
612 
613     /// Returns `true` if the permit has been acquired
614     #[allow(dead_code)] // may be used later
is_acquired(&self) -> bool615     pub(crate) fn is_acquired(&self) -> bool {
616         match self.state {
617             PermitState::Acquired(num) if num > 0 => true,
618             _ => false,
619         }
620     }
621 
622     /// Tries to acquire the permit. If no permits are available, the current task
623     /// is notified once a new permit becomes available.
poll_acquire( &mut self, cx: &mut Context<'_>, num_permits: u16, semaphore: &Semaphore, ) -> Poll<Result<(), AcquireError>>624     pub(crate) fn poll_acquire(
625         &mut self,
626         cx: &mut Context<'_>,
627         num_permits: u16,
628         semaphore: &Semaphore,
629     ) -> Poll<Result<(), AcquireError>> {
630         use std::cmp::Ordering::*;
631         use PermitState::*;
632 
633         match self.state {
634             Waiting(requested) => {
635                 // There must be a waiter
636                 let waiter = self.waiter.as_ref().unwrap();
637 
638                 match requested.cmp(&num_permits) {
639                     Less => {
640                         let delta = num_permits - requested;
641 
642                         // Request additional permits. If the waiter has been
643                         // dequeued, it must be re-queued.
644                         if !waiter.try_inc_permits_to_acquire(delta as usize) {
645                             let waiter = NonNull::from(&**waiter);
646 
647                             // Ignore the result. The check for
648                             // `permits_to_acquire()` will converge the state as
649                             // needed
650                             let _ = semaphore.poll_acquire2(delta, || Some(waiter))?;
651                         }
652 
653                         self.state = Waiting(num_permits);
654                     }
655                     Greater => {
656                         let delta = requested - num_permits;
657                         let to_release = waiter.try_dec_permits_to_acquire(delta as usize);
658 
659                         semaphore.add_permits(to_release);
660                         self.state = Waiting(num_permits);
661                     }
662                     Equal => {}
663                 }
664 
665                 if waiter.permits_to_acquire()? == 0 {
666                     self.state = Acquired(requested);
667                     return Ready(Ok(()));
668                 }
669 
670                 waiter.waker.register_by_ref(cx.waker());
671 
672                 if waiter.permits_to_acquire()? == 0 {
673                     self.state = Acquired(requested);
674                     return Ready(Ok(()));
675                 }
676 
677                 Pending
678             }
679             Acquired(acquired) => {
680                 if acquired >= num_permits {
681                     Ready(Ok(()))
682                 } else {
683                     match semaphore.poll_acquire(cx, num_permits - acquired, self)? {
684                         Ready(()) => {
685                             self.state = Acquired(num_permits);
686                             Ready(Ok(()))
687                         }
688                         Pending => {
689                             self.state = Waiting(num_permits);
690                             Pending
691                         }
692                     }
693                 }
694             }
695         }
696     }
697 
698     /// Tries to acquire the permit.
try_acquire( &mut self, num_permits: u16, semaphore: &Semaphore, ) -> Result<(), TryAcquireError>699     pub(crate) fn try_acquire(
700         &mut self,
701         num_permits: u16,
702         semaphore: &Semaphore,
703     ) -> Result<(), TryAcquireError> {
704         use PermitState::*;
705 
706         match self.state {
707             Waiting(requested) => {
708                 // There must be a waiter
709                 let waiter = self.waiter.as_ref().unwrap();
710 
711                 if requested > num_permits {
712                     let delta = requested - num_permits;
713                     let to_release = waiter.try_dec_permits_to_acquire(delta as usize);
714 
715                     semaphore.add_permits(to_release);
716                     self.state = Waiting(num_permits);
717                 }
718 
719                 let res = waiter.permits_to_acquire().map_err(to_try_acquire)?;
720 
721                 if res == 0 {
722                     if requested < num_permits {
723                         // Try to acquire the additional permits
724                         semaphore.try_acquire(num_permits - requested)?;
725                     }
726 
727                     self.state = Acquired(num_permits);
728                     Ok(())
729                 } else {
730                     Err(TryAcquireError::NoPermits)
731                 }
732             }
733             Acquired(acquired) => {
734                 if acquired < num_permits {
735                     semaphore.try_acquire(num_permits - acquired)?;
736                     self.state = Acquired(num_permits);
737                 }
738 
739                 Ok(())
740             }
741         }
742     }
743 
744     /// Releases a permit back to the semaphore
release(&mut self, n: u16, semaphore: &Semaphore)745     pub(crate) fn release(&mut self, n: u16, semaphore: &Semaphore) {
746         let n = self.forget(n);
747         semaphore.add_permits(n as usize);
748     }
749 
750     /// Forgets the permit **without** releasing it back to the semaphore.
751     ///
752     /// After calling `forget`, `poll_acquire` is able to acquire new permit
753     /// from the semaphore.
754     ///
755     /// Repeatedly calling `forget` without associated calls to `add_permit`
756     /// will result in the semaphore losing all permits.
757     ///
758     /// Will forget **at most** the number of acquired permits. This number is
759     /// returned.
forget(&mut self, n: u16) -> u16760     pub(crate) fn forget(&mut self, n: u16) -> u16 {
761         use PermitState::*;
762 
763         match self.state {
764             Waiting(requested) => {
765                 let n = cmp::min(n, requested);
766 
767                 // Decrement
768                 let acquired = self
769                     .waiter
770                     .as_ref()
771                     .unwrap()
772                     .try_dec_permits_to_acquire(n as usize) as u16;
773 
774                 if n == requested {
775                     self.state = Acquired(0);
776                 } else if acquired == requested - n {
777                     self.state = Waiting(acquired);
778                 } else {
779                     self.state = Waiting(requested - n);
780                 }
781 
782                 acquired
783             }
784             Acquired(acquired) => {
785                 let n = cmp::min(n, acquired);
786                 self.state = Acquired(acquired - n);
787                 n
788             }
789         }
790     }
791 }
792 
793 impl Default for Permit {
default() -> Self794     fn default() -> Self {
795         Self::new()
796     }
797 }
798 
799 impl Drop for Permit {
drop(&mut self)800     fn drop(&mut self) {
801         if let Some(waiter) = self.waiter.take() {
802             // Set the dropped flag
803             let state = WaiterState(waiter.state.fetch_or(DROPPED, AcqRel));
804 
805             if state.is_queued() {
806                 // The waiter is stored in the queue. The semaphore will drop it
807                 std::mem::forget(waiter);
808             }
809         }
810     }
811 }
812 
813 // ===== impl AcquireError ====
814 
815 impl AcquireError {
closed() -> AcquireError816     fn closed() -> AcquireError {
817         AcquireError(())
818     }
819 }
820 
to_try_acquire(_: AcquireError) -> TryAcquireError821 fn to_try_acquire(_: AcquireError) -> TryAcquireError {
822     TryAcquireError::Closed
823 }
824 
825 impl fmt::Display for AcquireError {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result826     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
827         write!(fmt, "semaphore closed")
828     }
829 }
830 
831 impl std::error::Error for AcquireError {}
832 
833 // ===== impl TryAcquireError =====
834 
835 impl TryAcquireError {
836     /// Returns `true` if the error was caused by a closed semaphore.
is_closed(&self) -> bool837     pub(crate) fn is_closed(&self) -> bool {
838         match self {
839             TryAcquireError::Closed => true,
840             _ => false,
841         }
842     }
843 
844     /// Returns `true` if the error was caused by calling `try_acquire` on a
845     /// semaphore with no available permits.
is_no_permits(&self) -> bool846     pub(crate) fn is_no_permits(&self) -> bool {
847         match self {
848             TryAcquireError::NoPermits => true,
849             _ => false,
850         }
851     }
852 }
853 
854 impl fmt::Display for TryAcquireError {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result855     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
856         match self {
857             TryAcquireError::Closed => write!(fmt, "semaphore closed"),
858             TryAcquireError::NoPermits => write!(fmt, "no permits available"),
859         }
860     }
861 }
862 
863 impl std::error::Error for TryAcquireError {}
864 
865 // ===== impl Waiter =====
866 
867 impl Waiter {
new() -> Waiter868     fn new() -> Waiter {
869         Waiter {
870             state: AtomicUsize::new(0),
871             waker: AtomicWaker::new(),
872             next: AtomicPtr::new(ptr::null_mut()),
873         }
874     }
875 
permits_to_acquire(&self) -> Result<usize, AcquireError>876     fn permits_to_acquire(&self) -> Result<usize, AcquireError> {
877         let state = WaiterState(self.state.load(Acquire));
878 
879         if state.is_closed() {
880             Err(AcquireError(()))
881         } else {
882             Ok(state.permits_to_acquire())
883         }
884     }
885 
886     /// Only increments the number of permits *if* the waiter is currently
887     /// queued.
888     ///
889     /// # Returns
890     ///
891     /// `true` if the number of permits to acquire has been incremented. `false`
892     /// otherwise. On `false`, the caller should use `Semaphore::poll_acquire`.
try_inc_permits_to_acquire(&self, n: usize) -> bool893     fn try_inc_permits_to_acquire(&self, n: usize) -> bool {
894         let mut curr = WaiterState(self.state.load(Acquire));
895 
896         loop {
897             if !curr.is_queued() {
898                 assert_eq!(0, curr.permits_to_acquire());
899                 return false;
900             }
901 
902             let mut next = curr;
903             next.set_permits_to_acquire(n + curr.permits_to_acquire());
904 
905             match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
906                 Ok(_) => return true,
907                 Err(actual) => curr = WaiterState(actual),
908             }
909         }
910     }
911 
912     /// Try to decrement the number of permits to acquire. This returns the
913     /// actual number of permits that were decremented. The delta between `n`
914     /// and the return has been assigned to the permit and the caller must
915     /// assign these back to the semaphore.
try_dec_permits_to_acquire(&self, n: usize) -> usize916     fn try_dec_permits_to_acquire(&self, n: usize) -> usize {
917         let mut curr = WaiterState(self.state.load(Acquire));
918 
919         loop {
920             if curr.is_closed() {
921                 return 0;
922             }
923 
924             if !curr.is_queued() {
925                 assert_eq!(0, curr.permits_to_acquire());
926             }
927 
928             let delta = cmp::min(n, curr.permits_to_acquire());
929             let rem = curr.permits_to_acquire() - delta;
930 
931             let mut next = curr;
932             next.set_permits_to_acquire(rem);
933 
934             match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
935                 Ok(_) => return n - delta,
936                 Err(actual) => curr = WaiterState(actual),
937             }
938         }
939     }
940 
941     /// Store the number of remaining permits needed to satisfy the waiter and
942     /// transition to the "QUEUED" state.
943     ///
944     /// # Returns
945     ///
946     /// `true` if the `QUEUED` bit was set as part of the transition.
to_queued(&self, num_permits: usize) -> bool947     fn to_queued(&self, num_permits: usize) -> bool {
948         let mut curr = WaiterState(self.state.load(Acquire));
949 
950         // The waiter should **not** be waiting for any permits.
951         debug_assert_eq!(curr.permits_to_acquire(), 0);
952 
953         loop {
954             let mut next = curr;
955             next.set_permits_to_acquire(num_permits);
956             next.set_queued();
957 
958             match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
959                 Ok(_) => {
960                     if curr.is_queued() {
961                         return false;
962                     } else {
963                         // Make sure the next pointer is null
964                         self.next.store(ptr::null_mut(), Relaxed);
965                         return true;
966                     }
967                 }
968                 Err(actual) => curr = WaiterState(actual),
969             }
970         }
971     }
972 
973     /// Set the number of permits to acquire.
974     ///
975     /// This function is only called when the waiter is being inserted into the
976     /// wait queue. Because of this, there are no concurrent threads that can
977     /// modify the state and using `store` is safe.
set_permits_to_acquire(&self, num_permits: usize)978     fn set_permits_to_acquire(&self, num_permits: usize) {
979         debug_assert!(WaiterState(self.state.load(Acquire)).is_queued());
980 
981         let mut state = WaiterState(QUEUED);
982         state.set_permits_to_acquire(num_permits);
983 
984         self.state.store(state.0, Release);
985     }
986 
987     /// Assign permits to the waiter.
988     ///
989     /// Returns `true` if the waiter should be removed from the queue
assign_permits(&self, n: &mut usize, closed: bool) -> bool990     fn assign_permits(&self, n: &mut usize, closed: bool) -> bool {
991         let mut curr = WaiterState(self.state.load(Acquire));
992 
993         loop {
994             let mut next = curr;
995 
996             // Number of permits to assign to this waiter
997             let assign = cmp::min(curr.permits_to_acquire(), *n);
998 
999             // Assign the permits
1000             next.set_permits_to_acquire(curr.permits_to_acquire() - assign);
1001 
1002             if closed {
1003                 next.set_closed();
1004             }
1005 
1006             match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
1007                 Ok(_) => {
1008                     // Update `n`
1009                     *n -= assign;
1010 
1011                     if next.permits_to_acquire() == 0 {
1012                         if curr.permits_to_acquire() > 0 {
1013                             self.waker.wake();
1014                         }
1015 
1016                         return true;
1017                     } else {
1018                         return false;
1019                     }
1020                 }
1021                 Err(actual) => curr = WaiterState(actual),
1022             }
1023         }
1024     }
1025 
revert_to_idle(&self)1026     fn revert_to_idle(&self) {
1027         // An idle node is not waiting on any permits
1028         self.state.store(0, Relaxed);
1029     }
1030 
store_next(&self, next: NonNull<Waiter>)1031     fn store_next(&self, next: NonNull<Waiter>) {
1032         self.next.store(next.as_ptr(), Release);
1033     }
1034 }
1035 
1036 // ===== impl SemState =====
1037 
1038 impl SemState {
1039     /// Returns a new default `State` value.
new(permits: usize, stub: &Waiter) -> SemState1040     fn new(permits: usize, stub: &Waiter) -> SemState {
1041         assert!(permits <= MAX_PERMITS);
1042 
1043         if permits > 0 {
1044             SemState((permits << NUM_SHIFT) | NUM_FLAG)
1045         } else {
1046             SemState(stub as *const _ as usize)
1047         }
1048     }
1049 
1050     /// Returns a `State` tracking `ptr` as the tail of the queue.
new_ptr(tail: NonNull<Waiter>, closed: bool) -> SemState1051     fn new_ptr(tail: NonNull<Waiter>, closed: bool) -> SemState {
1052         let mut val = tail.as_ptr() as usize;
1053 
1054         if closed {
1055             val |= CLOSED_FLAG;
1056         }
1057 
1058         SemState(val)
1059     }
1060 
1061     /// Returns the amount of remaining capacity
available_permits(self) -> usize1062     fn available_permits(self) -> usize {
1063         if !self.has_available_permits() {
1064             return 0;
1065         }
1066 
1067         self.0 >> NUM_SHIFT
1068     }
1069 
1070     /// Returns `true` if the state has permits that can be claimed by a waiter.
has_available_permits(self) -> bool1071     fn has_available_permits(self) -> bool {
1072         self.0 & NUM_FLAG == NUM_FLAG
1073     }
1074 
has_waiter(self, stub: &Waiter) -> bool1075     fn has_waiter(self, stub: &Waiter) -> bool {
1076         !self.has_available_permits() && !self.is_stub(stub)
1077     }
1078 
1079     /// Tries to atomically acquire specified number of permits.
1080     ///
1081     /// # Return
1082     ///
1083     /// Returns `true` if the specified number of permits were acquired, `false`
1084     /// otherwise. Returning false does not mean that there are no more
1085     /// available permits.
acquire_permits(&mut self, num: usize, stub: &Waiter) -> bool1086     fn acquire_permits(&mut self, num: usize, stub: &Waiter) -> bool {
1087         debug_assert!(num > 0);
1088 
1089         if self.available_permits() < num {
1090             return false;
1091         }
1092 
1093         debug_assert!(self.waiter().is_none());
1094 
1095         self.0 -= num << NUM_SHIFT;
1096 
1097         if self.0 == NUM_FLAG {
1098             // Set the state to the stub pointer.
1099             self.0 = stub as *const _ as usize;
1100         }
1101 
1102         true
1103     }
1104 
1105     /// Releases permits
1106     ///
1107     /// Returns `true` if the permits were accepted.
release_permits(&mut self, permits: usize, stub: &Waiter)1108     fn release_permits(&mut self, permits: usize, stub: &Waiter) {
1109         debug_assert!(permits > 0);
1110 
1111         if self.is_stub(stub) {
1112             self.0 = (permits << NUM_SHIFT) | NUM_FLAG | (self.0 & CLOSED_FLAG);
1113             return;
1114         }
1115 
1116         debug_assert!(self.has_available_permits());
1117 
1118         self.0 += permits << NUM_SHIFT;
1119     }
1120 
is_waiter(self) -> bool1121     fn is_waiter(self) -> bool {
1122         self.0 & NUM_FLAG == 0
1123     }
1124 
1125     /// Returns the waiter, if one is set.
waiter(self) -> Option<NonNull<Waiter>>1126     fn waiter(self) -> Option<NonNull<Waiter>> {
1127         if self.is_waiter() {
1128             let waiter = NonNull::new(self.as_ptr()).expect("null pointer stored");
1129 
1130             Some(waiter)
1131         } else {
1132             None
1133         }
1134     }
1135 
1136     /// Assumes `self` represents a pointer
as_ptr(self) -> *mut Waiter1137     fn as_ptr(self) -> *mut Waiter {
1138         (self.0 & !CLOSED_FLAG) as *mut Waiter
1139     }
1140 
1141     /// Sets to a pointer to a waiter.
1142     ///
1143     /// This can only be done from the full state.
set_waiter(&mut self, waiter: NonNull<Waiter>)1144     fn set_waiter(&mut self, waiter: NonNull<Waiter>) {
1145         let waiter = waiter.as_ptr() as usize;
1146         debug_assert!(!self.is_closed());
1147 
1148         self.0 = waiter;
1149     }
1150 
is_stub(self, stub: &Waiter) -> bool1151     fn is_stub(self, stub: &Waiter) -> bool {
1152         self.as_ptr() as usize == stub as *const _ as usize
1153     }
1154 
1155     /// Loads the state from an AtomicUsize.
load(cell: &AtomicUsize, ordering: Ordering) -> SemState1156     fn load(cell: &AtomicUsize, ordering: Ordering) -> SemState {
1157         let value = cell.load(ordering);
1158         SemState(value)
1159     }
1160 
fetch_set_closed(cell: &AtomicUsize, ordering: Ordering) -> SemState1161     fn fetch_set_closed(cell: &AtomicUsize, ordering: Ordering) -> SemState {
1162         let value = cell.fetch_or(CLOSED_FLAG, ordering);
1163         SemState(value)
1164     }
1165 
is_closed(self) -> bool1166     fn is_closed(self) -> bool {
1167         self.0 & CLOSED_FLAG == CLOSED_FLAG
1168     }
1169 
1170     /// Converts the state into a `usize` representation.
to_usize(self) -> usize1171     fn to_usize(self) -> usize {
1172         self.0
1173     }
1174 }
1175 
1176 impl fmt::Debug for SemState {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result1177     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
1178         let mut fmt = fmt.debug_struct("SemState");
1179 
1180         if self.is_waiter() {
1181             fmt.field("state", &"<waiter>");
1182         } else {
1183             fmt.field("permits", &self.available_permits());
1184         }
1185 
1186         fmt.finish()
1187     }
1188 }
1189 
1190 // ===== impl WaiterState =====
1191 
1192 impl WaiterState {
permits_to_acquire(self) -> usize1193     fn permits_to_acquire(self) -> usize {
1194         self.0 >> PERMIT_SHIFT
1195     }
1196 
set_permits_to_acquire(&mut self, val: usize)1197     fn set_permits_to_acquire(&mut self, val: usize) {
1198         self.0 = (val << PERMIT_SHIFT) | (self.0 & !PERMIT_MASK)
1199     }
1200 
is_queued(self) -> bool1201     fn is_queued(self) -> bool {
1202         self.0 & QUEUED == QUEUED
1203     }
1204 
set_queued(&mut self)1205     fn set_queued(&mut self) {
1206         self.0 |= QUEUED;
1207     }
1208 
is_closed(self) -> bool1209     fn is_closed(self) -> bool {
1210         self.0 & CLOSED == CLOSED
1211     }
1212 
set_closed(&mut self)1213     fn set_closed(&mut self) {
1214         self.0 |= CLOSED;
1215     }
1216 
unset_queued(&mut self)1217     fn unset_queued(&mut self) {
1218         assert!(self.is_queued());
1219         self.0 -= QUEUED;
1220     }
1221 
is_dropped(self) -> bool1222     fn is_dropped(self) -> bool {
1223         self.0 & DROPPED == DROPPED
1224     }
1225 }
1226