1 #![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))]
2
3 //! Thread-safe, asynchronous counting semaphore.
4 //!
5 //! A `Semaphore` instance holds a set of permits. Permits are used to
6 //! synchronize access to a shared resource.
7 //!
8 //! Before accessing the shared resource, callers acquire a permit from the
9 //! semaphore. Once the permit is acquired, the caller then enters the critical
10 //! section. If no permits are available, then acquiring the semaphore returns
11 //! `Pending`. The task is woken once a permit becomes available.
12
13 use crate::loom::cell::UnsafeCell;
14 use crate::loom::future::AtomicWaker;
15 use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize};
16 use crate::loom::thread;
17
18 use std::cmp;
19 use std::fmt;
20 use std::ptr::{self, NonNull};
21 use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed, Release};
22 use std::task::Poll::{Pending, Ready};
23 use std::task::{Context, Poll};
24 use std::usize;
25
26 /// Futures-aware semaphore.
27 pub(crate) struct Semaphore {
28 /// Tracks both the waiter queue tail pointer and the number of remaining
29 /// permits.
30 state: AtomicUsize,
31
32 /// waiter queue head pointer.
33 head: UnsafeCell<NonNull<Waiter>>,
34
35 /// Coordinates access to the queue head.
36 rx_lock: AtomicUsize,
37
38 /// Stub waiter node used as part of the MPSC channel algorithm.
39 stub: Box<Waiter>,
40 }
41
42 /// A semaphore permit
43 ///
44 /// Tracks the lifecycle of a semaphore permit.
45 ///
46 /// An instance of `Permit` is intended to be used with a **single** instance of
47 /// `Semaphore`. Using a single instance of `Permit` with multiple semaphore
48 /// instances will result in unexpected behavior.
49 ///
50 /// `Permit` does **not** release the permit back to the semaphore on drop. It
51 /// is the user's responsibility to ensure that `Permit::release` is called
52 /// before dropping the permit.
53 #[derive(Debug)]
54 pub(crate) struct Permit {
55 waiter: Option<Box<Waiter>>,
56 state: PermitState,
57 }
58
59 /// Error returned by `Permit::poll_acquire`.
60 #[derive(Debug)]
61 pub(crate) struct AcquireError(());
62
63 /// Error returned by `Permit::try_acquire`.
64 #[derive(Debug)]
65 pub(crate) enum TryAcquireError {
66 Closed,
67 NoPermits,
68 }
69
70 /// Node used to notify the semaphore waiter when permit is available.
71 #[derive(Debug)]
72 struct Waiter {
73 /// Stores waiter state.
74 ///
75 /// See `WaiterState` for more details.
76 state: AtomicUsize,
77
78 /// Task to wake when a permit is made available.
79 waker: AtomicWaker,
80
81 /// Next pointer in the queue of waiting senders.
82 next: AtomicPtr<Waiter>,
83 }
84
85 /// Semaphore state
86 ///
87 /// The 2 low bits track the modes.
88 ///
89 /// - Closed
90 /// - Full
91 ///
92 /// When not full, the rest of the `usize` tracks the total number of messages
93 /// in the channel. When full, the rest of the `usize` is a pointer to the tail
94 /// of the "waiting senders" queue.
95 #[derive(Copy, Clone)]
96 struct SemState(usize);
97
98 /// Permit state
99 #[derive(Debug, Copy, Clone)]
100 enum PermitState {
101 /// Currently waiting for permits to be made available and assigned to the
102 /// waiter.
103 Waiting(u16),
104
105 /// The number of acquired permits
106 Acquired(u16),
107 }
108
109 /// State for an individual waker node
110 #[derive(Debug, Copy, Clone)]
111 struct WaiterState(usize);
112
113 /// Waiter node is in the semaphore queue
114 const QUEUED: usize = 0b001;
115
116 /// Semaphore has been closed, no more permits will be issued.
117 const CLOSED: usize = 0b10;
118
119 /// The permit that owns the `Waiter` dropped.
120 const DROPPED: usize = 0b100;
121
122 /// Represents "one requested permit" in the waiter state
123 const PERMIT_ONE: usize = 0b1000;
124
125 /// Masks the waiter state to only contain bits tracking number of requested
126 /// permits.
127 const PERMIT_MASK: usize = usize::MAX - (PERMIT_ONE - 1);
128
129 /// How much to shift a permit count to pack it into the waker state
130 const PERMIT_SHIFT: u32 = PERMIT_ONE.trailing_zeros();
131
132 /// Flag differentiating between available permits and waiter pointers.
133 ///
134 /// If we assume pointers are properly aligned, then the least significant bit
135 /// will always be zero. So, we use that bit to track if the value represents a
136 /// number.
137 const NUM_FLAG: usize = 0b01;
138
139 /// Signal the semaphore is closed
140 const CLOSED_FLAG: usize = 0b10;
141
142 /// Maximum number of permits a semaphore can manage
143 const MAX_PERMITS: usize = usize::MAX >> NUM_SHIFT;
144
145 /// When representing "numbers", the state has to be shifted this much (to get
146 /// rid of the flag bit).
147 const NUM_SHIFT: usize = 2;
148
149 // ===== impl Semaphore =====
150
151 impl Semaphore {
152 /// Creates a new semaphore with the initial number of permits
153 ///
154 /// # Panics
155 ///
156 /// Panics if `permits` is zero.
new(permits: usize) -> Semaphore157 pub(crate) fn new(permits: usize) -> Semaphore {
158 let stub = Box::new(Waiter::new());
159 let ptr = NonNull::from(&*stub);
160
161 // Allocations are aligned
162 debug_assert!(ptr.as_ptr() as usize & NUM_FLAG == 0);
163
164 let state = SemState::new(permits, &stub);
165
166 Semaphore {
167 state: AtomicUsize::new(state.to_usize()),
168 head: UnsafeCell::new(ptr),
169 rx_lock: AtomicUsize::new(0),
170 stub,
171 }
172 }
173
174 /// Returns the current number of available permits
available_permits(&self) -> usize175 pub(crate) fn available_permits(&self) -> usize {
176 let curr = SemState(self.state.load(Acquire));
177 curr.available_permits()
178 }
179
180 /// Tries to acquire the requested number of permits, registering the waiter
181 /// if not enough permits are available.
poll_acquire( &self, cx: &mut Context<'_>, num_permits: u16, permit: &mut Permit, ) -> Poll<Result<(), AcquireError>>182 fn poll_acquire(
183 &self,
184 cx: &mut Context<'_>,
185 num_permits: u16,
186 permit: &mut Permit,
187 ) -> Poll<Result<(), AcquireError>> {
188 self.poll_acquire2(num_permits, || {
189 let waiter = permit.waiter.get_or_insert_with(|| Box::new(Waiter::new()));
190
191 waiter.waker.register_by_ref(cx.waker());
192
193 Some(NonNull::from(&**waiter))
194 })
195 }
196
try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError>197 fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> {
198 match self.poll_acquire2(num_permits, || None) {
199 Poll::Ready(res) => res.map_err(to_try_acquire),
200 Poll::Pending => Err(TryAcquireError::NoPermits),
201 }
202 }
203
204 /// Polls for a permit
205 ///
206 /// Tries to acquire available permits first. If unable to acquire a
207 /// sufficient number of permits, the caller's waiter is pushed onto the
208 /// semaphore's wait queue.
poll_acquire2<F>( &self, num_permits: u16, mut get_waiter: F, ) -> Poll<Result<(), AcquireError>> where F: FnMut() -> Option<NonNull<Waiter>>,209 fn poll_acquire2<F>(
210 &self,
211 num_permits: u16,
212 mut get_waiter: F,
213 ) -> Poll<Result<(), AcquireError>>
214 where
215 F: FnMut() -> Option<NonNull<Waiter>>,
216 {
217 let num_permits = num_permits as usize;
218
219 // Load the current state
220 let mut curr = SemState(self.state.load(Acquire));
221
222 // Saves a ref to the waiter node
223 let mut maybe_waiter: Option<NonNull<Waiter>> = None;
224
225 /// Used in branches where we attempt to push the waiter into the wait
226 /// queue but fail due to permits becoming available or the wait queue
227 /// transitioning to "closed". In this case, the waiter must be
228 /// transitioned back to the "idle" state.
229 macro_rules! revert_to_idle {
230 () => {
231 if let Some(waiter) = maybe_waiter {
232 unsafe { waiter.as_ref() }.revert_to_idle();
233 }
234 };
235 }
236
237 loop {
238 let mut next = curr;
239
240 if curr.is_closed() {
241 revert_to_idle!();
242 return Ready(Err(AcquireError::closed()));
243 }
244
245 let acquired = next.acquire_permits(num_permits, &self.stub);
246
247 if !acquired {
248 // There are not enough available permits to satisfy the
249 // request. The permit transitions to a waiting state.
250 debug_assert!(curr.waiter().is_some() || curr.available_permits() < num_permits);
251
252 if let Some(waiter) = maybe_waiter.as_ref() {
253 // Safety: the caller owns the waiter.
254 let w = unsafe { waiter.as_ref() };
255 w.set_permits_to_acquire(num_permits - curr.available_permits());
256 } else {
257 // Get the waiter for the permit.
258 if let Some(waiter) = get_waiter() {
259 // Safety: the caller owns the waiter.
260 let w = unsafe { waiter.as_ref() };
261
262 // If there are any currently available permits, the
263 // waiter acquires those immediately and waits for the
264 // remaining permits to become available.
265 if !w.to_queued(num_permits - curr.available_permits()) {
266 // The node is alrady queued, there is no further work
267 // to do.
268 return Pending;
269 }
270
271 maybe_waiter = Some(waiter);
272 } else {
273 // No waiter, this indicates the caller does not wish to
274 // "wait", so there is nothing left to do.
275 return Pending;
276 }
277 }
278
279 next.set_waiter(maybe_waiter.unwrap());
280 }
281
282 debug_assert_ne!(curr.0, 0);
283 debug_assert_ne!(next.0, 0);
284
285 match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
286 Ok(_) => {
287 if acquired {
288 // Successfully acquire permits **without** queuing the
289 // waiter node. The waiter node is not currently in the
290 // queue.
291 revert_to_idle!();
292 return Ready(Ok(()));
293 } else {
294 // The node is pushed into the queue, the final step is
295 // to set the node's "next" pointer to return the wait
296 // queue into a consistent state.
297
298 let prev_waiter =
299 curr.waiter().unwrap_or_else(|| NonNull::from(&*self.stub));
300
301 let waiter = maybe_waiter.unwrap();
302
303 // Link the nodes.
304 //
305 // Safety: the mpsc algorithm guarantees the old tail of
306 // the queue is not removed from the queue during the
307 // push process.
308 unsafe {
309 prev_waiter.as_ref().store_next(waiter);
310 }
311
312 return Pending;
313 }
314 }
315 Err(actual) => {
316 curr = SemState(actual);
317 }
318 }
319 }
320 }
321
322 /// Closes the semaphore. This prevents the semaphore from issuing new
323 /// permits and notifies all pending waiters.
close(&self)324 pub(crate) fn close(&self) {
325 // Acquire the `rx_lock`, setting the "closed" flag on the lock.
326 let prev = self.rx_lock.fetch_or(1, AcqRel);
327
328 if prev != 0 {
329 // Another thread has the lock and will be responsible for notifying
330 // pending waiters.
331 return;
332 }
333
334 self.add_permits_locked(0, true);
335 }
336
337 /// Adds `n` new permits to the semaphore.
add_permits(&self, n: usize)338 pub(crate) fn add_permits(&self, n: usize) {
339 if n == 0 {
340 return;
341 }
342
343 // TODO: Handle overflow. A panic is not sufficient, the process must
344 // abort.
345 let prev = self.rx_lock.fetch_add(n << 1, AcqRel);
346
347 if prev != 0 {
348 // Another thread has the lock and will be responsible for notifying
349 // pending waiters.
350 return;
351 }
352
353 self.add_permits_locked(n, false);
354 }
355
add_permits_locked(&self, mut rem: usize, mut closed: bool)356 fn add_permits_locked(&self, mut rem: usize, mut closed: bool) {
357 while rem > 0 || closed {
358 if closed {
359 SemState::fetch_set_closed(&self.state, AcqRel);
360 }
361
362 // Release the permits and notify
363 self.add_permits_locked2(rem, closed);
364
365 let n = rem << 1;
366
367 let actual = if closed {
368 let actual = self.rx_lock.fetch_sub(n | 1, AcqRel);
369 closed = false;
370 actual
371 } else {
372 let actual = self.rx_lock.fetch_sub(n, AcqRel);
373 closed = actual & 1 == 1;
374 actual
375 };
376
377 rem = (actual >> 1) - rem;
378 }
379 }
380
381 /// Releases a specific amount of permits to the semaphore
382 ///
383 /// This function is called by `add_permits` after the add lock has been
384 /// acquired.
add_permits_locked2(&self, mut n: usize, closed: bool)385 fn add_permits_locked2(&self, mut n: usize, closed: bool) {
386 // If closing the semaphore, we want to drain the entire queue. The
387 // number of permits being assigned doesn't matter.
388 if closed {
389 n = usize::MAX;
390 }
391
392 'outer: while n > 0 {
393 unsafe {
394 let mut head = self.head.with(|head| *head);
395 let mut next_ptr = head.as_ref().next.load(Acquire);
396
397 let stub = self.stub();
398
399 if head == stub {
400 // The stub node indicates an empty queue. Any remaining
401 // permits get assigned back to the semaphore.
402 let next = match NonNull::new(next_ptr) {
403 Some(next) => next,
404 None => {
405 // This loop is not part of the standard intrusive mpsc
406 // channel algorithm. This is where we atomically pop
407 // the last task and add `n` to the remaining capacity.
408 //
409 // This modification to the pop algorithm works because,
410 // at this point, we have not done any work (only done
411 // reading). We have a *pretty* good idea that there is
412 // no concurrent pusher.
413 //
414 // The capacity is then atomically added by doing an
415 // AcqRel CAS on `state`. The `state` cell is the
416 // linchpin of the algorithm.
417 //
418 // By successfully CASing `head` w/ AcqRel, we ensure
419 // that, if any thread was racing and entered a push, we
420 // see that and abort pop, retrying as it is
421 // "inconsistent".
422 let mut curr = SemState::load(&self.state, Acquire);
423
424 loop {
425 if curr.has_waiter(&self.stub) {
426 // A waiter is being added concurrently.
427 // This is the MPSC queue's "inconsistent"
428 // state and we must loop and try again.
429 thread::yield_now();
430 continue 'outer;
431 }
432
433 // If closing, nothing more to do.
434 if closed {
435 debug_assert!(curr.is_closed(), "state = {:?}", curr);
436 return;
437 }
438
439 let mut next = curr;
440 next.release_permits(n, &self.stub);
441
442 match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
443 Ok(_) => return,
444 Err(actual) => {
445 curr = SemState(actual);
446 }
447 }
448 }
449 }
450 };
451
452 self.head.with_mut(|head| *head = next);
453 head = next;
454 next_ptr = next.as_ref().next.load(Acquire);
455 }
456
457 // `head` points to a waiter assign permits to the waiter. If
458 // all requested permits are satisfied, then we can continue,
459 // otherwise the node stays in the wait queue.
460 if !head.as_ref().assign_permits(&mut n, closed) {
461 assert_eq!(n, 0);
462 return;
463 }
464
465 if let Some(next) = NonNull::new(next_ptr) {
466 self.head.with_mut(|head| *head = next);
467
468 self.remove_queued(head, closed);
469 continue 'outer;
470 }
471
472 let state = SemState::load(&self.state, Acquire);
473
474 // This must always be a pointer as the wait list is not empty.
475 let tail = state.waiter().unwrap();
476
477 if tail != head {
478 // Inconsistent
479 thread::yield_now();
480 continue 'outer;
481 }
482
483 self.push_stub(closed);
484
485 next_ptr = head.as_ref().next.load(Acquire);
486
487 if let Some(next) = NonNull::new(next_ptr) {
488 self.head.with_mut(|head| *head = next);
489
490 self.remove_queued(head, closed);
491 continue 'outer;
492 }
493
494 // Inconsistent state, loop
495 thread::yield_now();
496 }
497 }
498 }
499
500 /// The wait node has had all of its permits assigned and has been removed
501 /// from the wait queue.
502 ///
503 /// Attempt to remove the QUEUED bit from the node. If additional permits
504 /// are concurrently requested, the node must be pushed back into the wait
505 /// queued.
remove_queued(&self, waiter: NonNull<Waiter>, closed: bool)506 fn remove_queued(&self, waiter: NonNull<Waiter>, closed: bool) {
507 let mut curr = WaiterState(unsafe { waiter.as_ref() }.state.load(Acquire));
508
509 loop {
510 if curr.is_dropped() {
511 // The Permit dropped, it is on us to release the memory
512 let _ = unsafe { Box::from_raw(waiter.as_ptr()) };
513 return;
514 }
515
516 // The node is removed from the queue. We attempt to unset the
517 // queued bit, but concurrently the waiter has requested more
518 // permits. When the waiter requested more permits, it saw the
519 // queued bit set so took no further action. This requires us to
520 // push the node back into the queue.
521 if curr.permits_to_acquire() > 0 {
522 // More permits are requested. The waiter must be re-queued
523 unsafe {
524 self.push_waiter(waiter, closed);
525 }
526 return;
527 }
528
529 let mut next = curr;
530 next.unset_queued();
531
532 let w = unsafe { waiter.as_ref() };
533
534 match w.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
535 Ok(_) => return,
536 Err(actual) => {
537 curr = WaiterState(actual);
538 }
539 }
540 }
541 }
542
push_stub(&self, closed: bool)543 unsafe fn push_stub(&self, closed: bool) {
544 self.push_waiter(self.stub(), closed);
545 }
546
push_waiter(&self, waiter: NonNull<Waiter>, closed: bool)547 unsafe fn push_waiter(&self, waiter: NonNull<Waiter>, closed: bool) {
548 // Set the next pointer. This does not require an atomic operation as
549 // this node is not accessible. The write will be flushed with the next
550 // operation
551 waiter.as_ref().next.store(ptr::null_mut(), Relaxed);
552
553 // Update the tail to point to the new node. We need to see the previous
554 // node in order to update the next pointer as well as release `task`
555 // to any other threads calling `push`.
556 let next = SemState::new_ptr(waiter, closed);
557 let prev = SemState(self.state.swap(next.0, AcqRel));
558
559 debug_assert_eq!(closed, prev.is_closed());
560
561 // This function is only called when there are pending tasks. Because of
562 // this, the state must *always* be in pointer mode.
563 let prev = prev.waiter().unwrap();
564
565 // No cycles plz
566 debug_assert_ne!(prev, waiter);
567
568 // Release `task` to the consume end.
569 prev.as_ref().next.store(waiter.as_ptr(), Release);
570 }
571
stub(&self) -> NonNull<Waiter>572 fn stub(&self) -> NonNull<Waiter> {
573 unsafe { NonNull::new_unchecked(&*self.stub as *const _ as *mut _) }
574 }
575 }
576
577 impl Drop for Semaphore {
drop(&mut self)578 fn drop(&mut self) {
579 self.close();
580 }
581 }
582
583 impl fmt::Debug for Semaphore {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result584 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
585 fmt.debug_struct("Semaphore")
586 .field("state", &SemState::load(&self.state, Relaxed))
587 .field("head", &self.head.with(|ptr| ptr))
588 .field("rx_lock", &self.rx_lock.load(Relaxed))
589 .field("stub", &self.stub)
590 .finish()
591 }
592 }
593
594 unsafe impl Send for Semaphore {}
595 unsafe impl Sync for Semaphore {}
596
597 // ===== impl Permit =====
598
599 impl Permit {
600 /// Creates a new `Permit`.
601 ///
602 /// The permit begins in the "unacquired" state.
new() -> Permit603 pub(crate) fn new() -> Permit {
604 use PermitState::Acquired;
605
606 Permit {
607 waiter: None,
608 state: Acquired(0),
609 }
610 }
611
612 /// Returns `true` if the permit has been acquired
613 #[allow(dead_code)] // may be used later
is_acquired(&self) -> bool614 pub(crate) fn is_acquired(&self) -> bool {
615 match self.state {
616 PermitState::Acquired(num) if num > 0 => true,
617 _ => false,
618 }
619 }
620
621 /// Tries to acquire the permit. If no permits are available, the current task
622 /// is notified once a new permit becomes available.
poll_acquire( &mut self, cx: &mut Context<'_>, num_permits: u16, semaphore: &Semaphore, ) -> Poll<Result<(), AcquireError>>623 pub(crate) fn poll_acquire(
624 &mut self,
625 cx: &mut Context<'_>,
626 num_permits: u16,
627 semaphore: &Semaphore,
628 ) -> Poll<Result<(), AcquireError>> {
629 use std::cmp::Ordering::*;
630 use PermitState::*;
631
632 match self.state {
633 Waiting(requested) => {
634 // There must be a waiter
635 let waiter = self.waiter.as_ref().unwrap();
636
637 match requested.cmp(&num_permits) {
638 Less => {
639 let delta = num_permits - requested;
640
641 // Request additional permits. If the waiter has been
642 // dequeued, it must be re-queued.
643 if !waiter.try_inc_permits_to_acquire(delta as usize) {
644 let waiter = NonNull::from(&**waiter);
645
646 // Ignore the result. The check for
647 // `permits_to_acquire()` will converge the state as
648 // needed
649 let _ = semaphore.poll_acquire2(delta, || Some(waiter))?;
650 }
651
652 self.state = Waiting(num_permits);
653 }
654 Greater => {
655 let delta = requested - num_permits;
656 let to_release = waiter.try_dec_permits_to_acquire(delta as usize);
657
658 semaphore.add_permits(to_release);
659 self.state = Waiting(num_permits);
660 }
661 Equal => {}
662 }
663
664 if waiter.permits_to_acquire()? == 0 {
665 self.state = Acquired(requested);
666 return Ready(Ok(()));
667 }
668
669 waiter.waker.register_by_ref(cx.waker());
670
671 if waiter.permits_to_acquire()? == 0 {
672 self.state = Acquired(requested);
673 return Ready(Ok(()));
674 }
675
676 Pending
677 }
678 Acquired(acquired) => {
679 if acquired >= num_permits {
680 Ready(Ok(()))
681 } else {
682 match semaphore.poll_acquire(cx, num_permits - acquired, self)? {
683 Ready(()) => {
684 self.state = Acquired(num_permits);
685 Ready(Ok(()))
686 }
687 Pending => {
688 self.state = Waiting(num_permits);
689 Pending
690 }
691 }
692 }
693 }
694 }
695 }
696
697 /// Tries to acquire the permit.
try_acquire( &mut self, num_permits: u16, semaphore: &Semaphore, ) -> Result<(), TryAcquireError>698 pub(crate) fn try_acquire(
699 &mut self,
700 num_permits: u16,
701 semaphore: &Semaphore,
702 ) -> Result<(), TryAcquireError> {
703 use PermitState::*;
704
705 match self.state {
706 Waiting(requested) => {
707 // There must be a waiter
708 let waiter = self.waiter.as_ref().unwrap();
709
710 if requested > num_permits {
711 let delta = requested - num_permits;
712 let to_release = waiter.try_dec_permits_to_acquire(delta as usize);
713
714 semaphore.add_permits(to_release);
715 self.state = Waiting(num_permits);
716 }
717
718 let res = waiter.permits_to_acquire().map_err(to_try_acquire)?;
719
720 if res == 0 {
721 if requested < num_permits {
722 // Try to acquire the additional permits
723 semaphore.try_acquire(num_permits - requested)?;
724 }
725
726 self.state = Acquired(num_permits);
727 Ok(())
728 } else {
729 Err(TryAcquireError::NoPermits)
730 }
731 }
732 Acquired(acquired) => {
733 if acquired < num_permits {
734 semaphore.try_acquire(num_permits - acquired)?;
735 self.state = Acquired(num_permits);
736 }
737
738 Ok(())
739 }
740 }
741 }
742
743 /// Releases a permit back to the semaphore
release(&mut self, n: u16, semaphore: &Semaphore)744 pub(crate) fn release(&mut self, n: u16, semaphore: &Semaphore) {
745 let n = self.forget(n);
746 semaphore.add_permits(n as usize);
747 }
748
749 /// Forgets the permit **without** releasing it back to the semaphore.
750 ///
751 /// After calling `forget`, `poll_acquire` is able to acquire new permit
752 /// from the sempahore.
753 ///
754 /// Repeatedly calling `forget` without associated calls to `add_permit`
755 /// will result in the semaphore losing all permits.
756 ///
757 /// Will forget **at most** the number of acquired permits. This number is
758 /// returned.
forget(&mut self, n: u16) -> u16759 pub(crate) fn forget(&mut self, n: u16) -> u16 {
760 use PermitState::*;
761
762 match self.state {
763 Waiting(requested) => {
764 let n = cmp::min(n, requested);
765
766 // Decrement
767 let acquired = self
768 .waiter
769 .as_ref()
770 .unwrap()
771 .try_dec_permits_to_acquire(n as usize) as u16;
772
773 if n == requested {
774 self.state = Acquired(0);
775 } else if acquired == requested - n {
776 self.state = Waiting(acquired);
777 } else {
778 self.state = Waiting(requested - n);
779 }
780
781 acquired
782 }
783 Acquired(acquired) => {
784 let n = cmp::min(n, acquired);
785 self.state = Acquired(acquired - n);
786 n
787 }
788 }
789 }
790 }
791
792 impl Default for Permit {
default() -> Self793 fn default() -> Self {
794 Self::new()
795 }
796 }
797
798 impl Drop for Permit {
drop(&mut self)799 fn drop(&mut self) {
800 if let Some(waiter) = self.waiter.take() {
801 // Set the dropped flag
802 let state = WaiterState(waiter.state.fetch_or(DROPPED, AcqRel));
803
804 if state.is_queued() {
805 // The waiter is stored in the queue. The semaphore will drop it
806 std::mem::forget(waiter);
807 }
808 }
809 }
810 }
811
812 // ===== impl AcquireError ====
813
814 impl AcquireError {
closed() -> AcquireError815 fn closed() -> AcquireError {
816 AcquireError(())
817 }
818 }
819
to_try_acquire(_: AcquireError) -> TryAcquireError820 fn to_try_acquire(_: AcquireError) -> TryAcquireError {
821 TryAcquireError::Closed
822 }
823
824 impl fmt::Display for AcquireError {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result825 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
826 write!(fmt, "semaphore closed")
827 }
828 }
829
830 impl std::error::Error for AcquireError {}
831
832 // ===== impl TryAcquireError =====
833
834 impl TryAcquireError {
835 /// Returns `true` if the error was caused by a closed semaphore.
is_closed(&self) -> bool836 pub(crate) fn is_closed(&self) -> bool {
837 match self {
838 TryAcquireError::Closed => true,
839 _ => false,
840 }
841 }
842
843 /// Returns `true` if the error was caused by calling `try_acquire` on a
844 /// semaphore with no available permits.
is_no_permits(&self) -> bool845 pub(crate) fn is_no_permits(&self) -> bool {
846 match self {
847 TryAcquireError::NoPermits => true,
848 _ => false,
849 }
850 }
851 }
852
853 impl fmt::Display for TryAcquireError {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result854 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
855 match self {
856 TryAcquireError::Closed => write!(fmt, "{}", "semaphore closed"),
857 TryAcquireError::NoPermits => write!(fmt, "{}", "no permits available"),
858 }
859 }
860 }
861
862 impl std::error::Error for TryAcquireError {}
863
864 // ===== impl Waiter =====
865
866 impl Waiter {
new() -> Waiter867 fn new() -> Waiter {
868 Waiter {
869 state: AtomicUsize::new(0),
870 waker: AtomicWaker::new(),
871 next: AtomicPtr::new(ptr::null_mut()),
872 }
873 }
874
permits_to_acquire(&self) -> Result<usize, AcquireError>875 fn permits_to_acquire(&self) -> Result<usize, AcquireError> {
876 let state = WaiterState(self.state.load(Acquire));
877
878 if state.is_closed() {
879 Err(AcquireError(()))
880 } else {
881 Ok(state.permits_to_acquire())
882 }
883 }
884
885 /// Only increments the number of permits *if* the waiter is currently
886 /// queued.
887 ///
888 /// # Returns
889 ///
890 /// `true` if the number of permits to acquire has been incremented. `false`
891 /// otherwise. On `false`, the caller should use `Semaphore::poll_acquire`.
try_inc_permits_to_acquire(&self, n: usize) -> bool892 fn try_inc_permits_to_acquire(&self, n: usize) -> bool {
893 let mut curr = WaiterState(self.state.load(Acquire));
894
895 loop {
896 if !curr.is_queued() {
897 assert_eq!(0, curr.permits_to_acquire());
898 return false;
899 }
900
901 let mut next = curr;
902 next.set_permits_to_acquire(n + curr.permits_to_acquire());
903
904 match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
905 Ok(_) => return true,
906 Err(actual) => curr = WaiterState(actual),
907 }
908 }
909 }
910
911 /// Try to decrement the number of permits to acquire. This returns the
912 /// actual number of permits that were decremented. The delta betweeen `n`
913 /// and the return has been assigned to the permit and the caller must
914 /// assign these back to the semaphore.
try_dec_permits_to_acquire(&self, n: usize) -> usize915 fn try_dec_permits_to_acquire(&self, n: usize) -> usize {
916 let mut curr = WaiterState(self.state.load(Acquire));
917
918 loop {
919 if !curr.is_queued() {
920 assert_eq!(0, curr.permits_to_acquire());
921 }
922
923 let delta = cmp::min(n, curr.permits_to_acquire());
924 let rem = curr.permits_to_acquire() - delta;
925
926 let mut next = curr;
927 next.set_permits_to_acquire(rem);
928
929 match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
930 Ok(_) => return n - delta,
931 Err(actual) => curr = WaiterState(actual),
932 }
933 }
934 }
935
936 /// Store the number of remaining permits needed to satisfy the waiter and
937 /// transition to the "QUEUED" state.
938 ///
939 /// # Returns
940 ///
941 /// `true` if the `QUEUED` bit was set as part of the transition.
to_queued(&self, num_permits: usize) -> bool942 fn to_queued(&self, num_permits: usize) -> bool {
943 let mut curr = WaiterState(self.state.load(Acquire));
944
945 // The waiter should **not** be waiting for any permits.
946 debug_assert_eq!(curr.permits_to_acquire(), 0);
947
948 loop {
949 let mut next = curr;
950 next.set_permits_to_acquire(num_permits);
951 next.set_queued();
952
953 match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
954 Ok(_) => {
955 if curr.is_queued() {
956 return false;
957 } else {
958 // Make sure the next pointer is null
959 self.next.store(ptr::null_mut(), Relaxed);
960 return true;
961 }
962 }
963 Err(actual) => curr = WaiterState(actual),
964 }
965 }
966 }
967
968 /// Set the number of permits to acquire.
969 ///
970 /// This function is only called when the waiter is being inserted into the
971 /// wait queue. Because of this, there are no concurrent threads that can
972 /// modify the state and using `store` is safe.
set_permits_to_acquire(&self, num_permits: usize)973 fn set_permits_to_acquire(&self, num_permits: usize) {
974 debug_assert!(WaiterState(self.state.load(Acquire)).is_queued());
975
976 let mut state = WaiterState(QUEUED);
977 state.set_permits_to_acquire(num_permits);
978
979 self.state.store(state.0, Release);
980 }
981
982 /// Assign permits to the waiter.
983 ///
984 /// Returns `true` if the waiter should be removed from the queue
assign_permits(&self, n: &mut usize, closed: bool) -> bool985 fn assign_permits(&self, n: &mut usize, closed: bool) -> bool {
986 let mut curr = WaiterState(self.state.load(Acquire));
987
988 loop {
989 let mut next = curr;
990
991 // Number of permits to assign to this waiter
992 let assign = cmp::min(curr.permits_to_acquire(), *n);
993
994 // Assign the permits
995 next.set_permits_to_acquire(curr.permits_to_acquire() - assign);
996
997 if closed {
998 next.set_closed();
999 }
1000
1001 match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
1002 Ok(_) => {
1003 // Update `n`
1004 *n -= assign;
1005
1006 if next.permits_to_acquire() == 0 {
1007 if curr.permits_to_acquire() > 0 {
1008 self.waker.wake();
1009 }
1010
1011 return true;
1012 } else {
1013 return false;
1014 }
1015 }
1016 Err(actual) => curr = WaiterState(actual),
1017 }
1018 }
1019 }
1020
revert_to_idle(&self)1021 fn revert_to_idle(&self) {
1022 // An idle node is not waiting on any permits
1023 self.state.store(0, Relaxed);
1024 }
1025
store_next(&self, next: NonNull<Waiter>)1026 fn store_next(&self, next: NonNull<Waiter>) {
1027 self.next.store(next.as_ptr(), Release);
1028 }
1029 }
1030
1031 // ===== impl SemState =====
1032
1033 impl SemState {
1034 /// Returns a new default `State` value.
new(permits: usize, stub: &Waiter) -> SemState1035 fn new(permits: usize, stub: &Waiter) -> SemState {
1036 assert!(permits <= MAX_PERMITS);
1037
1038 if permits > 0 {
1039 SemState((permits << NUM_SHIFT) | NUM_FLAG)
1040 } else {
1041 SemState(stub as *const _ as usize)
1042 }
1043 }
1044
1045 /// Returns a `State` tracking `ptr` as the tail of the queue.
new_ptr(tail: NonNull<Waiter>, closed: bool) -> SemState1046 fn new_ptr(tail: NonNull<Waiter>, closed: bool) -> SemState {
1047 let mut val = tail.as_ptr() as usize;
1048
1049 if closed {
1050 val |= CLOSED_FLAG;
1051 }
1052
1053 SemState(val)
1054 }
1055
1056 /// Returns the amount of remaining capacity
available_permits(self) -> usize1057 fn available_permits(self) -> usize {
1058 if !self.has_available_permits() {
1059 return 0;
1060 }
1061
1062 self.0 >> NUM_SHIFT
1063 }
1064
1065 /// Returns `true` if the state has permits that can be claimed by a waiter.
has_available_permits(self) -> bool1066 fn has_available_permits(self) -> bool {
1067 self.0 & NUM_FLAG == NUM_FLAG
1068 }
1069
has_waiter(self, stub: &Waiter) -> bool1070 fn has_waiter(self, stub: &Waiter) -> bool {
1071 !self.has_available_permits() && !self.is_stub(stub)
1072 }
1073
1074 /// Tries to atomically acquire specified number of permits.
1075 ///
1076 /// # Return
1077 ///
1078 /// Returns `true` if the specified number of permits were acquired, `false`
1079 /// otherwise. Returning false does not mean that there are no more
1080 /// available permits.
acquire_permits(&mut self, num: usize, stub: &Waiter) -> bool1081 fn acquire_permits(&mut self, num: usize, stub: &Waiter) -> bool {
1082 debug_assert!(num > 0);
1083
1084 if self.available_permits() < num {
1085 return false;
1086 }
1087
1088 debug_assert!(self.waiter().is_none());
1089
1090 self.0 -= num << NUM_SHIFT;
1091
1092 if self.0 == NUM_FLAG {
1093 // Set the state to the stub pointer.
1094 self.0 = stub as *const _ as usize;
1095 }
1096
1097 true
1098 }
1099
1100 /// Releases permits
1101 ///
1102 /// Returns `true` if the permits were accepted.
release_permits(&mut self, permits: usize, stub: &Waiter)1103 fn release_permits(&mut self, permits: usize, stub: &Waiter) {
1104 debug_assert!(permits > 0);
1105
1106 if self.is_stub(stub) {
1107 self.0 = (permits << NUM_SHIFT) | NUM_FLAG | (self.0 & CLOSED_FLAG);
1108 return;
1109 }
1110
1111 debug_assert!(self.has_available_permits());
1112
1113 self.0 += permits << NUM_SHIFT;
1114 }
1115
is_waiter(self) -> bool1116 fn is_waiter(self) -> bool {
1117 self.0 & NUM_FLAG == 0
1118 }
1119
1120 /// Returns the waiter, if one is set.
waiter(self) -> Option<NonNull<Waiter>>1121 fn waiter(self) -> Option<NonNull<Waiter>> {
1122 if self.is_waiter() {
1123 let waiter = NonNull::new(self.as_ptr()).expect("null pointer stored");
1124
1125 Some(waiter)
1126 } else {
1127 None
1128 }
1129 }
1130
1131 /// Assumes `self` represents a pointer
as_ptr(self) -> *mut Waiter1132 fn as_ptr(self) -> *mut Waiter {
1133 (self.0 & !CLOSED_FLAG) as *mut Waiter
1134 }
1135
1136 /// Sets to a pointer to a waiter.
1137 ///
1138 /// This can only be done from the full state.
set_waiter(&mut self, waiter: NonNull<Waiter>)1139 fn set_waiter(&mut self, waiter: NonNull<Waiter>) {
1140 let waiter = waiter.as_ptr() as usize;
1141 debug_assert!(!self.is_closed());
1142
1143 self.0 = waiter;
1144 }
1145
is_stub(self, stub: &Waiter) -> bool1146 fn is_stub(self, stub: &Waiter) -> bool {
1147 self.as_ptr() as usize == stub as *const _ as usize
1148 }
1149
1150 /// Loads the state from an AtomicUsize.
load(cell: &AtomicUsize, ordering: Ordering) -> SemState1151 fn load(cell: &AtomicUsize, ordering: Ordering) -> SemState {
1152 let value = cell.load(ordering);
1153 SemState(value)
1154 }
1155
fetch_set_closed(cell: &AtomicUsize, ordering: Ordering) -> SemState1156 fn fetch_set_closed(cell: &AtomicUsize, ordering: Ordering) -> SemState {
1157 let value = cell.fetch_or(CLOSED_FLAG, ordering);
1158 SemState(value)
1159 }
1160
is_closed(self) -> bool1161 fn is_closed(self) -> bool {
1162 self.0 & CLOSED_FLAG == CLOSED_FLAG
1163 }
1164
1165 /// Converts the state into a `usize` representation.
to_usize(self) -> usize1166 fn to_usize(self) -> usize {
1167 self.0
1168 }
1169 }
1170
1171 impl fmt::Debug for SemState {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result1172 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
1173 let mut fmt = fmt.debug_struct("SemState");
1174
1175 if self.is_waiter() {
1176 fmt.field("state", &"<waiter>");
1177 } else {
1178 fmt.field("permits", &self.available_permits());
1179 }
1180
1181 fmt.finish()
1182 }
1183 }
1184
1185 // ===== impl WaiterState =====
1186
1187 impl WaiterState {
permits_to_acquire(self) -> usize1188 fn permits_to_acquire(self) -> usize {
1189 self.0 >> PERMIT_SHIFT
1190 }
1191
set_permits_to_acquire(&mut self, val: usize)1192 fn set_permits_to_acquire(&mut self, val: usize) {
1193 self.0 = (val << PERMIT_SHIFT) | (self.0 & !PERMIT_MASK)
1194 }
1195
is_queued(self) -> bool1196 fn is_queued(self) -> bool {
1197 self.0 & QUEUED == QUEUED
1198 }
1199
set_queued(&mut self)1200 fn set_queued(&mut self) {
1201 self.0 |= QUEUED;
1202 }
1203
is_closed(self) -> bool1204 fn is_closed(self) -> bool {
1205 self.0 & CLOSED == CLOSED
1206 }
1207
set_closed(&mut self)1208 fn set_closed(&mut self) {
1209 self.0 |= CLOSED;
1210 }
1211
unset_queued(&mut self)1212 fn unset_queued(&mut self) {
1213 assert!(self.is_queued());
1214 self.0 -= QUEUED;
1215 }
1216
is_dropped(self) -> bool1217 fn is_dropped(self) -> bool {
1218 self.0 & DROPPED == DROPPED
1219 }
1220 }
1221