1 #![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))]
2
3 //! Thread-safe, asynchronous counting semaphore.
4 //!
5 //! A `Semaphore` instance holds a set of permits. Permits are used to
6 //! synchronize access to a shared resource.
7 //!
8 //! Before accessing the shared resource, callers acquire a permit from the
9 //! semaphore. Once the permit is acquired, the caller then enters the critical
10 //! section. If no permits are available, then acquiring the semaphore returns
11 //! `Pending`. The task is woken once a permit becomes available.
12
13 use crate::loom::cell::UnsafeCell;
14 use crate::loom::future::AtomicWaker;
15 use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize};
16 use crate::loom::thread;
17
18 use std::cmp;
19 use std::fmt;
20 use std::ptr::{self, NonNull};
21 use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed, Release};
22 use std::task::Poll::{Pending, Ready};
23 use std::task::{Context, Poll};
24 use std::usize;
25
26 /// Futures-aware semaphore.
27 pub(crate) struct Semaphore {
28 /// Tracks both the waiter queue tail pointer and the number of remaining
29 /// permits.
30 state: AtomicUsize,
31
32 /// waiter queue head pointer.
33 head: UnsafeCell<NonNull<Waiter>>,
34
35 /// Coordinates access to the queue head.
36 rx_lock: AtomicUsize,
37
38 /// Stub waiter node used as part of the MPSC channel algorithm.
39 stub: Box<Waiter>,
40 }
41
42 /// A semaphore permit
43 ///
44 /// Tracks the lifecycle of a semaphore permit.
45 ///
46 /// An instance of `Permit` is intended to be used with a **single** instance of
47 /// `Semaphore`. Using a single instance of `Permit` with multiple semaphore
48 /// instances will result in unexpected behavior.
49 ///
50 /// `Permit` does **not** release the permit back to the semaphore on drop. It
51 /// is the user's responsibility to ensure that `Permit::release` is called
52 /// before dropping the permit.
53 #[derive(Debug)]
54 pub(crate) struct Permit {
55 waiter: Option<Box<Waiter>>,
56 state: PermitState,
57 }
58
59 /// Error returned by `Permit::poll_acquire`.
60 #[derive(Debug)]
61 pub(crate) struct AcquireError(());
62
63 /// Error returned by `Permit::try_acquire`.
64 #[derive(Debug)]
65 pub(crate) enum TryAcquireError {
66 Closed,
67 NoPermits,
68 }
69
70 /// Node used to notify the semaphore waiter when permit is available.
71 #[derive(Debug)]
72 struct Waiter {
73 /// Stores waiter state.
74 ///
75 /// See `WaiterState` for more details.
76 state: AtomicUsize,
77
78 /// Task to wake when a permit is made available.
79 waker: AtomicWaker,
80
81 /// Next pointer in the queue of waiting senders.
82 next: AtomicPtr<Waiter>,
83 }
84
85 /// Semaphore state
86 ///
87 /// The 2 low bits track the modes.
88 ///
89 /// - Closed
90 /// - Full
91 ///
92 /// When not full, the rest of the `usize` tracks the total number of messages
93 /// in the channel. When full, the rest of the `usize` is a pointer to the tail
94 /// of the "waiting senders" queue.
95 #[derive(Copy, Clone)]
96 struct SemState(usize);
97
98 /// Permit state
99 #[derive(Debug, Copy, Clone)]
100 enum PermitState {
101 /// Currently waiting for permits to be made available and assigned to the
102 /// waiter.
103 Waiting(u16),
104
105 /// The number of acquired permits
106 Acquired(u16),
107 }
108
109 /// State for an individual waker node
110 #[derive(Debug, Copy, Clone)]
111 struct WaiterState(usize);
112
113 /// Waiter node is in the semaphore queue
114 const QUEUED: usize = 0b001;
115
116 /// Semaphore has been closed, no more permits will be issued.
117 const CLOSED: usize = 0b10;
118
119 /// The permit that owns the `Waiter` dropped.
120 const DROPPED: usize = 0b100;
121
122 /// Represents "one requested permit" in the waiter state
123 const PERMIT_ONE: usize = 0b1000;
124
125 /// Masks the waiter state to only contain bits tracking number of requested
126 /// permits.
127 const PERMIT_MASK: usize = usize::MAX - (PERMIT_ONE - 1);
128
129 /// How much to shift a permit count to pack it into the waker state
130 const PERMIT_SHIFT: u32 = PERMIT_ONE.trailing_zeros();
131
132 /// Flag differentiating between available permits and waiter pointers.
133 ///
134 /// If we assume pointers are properly aligned, then the least significant bit
135 /// will always be zero. So, we use that bit to track if the value represents a
136 /// number.
137 const NUM_FLAG: usize = 0b01;
138
139 /// Signal the semaphore is closed
140 const CLOSED_FLAG: usize = 0b10;
141
142 /// Maximum number of permits a semaphore can manage
143 const MAX_PERMITS: usize = usize::MAX >> NUM_SHIFT;
144
145 /// When representing "numbers", the state has to be shifted this much (to get
146 /// rid of the flag bit).
147 const NUM_SHIFT: usize = 2;
148
149 // ===== impl Semaphore =====
150
151 impl Semaphore {
152 /// Creates a new semaphore with the initial number of permits
153 ///
154 /// # Panics
155 ///
156 /// Panics if `permits` is zero.
new(permits: usize) -> Semaphore157 pub(crate) fn new(permits: usize) -> Semaphore {
158 let stub = Box::new(Waiter::new());
159 let ptr = NonNull::from(&*stub);
160
161 // Allocations are aligned
162 debug_assert!(ptr.as_ptr() as usize & NUM_FLAG == 0);
163
164 let state = SemState::new(permits, &stub);
165
166 Semaphore {
167 state: AtomicUsize::new(state.to_usize()),
168 head: UnsafeCell::new(ptr),
169 rx_lock: AtomicUsize::new(0),
170 stub,
171 }
172 }
173
174 /// Returns the current number of available permits
available_permits(&self) -> usize175 pub(crate) fn available_permits(&self) -> usize {
176 let curr = SemState(self.state.load(Acquire));
177 curr.available_permits()
178 }
179
180 /// Tries to acquire the requested number of permits, registering the waiter
181 /// if not enough permits are available.
poll_acquire( &self, cx: &mut Context<'_>, num_permits: u16, permit: &mut Permit, ) -> Poll<Result<(), AcquireError>>182 fn poll_acquire(
183 &self,
184 cx: &mut Context<'_>,
185 num_permits: u16,
186 permit: &mut Permit,
187 ) -> Poll<Result<(), AcquireError>> {
188 self.poll_acquire2(num_permits, || {
189 let waiter = permit.waiter.get_or_insert_with(|| Box::new(Waiter::new()));
190
191 waiter.waker.register_by_ref(cx.waker());
192
193 Some(NonNull::from(&**waiter))
194 })
195 }
196
try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError>197 fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> {
198 match self.poll_acquire2(num_permits, || None) {
199 Poll::Ready(res) => res.map_err(to_try_acquire),
200 Poll::Pending => Err(TryAcquireError::NoPermits),
201 }
202 }
203
204 /// Polls for a permit
205 ///
206 /// Tries to acquire available permits first. If unable to acquire a
207 /// sufficient number of permits, the caller's waiter is pushed onto the
208 /// semaphore's wait queue.
poll_acquire2<F>( &self, num_permits: u16, mut get_waiter: F, ) -> Poll<Result<(), AcquireError>> where F: FnMut() -> Option<NonNull<Waiter>>,209 fn poll_acquire2<F>(
210 &self,
211 num_permits: u16,
212 mut get_waiter: F,
213 ) -> Poll<Result<(), AcquireError>>
214 where
215 F: FnMut() -> Option<NonNull<Waiter>>,
216 {
217 let num_permits = num_permits as usize;
218
219 // Load the current state
220 let mut curr = SemState(self.state.load(Acquire));
221
222 // Saves a ref to the waiter node
223 let mut maybe_waiter: Option<NonNull<Waiter>> = None;
224
225 /// Used in branches where we attempt to push the waiter into the wait
226 /// queue but fail due to permits becoming available or the wait queue
227 /// transitioning to "closed". In this case, the waiter must be
228 /// transitioned back to the "idle" state.
229 macro_rules! revert_to_idle {
230 () => {
231 if let Some(waiter) = maybe_waiter {
232 unsafe { waiter.as_ref() }.revert_to_idle();
233 }
234 };
235 }
236
237 loop {
238 let mut next = curr;
239
240 if curr.is_closed() {
241 revert_to_idle!();
242 return Ready(Err(AcquireError::closed()));
243 }
244
245 let acquired = next.acquire_permits(num_permits, &self.stub);
246
247 if !acquired {
248 // There are not enough available permits to satisfy the
249 // request. The permit transitions to a waiting state.
250 debug_assert!(curr.waiter().is_some() || curr.available_permits() < num_permits);
251
252 if let Some(waiter) = maybe_waiter.as_ref() {
253 // Safety: the caller owns the waiter.
254 let w = unsafe { waiter.as_ref() };
255 w.set_permits_to_acquire(num_permits - curr.available_permits());
256 } else {
257 // Get the waiter for the permit.
258 if let Some(waiter) = get_waiter() {
259 // Safety: the caller owns the waiter.
260 let w = unsafe { waiter.as_ref() };
261
262 // If there are any currently available permits, the
263 // waiter acquires those immediately and waits for the
264 // remaining permits to become available.
265 if !w.to_queued(num_permits - curr.available_permits()) {
266 // The node is alrady queued, there is no further work
267 // to do.
268 return Pending;
269 }
270
271 maybe_waiter = Some(waiter);
272 } else {
273 // No waiter, this indicates the caller does not wish to
274 // "wait", so there is nothing left to do.
275 return Pending;
276 }
277 }
278
279 next.set_waiter(maybe_waiter.unwrap());
280 }
281
282 debug_assert_ne!(curr.0, 0);
283 debug_assert_ne!(next.0, 0);
284
285 match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
286 Ok(_) => {
287 if acquired {
288 // Successfully acquire permits **without** queuing the
289 // waiter node. The waiter node is not currently in the
290 // queue.
291 revert_to_idle!();
292 return Ready(Ok(()));
293 } else {
294 // The node is pushed into the queue, the final step is
295 // to set the node's "next" pointer to return the wait
296 // queue into a consistent state.
297
298 let prev_waiter =
299 curr.waiter().unwrap_or_else(|| NonNull::from(&*self.stub));
300
301 let waiter = maybe_waiter.unwrap();
302
303 // Link the nodes.
304 //
305 // Safety: the mpsc algorithm guarantees the old tail of
306 // the queue is not removed from the queue during the
307 // push process.
308 unsafe {
309 prev_waiter.as_ref().store_next(waiter);
310 }
311
312 return Pending;
313 }
314 }
315 Err(actual) => {
316 curr = SemState(actual);
317 }
318 }
319 }
320 }
321
322 /// Closes the semaphore. This prevents the semaphore from issuing new
323 /// permits and notifies all pending waiters.
close(&self)324 pub(crate) fn close(&self) {
325 // Acquire the `rx_lock`, setting the "closed" flag on the lock.
326 let prev = self.rx_lock.fetch_or(1, AcqRel);
327
328 if prev != 0 {
329 // Another thread has the lock and will be responsible for notifying
330 // pending waiters.
331 return;
332 }
333
334 self.add_permits_locked(0, true);
335 }
336 /// Adds `n` new permits to the semaphore.
337 ///
338 /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded.
add_permits(&self, n: usize)339 pub(crate) fn add_permits(&self, n: usize) {
340 if n == 0 {
341 return;
342 }
343
344 // TODO: Handle overflow. A panic is not sufficient, the process must
345 // abort.
346 let prev = self.rx_lock.fetch_add(n << 1, AcqRel);
347
348 if prev != 0 {
349 // Another thread has the lock and will be responsible for notifying
350 // pending waiters.
351 return;
352 }
353
354 self.add_permits_locked(n, false);
355 }
356
add_permits_locked(&self, mut rem: usize, mut closed: bool)357 fn add_permits_locked(&self, mut rem: usize, mut closed: bool) {
358 while rem > 0 || closed {
359 if closed {
360 SemState::fetch_set_closed(&self.state, AcqRel);
361 }
362
363 // Release the permits and notify
364 self.add_permits_locked2(rem, closed);
365
366 let n = rem << 1;
367
368 let actual = if closed {
369 let actual = self.rx_lock.fetch_sub(n | 1, AcqRel);
370 closed = false;
371 actual
372 } else {
373 let actual = self.rx_lock.fetch_sub(n, AcqRel);
374 closed = actual & 1 == 1;
375 actual
376 };
377
378 rem = (actual >> 1) - rem;
379 }
380 }
381
382 /// Releases a specific amount of permits to the semaphore
383 ///
384 /// This function is called by `add_permits` after the add lock has been
385 /// acquired.
add_permits_locked2(&self, mut n: usize, closed: bool)386 fn add_permits_locked2(&self, mut n: usize, closed: bool) {
387 // If closing the semaphore, we want to drain the entire queue. The
388 // number of permits being assigned doesn't matter.
389 if closed {
390 n = usize::MAX;
391 }
392
393 'outer: while n > 0 {
394 unsafe {
395 let mut head = self.head.with(|head| *head);
396 let mut next_ptr = head.as_ref().next.load(Acquire);
397
398 let stub = self.stub();
399
400 if head == stub {
401 // The stub node indicates an empty queue. Any remaining
402 // permits get assigned back to the semaphore.
403 let next = match NonNull::new(next_ptr) {
404 Some(next) => next,
405 None => {
406 // This loop is not part of the standard intrusive mpsc
407 // channel algorithm. This is where we atomically pop
408 // the last task and add `n` to the remaining capacity.
409 //
410 // This modification to the pop algorithm works because,
411 // at this point, we have not done any work (only done
412 // reading). We have a *pretty* good idea that there is
413 // no concurrent pusher.
414 //
415 // The capacity is then atomically added by doing an
416 // AcqRel CAS on `state`. The `state` cell is the
417 // linchpin of the algorithm.
418 //
419 // By successfully CASing `head` w/ AcqRel, we ensure
420 // that, if any thread was racing and entered a push, we
421 // see that and abort pop, retrying as it is
422 // "inconsistent".
423 let mut curr = SemState::load(&self.state, Acquire);
424
425 loop {
426 if curr.has_waiter(&self.stub) {
427 // A waiter is being added concurrently.
428 // This is the MPSC queue's "inconsistent"
429 // state and we must loop and try again.
430 thread::yield_now();
431 continue 'outer;
432 }
433
434 // If closing, nothing more to do.
435 if closed {
436 debug_assert!(curr.is_closed(), "state = {:?}", curr);
437 return;
438 }
439
440 let mut next = curr;
441 next.release_permits(n, &self.stub);
442
443 match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
444 Ok(_) => return,
445 Err(actual) => {
446 curr = SemState(actual);
447 }
448 }
449 }
450 }
451 };
452
453 self.head.with_mut(|head| *head = next);
454 head = next;
455 next_ptr = next.as_ref().next.load(Acquire);
456 }
457
458 // `head` points to a waiter assign permits to the waiter. If
459 // all requested permits are satisfied, then we can continue,
460 // otherwise the node stays in the wait queue.
461 if !head.as_ref().assign_permits(&mut n, closed) {
462 assert_eq!(n, 0);
463 return;
464 }
465
466 if let Some(next) = NonNull::new(next_ptr) {
467 self.head.with_mut(|head| *head = next);
468
469 self.remove_queued(head, closed);
470 continue 'outer;
471 }
472
473 let state = SemState::load(&self.state, Acquire);
474
475 // This must always be a pointer as the wait list is not empty.
476 let tail = state.waiter().unwrap();
477
478 if tail != head {
479 // Inconsistent
480 thread::yield_now();
481 continue 'outer;
482 }
483
484 self.push_stub(closed);
485
486 next_ptr = head.as_ref().next.load(Acquire);
487
488 if let Some(next) = NonNull::new(next_ptr) {
489 self.head.with_mut(|head| *head = next);
490
491 self.remove_queued(head, closed);
492 continue 'outer;
493 }
494
495 // Inconsistent state, loop
496 thread::yield_now();
497 }
498 }
499 }
500
501 /// The wait node has had all of its permits assigned and has been removed
502 /// from the wait queue.
503 ///
504 /// Attempt to remove the QUEUED bit from the node. If additional permits
505 /// are concurrently requested, the node must be pushed back into the wait
506 /// queued.
remove_queued(&self, waiter: NonNull<Waiter>, closed: bool)507 fn remove_queued(&self, waiter: NonNull<Waiter>, closed: bool) {
508 let mut curr = WaiterState(unsafe { waiter.as_ref() }.state.load(Acquire));
509
510 loop {
511 if curr.is_dropped() {
512 // The Permit dropped, it is on us to release the memory
513 let _ = unsafe { Box::from_raw(waiter.as_ptr()) };
514 return;
515 }
516
517 // The node is removed from the queue. We attempt to unset the
518 // queued bit, but concurrently the waiter has requested more
519 // permits. When the waiter requested more permits, it saw the
520 // queued bit set so took no further action. This requires us to
521 // push the node back into the queue.
522 if curr.permits_to_acquire() > 0 {
523 // More permits are requested. The waiter must be re-queued
524 unsafe {
525 self.push_waiter(waiter, closed);
526 }
527 return;
528 }
529
530 let mut next = curr;
531 next.unset_queued();
532
533 let w = unsafe { waiter.as_ref() };
534
535 match w.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
536 Ok(_) => return,
537 Err(actual) => {
538 curr = WaiterState(actual);
539 }
540 }
541 }
542 }
543
push_stub(&self, closed: bool)544 unsafe fn push_stub(&self, closed: bool) {
545 self.push_waiter(self.stub(), closed);
546 }
547
push_waiter(&self, waiter: NonNull<Waiter>, closed: bool)548 unsafe fn push_waiter(&self, waiter: NonNull<Waiter>, closed: bool) {
549 // Set the next pointer. This does not require an atomic operation as
550 // this node is not accessible. The write will be flushed with the next
551 // operation
552 waiter.as_ref().next.store(ptr::null_mut(), Relaxed);
553
554 // Update the tail to point to the new node. We need to see the previous
555 // node in order to update the next pointer as well as release `task`
556 // to any other threads calling `push`.
557 let next = SemState::new_ptr(waiter, closed);
558 let prev = SemState(self.state.swap(next.0, AcqRel));
559
560 debug_assert_eq!(closed, prev.is_closed());
561
562 // This function is only called when there are pending tasks. Because of
563 // this, the state must *always* be in pointer mode.
564 let prev = prev.waiter().unwrap();
565
566 // No cycles plz
567 debug_assert_ne!(prev, waiter);
568
569 // Release `task` to the consume end.
570 prev.as_ref().next.store(waiter.as_ptr(), Release);
571 }
572
stub(&self) -> NonNull<Waiter>573 fn stub(&self) -> NonNull<Waiter> {
574 unsafe { NonNull::new_unchecked(&*self.stub as *const _ as *mut _) }
575 }
576 }
577
578 impl Drop for Semaphore {
drop(&mut self)579 fn drop(&mut self) {
580 self.close();
581 }
582 }
583
584 impl fmt::Debug for Semaphore {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result585 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
586 fmt.debug_struct("Semaphore")
587 .field("state", &SemState::load(&self.state, Relaxed))
588 .field("head", &self.head.with(|ptr| ptr))
589 .field("rx_lock", &self.rx_lock.load(Relaxed))
590 .field("stub", &self.stub)
591 .finish()
592 }
593 }
594
595 unsafe impl Send for Semaphore {}
596 unsafe impl Sync for Semaphore {}
597
598 // ===== impl Permit =====
599
600 impl Permit {
601 /// Creates a new `Permit`.
602 ///
603 /// The permit begins in the "unacquired" state.
new() -> Permit604 pub(crate) fn new() -> Permit {
605 use PermitState::Acquired;
606
607 Permit {
608 waiter: None,
609 state: Acquired(0),
610 }
611 }
612
613 /// Returns `true` if the permit has been acquired
614 #[allow(dead_code)] // may be used later
is_acquired(&self) -> bool615 pub(crate) fn is_acquired(&self) -> bool {
616 match self.state {
617 PermitState::Acquired(num) if num > 0 => true,
618 _ => false,
619 }
620 }
621
622 /// Tries to acquire the permit. If no permits are available, the current task
623 /// is notified once a new permit becomes available.
poll_acquire( &mut self, cx: &mut Context<'_>, num_permits: u16, semaphore: &Semaphore, ) -> Poll<Result<(), AcquireError>>624 pub(crate) fn poll_acquire(
625 &mut self,
626 cx: &mut Context<'_>,
627 num_permits: u16,
628 semaphore: &Semaphore,
629 ) -> Poll<Result<(), AcquireError>> {
630 use std::cmp::Ordering::*;
631 use PermitState::*;
632
633 match self.state {
634 Waiting(requested) => {
635 // There must be a waiter
636 let waiter = self.waiter.as_ref().unwrap();
637
638 match requested.cmp(&num_permits) {
639 Less => {
640 let delta = num_permits - requested;
641
642 // Request additional permits. If the waiter has been
643 // dequeued, it must be re-queued.
644 if !waiter.try_inc_permits_to_acquire(delta as usize) {
645 let waiter = NonNull::from(&**waiter);
646
647 // Ignore the result. The check for
648 // `permits_to_acquire()` will converge the state as
649 // needed
650 let _ = semaphore.poll_acquire2(delta, || Some(waiter))?;
651 }
652
653 self.state = Waiting(num_permits);
654 }
655 Greater => {
656 let delta = requested - num_permits;
657 let to_release = waiter.try_dec_permits_to_acquire(delta as usize);
658
659 semaphore.add_permits(to_release);
660 self.state = Waiting(num_permits);
661 }
662 Equal => {}
663 }
664
665 if waiter.permits_to_acquire()? == 0 {
666 self.state = Acquired(requested);
667 return Ready(Ok(()));
668 }
669
670 waiter.waker.register_by_ref(cx.waker());
671
672 if waiter.permits_to_acquire()? == 0 {
673 self.state = Acquired(requested);
674 return Ready(Ok(()));
675 }
676
677 Pending
678 }
679 Acquired(acquired) => {
680 if acquired >= num_permits {
681 Ready(Ok(()))
682 } else {
683 match semaphore.poll_acquire(cx, num_permits - acquired, self)? {
684 Ready(()) => {
685 self.state = Acquired(num_permits);
686 Ready(Ok(()))
687 }
688 Pending => {
689 self.state = Waiting(num_permits);
690 Pending
691 }
692 }
693 }
694 }
695 }
696 }
697
698 /// Tries to acquire the permit.
try_acquire( &mut self, num_permits: u16, semaphore: &Semaphore, ) -> Result<(), TryAcquireError>699 pub(crate) fn try_acquire(
700 &mut self,
701 num_permits: u16,
702 semaphore: &Semaphore,
703 ) -> Result<(), TryAcquireError> {
704 use PermitState::*;
705
706 match self.state {
707 Waiting(requested) => {
708 // There must be a waiter
709 let waiter = self.waiter.as_ref().unwrap();
710
711 if requested > num_permits {
712 let delta = requested - num_permits;
713 let to_release = waiter.try_dec_permits_to_acquire(delta as usize);
714
715 semaphore.add_permits(to_release);
716 self.state = Waiting(num_permits);
717 }
718
719 let res = waiter.permits_to_acquire().map_err(to_try_acquire)?;
720
721 if res == 0 {
722 if requested < num_permits {
723 // Try to acquire the additional permits
724 semaphore.try_acquire(num_permits - requested)?;
725 }
726
727 self.state = Acquired(num_permits);
728 Ok(())
729 } else {
730 Err(TryAcquireError::NoPermits)
731 }
732 }
733 Acquired(acquired) => {
734 if acquired < num_permits {
735 semaphore.try_acquire(num_permits - acquired)?;
736 self.state = Acquired(num_permits);
737 }
738
739 Ok(())
740 }
741 }
742 }
743
744 /// Releases a permit back to the semaphore
release(&mut self, n: u16, semaphore: &Semaphore)745 pub(crate) fn release(&mut self, n: u16, semaphore: &Semaphore) {
746 let n = self.forget(n);
747 semaphore.add_permits(n as usize);
748 }
749
750 /// Forgets the permit **without** releasing it back to the semaphore.
751 ///
752 /// After calling `forget`, `poll_acquire` is able to acquire new permit
753 /// from the semaphore.
754 ///
755 /// Repeatedly calling `forget` without associated calls to `add_permit`
756 /// will result in the semaphore losing all permits.
757 ///
758 /// Will forget **at most** the number of acquired permits. This number is
759 /// returned.
forget(&mut self, n: u16) -> u16760 pub(crate) fn forget(&mut self, n: u16) -> u16 {
761 use PermitState::*;
762
763 match self.state {
764 Waiting(requested) => {
765 let n = cmp::min(n, requested);
766
767 // Decrement
768 let acquired = self
769 .waiter
770 .as_ref()
771 .unwrap()
772 .try_dec_permits_to_acquire(n as usize) as u16;
773
774 if n == requested {
775 self.state = Acquired(0);
776 } else if acquired == requested - n {
777 self.state = Waiting(acquired);
778 } else {
779 self.state = Waiting(requested - n);
780 }
781
782 acquired
783 }
784 Acquired(acquired) => {
785 let n = cmp::min(n, acquired);
786 self.state = Acquired(acquired - n);
787 n
788 }
789 }
790 }
791 }
792
793 impl Default for Permit {
default() -> Self794 fn default() -> Self {
795 Self::new()
796 }
797 }
798
799 impl Drop for Permit {
drop(&mut self)800 fn drop(&mut self) {
801 if let Some(waiter) = self.waiter.take() {
802 // Set the dropped flag
803 let state = WaiterState(waiter.state.fetch_or(DROPPED, AcqRel));
804
805 if state.is_queued() {
806 // The waiter is stored in the queue. The semaphore will drop it
807 std::mem::forget(waiter);
808 }
809 }
810 }
811 }
812
813 // ===== impl AcquireError ====
814
815 impl AcquireError {
closed() -> AcquireError816 fn closed() -> AcquireError {
817 AcquireError(())
818 }
819 }
820
to_try_acquire(_: AcquireError) -> TryAcquireError821 fn to_try_acquire(_: AcquireError) -> TryAcquireError {
822 TryAcquireError::Closed
823 }
824
825 impl fmt::Display for AcquireError {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result826 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
827 write!(fmt, "semaphore closed")
828 }
829 }
830
831 impl std::error::Error for AcquireError {}
832
833 // ===== impl TryAcquireError =====
834
835 impl TryAcquireError {
836 /// Returns `true` if the error was caused by a closed semaphore.
is_closed(&self) -> bool837 pub(crate) fn is_closed(&self) -> bool {
838 match self {
839 TryAcquireError::Closed => true,
840 _ => false,
841 }
842 }
843
844 /// Returns `true` if the error was caused by calling `try_acquire` on a
845 /// semaphore with no available permits.
is_no_permits(&self) -> bool846 pub(crate) fn is_no_permits(&self) -> bool {
847 match self {
848 TryAcquireError::NoPermits => true,
849 _ => false,
850 }
851 }
852 }
853
854 impl fmt::Display for TryAcquireError {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result855 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
856 match self {
857 TryAcquireError::Closed => write!(fmt, "semaphore closed"),
858 TryAcquireError::NoPermits => write!(fmt, "no permits available"),
859 }
860 }
861 }
862
863 impl std::error::Error for TryAcquireError {}
864
865 // ===== impl Waiter =====
866
867 impl Waiter {
new() -> Waiter868 fn new() -> Waiter {
869 Waiter {
870 state: AtomicUsize::new(0),
871 waker: AtomicWaker::new(),
872 next: AtomicPtr::new(ptr::null_mut()),
873 }
874 }
875
permits_to_acquire(&self) -> Result<usize, AcquireError>876 fn permits_to_acquire(&self) -> Result<usize, AcquireError> {
877 let state = WaiterState(self.state.load(Acquire));
878
879 if state.is_closed() {
880 Err(AcquireError(()))
881 } else {
882 Ok(state.permits_to_acquire())
883 }
884 }
885
886 /// Only increments the number of permits *if* the waiter is currently
887 /// queued.
888 ///
889 /// # Returns
890 ///
891 /// `true` if the number of permits to acquire has been incremented. `false`
892 /// otherwise. On `false`, the caller should use `Semaphore::poll_acquire`.
try_inc_permits_to_acquire(&self, n: usize) -> bool893 fn try_inc_permits_to_acquire(&self, n: usize) -> bool {
894 let mut curr = WaiterState(self.state.load(Acquire));
895
896 loop {
897 if !curr.is_queued() {
898 assert_eq!(0, curr.permits_to_acquire());
899 return false;
900 }
901
902 let mut next = curr;
903 next.set_permits_to_acquire(n + curr.permits_to_acquire());
904
905 match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
906 Ok(_) => return true,
907 Err(actual) => curr = WaiterState(actual),
908 }
909 }
910 }
911
912 /// Try to decrement the number of permits to acquire. This returns the
913 /// actual number of permits that were decremented. The delta between `n`
914 /// and the return has been assigned to the permit and the caller must
915 /// assign these back to the semaphore.
try_dec_permits_to_acquire(&self, n: usize) -> usize916 fn try_dec_permits_to_acquire(&self, n: usize) -> usize {
917 let mut curr = WaiterState(self.state.load(Acquire));
918
919 loop {
920 if curr.is_closed() {
921 return 0;
922 }
923
924 if !curr.is_queued() {
925 assert_eq!(0, curr.permits_to_acquire());
926 }
927
928 let delta = cmp::min(n, curr.permits_to_acquire());
929 let rem = curr.permits_to_acquire() - delta;
930
931 let mut next = curr;
932 next.set_permits_to_acquire(rem);
933
934 match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
935 Ok(_) => return n - delta,
936 Err(actual) => curr = WaiterState(actual),
937 }
938 }
939 }
940
941 /// Store the number of remaining permits needed to satisfy the waiter and
942 /// transition to the "QUEUED" state.
943 ///
944 /// # Returns
945 ///
946 /// `true` if the `QUEUED` bit was set as part of the transition.
to_queued(&self, num_permits: usize) -> bool947 fn to_queued(&self, num_permits: usize) -> bool {
948 let mut curr = WaiterState(self.state.load(Acquire));
949
950 // The waiter should **not** be waiting for any permits.
951 debug_assert_eq!(curr.permits_to_acquire(), 0);
952
953 loop {
954 let mut next = curr;
955 next.set_permits_to_acquire(num_permits);
956 next.set_queued();
957
958 match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
959 Ok(_) => {
960 if curr.is_queued() {
961 return false;
962 } else {
963 // Make sure the next pointer is null
964 self.next.store(ptr::null_mut(), Relaxed);
965 return true;
966 }
967 }
968 Err(actual) => curr = WaiterState(actual),
969 }
970 }
971 }
972
973 /// Set the number of permits to acquire.
974 ///
975 /// This function is only called when the waiter is being inserted into the
976 /// wait queue. Because of this, there are no concurrent threads that can
977 /// modify the state and using `store` is safe.
set_permits_to_acquire(&self, num_permits: usize)978 fn set_permits_to_acquire(&self, num_permits: usize) {
979 debug_assert!(WaiterState(self.state.load(Acquire)).is_queued());
980
981 let mut state = WaiterState(QUEUED);
982 state.set_permits_to_acquire(num_permits);
983
984 self.state.store(state.0, Release);
985 }
986
987 /// Assign permits to the waiter.
988 ///
989 /// Returns `true` if the waiter should be removed from the queue
assign_permits(&self, n: &mut usize, closed: bool) -> bool990 fn assign_permits(&self, n: &mut usize, closed: bool) -> bool {
991 let mut curr = WaiterState(self.state.load(Acquire));
992
993 loop {
994 let mut next = curr;
995
996 // Number of permits to assign to this waiter
997 let assign = cmp::min(curr.permits_to_acquire(), *n);
998
999 // Assign the permits
1000 next.set_permits_to_acquire(curr.permits_to_acquire() - assign);
1001
1002 if closed {
1003 next.set_closed();
1004 }
1005
1006 match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
1007 Ok(_) => {
1008 // Update `n`
1009 *n -= assign;
1010
1011 if next.permits_to_acquire() == 0 {
1012 if curr.permits_to_acquire() > 0 {
1013 self.waker.wake();
1014 }
1015
1016 return true;
1017 } else {
1018 return false;
1019 }
1020 }
1021 Err(actual) => curr = WaiterState(actual),
1022 }
1023 }
1024 }
1025
revert_to_idle(&self)1026 fn revert_to_idle(&self) {
1027 // An idle node is not waiting on any permits
1028 self.state.store(0, Relaxed);
1029 }
1030
store_next(&self, next: NonNull<Waiter>)1031 fn store_next(&self, next: NonNull<Waiter>) {
1032 self.next.store(next.as_ptr(), Release);
1033 }
1034 }
1035
1036 // ===== impl SemState =====
1037
1038 impl SemState {
1039 /// Returns a new default `State` value.
new(permits: usize, stub: &Waiter) -> SemState1040 fn new(permits: usize, stub: &Waiter) -> SemState {
1041 assert!(permits <= MAX_PERMITS);
1042
1043 if permits > 0 {
1044 SemState((permits << NUM_SHIFT) | NUM_FLAG)
1045 } else {
1046 SemState(stub as *const _ as usize)
1047 }
1048 }
1049
1050 /// Returns a `State` tracking `ptr` as the tail of the queue.
new_ptr(tail: NonNull<Waiter>, closed: bool) -> SemState1051 fn new_ptr(tail: NonNull<Waiter>, closed: bool) -> SemState {
1052 let mut val = tail.as_ptr() as usize;
1053
1054 if closed {
1055 val |= CLOSED_FLAG;
1056 }
1057
1058 SemState(val)
1059 }
1060
1061 /// Returns the amount of remaining capacity
available_permits(self) -> usize1062 fn available_permits(self) -> usize {
1063 if !self.has_available_permits() {
1064 return 0;
1065 }
1066
1067 self.0 >> NUM_SHIFT
1068 }
1069
1070 /// Returns `true` if the state has permits that can be claimed by a waiter.
has_available_permits(self) -> bool1071 fn has_available_permits(self) -> bool {
1072 self.0 & NUM_FLAG == NUM_FLAG
1073 }
1074
has_waiter(self, stub: &Waiter) -> bool1075 fn has_waiter(self, stub: &Waiter) -> bool {
1076 !self.has_available_permits() && !self.is_stub(stub)
1077 }
1078
1079 /// Tries to atomically acquire specified number of permits.
1080 ///
1081 /// # Return
1082 ///
1083 /// Returns `true` if the specified number of permits were acquired, `false`
1084 /// otherwise. Returning false does not mean that there are no more
1085 /// available permits.
acquire_permits(&mut self, num: usize, stub: &Waiter) -> bool1086 fn acquire_permits(&mut self, num: usize, stub: &Waiter) -> bool {
1087 debug_assert!(num > 0);
1088
1089 if self.available_permits() < num {
1090 return false;
1091 }
1092
1093 debug_assert!(self.waiter().is_none());
1094
1095 self.0 -= num << NUM_SHIFT;
1096
1097 if self.0 == NUM_FLAG {
1098 // Set the state to the stub pointer.
1099 self.0 = stub as *const _ as usize;
1100 }
1101
1102 true
1103 }
1104
1105 /// Releases permits
1106 ///
1107 /// Returns `true` if the permits were accepted.
release_permits(&mut self, permits: usize, stub: &Waiter)1108 fn release_permits(&mut self, permits: usize, stub: &Waiter) {
1109 debug_assert!(permits > 0);
1110
1111 if self.is_stub(stub) {
1112 self.0 = (permits << NUM_SHIFT) | NUM_FLAG | (self.0 & CLOSED_FLAG);
1113 return;
1114 }
1115
1116 debug_assert!(self.has_available_permits());
1117
1118 self.0 += permits << NUM_SHIFT;
1119 }
1120
is_waiter(self) -> bool1121 fn is_waiter(self) -> bool {
1122 self.0 & NUM_FLAG == 0
1123 }
1124
1125 /// Returns the waiter, if one is set.
waiter(self) -> Option<NonNull<Waiter>>1126 fn waiter(self) -> Option<NonNull<Waiter>> {
1127 if self.is_waiter() {
1128 let waiter = NonNull::new(self.as_ptr()).expect("null pointer stored");
1129
1130 Some(waiter)
1131 } else {
1132 None
1133 }
1134 }
1135
1136 /// Assumes `self` represents a pointer
as_ptr(self) -> *mut Waiter1137 fn as_ptr(self) -> *mut Waiter {
1138 (self.0 & !CLOSED_FLAG) as *mut Waiter
1139 }
1140
1141 /// Sets to a pointer to a waiter.
1142 ///
1143 /// This can only be done from the full state.
set_waiter(&mut self, waiter: NonNull<Waiter>)1144 fn set_waiter(&mut self, waiter: NonNull<Waiter>) {
1145 let waiter = waiter.as_ptr() as usize;
1146 debug_assert!(!self.is_closed());
1147
1148 self.0 = waiter;
1149 }
1150
is_stub(self, stub: &Waiter) -> bool1151 fn is_stub(self, stub: &Waiter) -> bool {
1152 self.as_ptr() as usize == stub as *const _ as usize
1153 }
1154
1155 /// Loads the state from an AtomicUsize.
load(cell: &AtomicUsize, ordering: Ordering) -> SemState1156 fn load(cell: &AtomicUsize, ordering: Ordering) -> SemState {
1157 let value = cell.load(ordering);
1158 SemState(value)
1159 }
1160
fetch_set_closed(cell: &AtomicUsize, ordering: Ordering) -> SemState1161 fn fetch_set_closed(cell: &AtomicUsize, ordering: Ordering) -> SemState {
1162 let value = cell.fetch_or(CLOSED_FLAG, ordering);
1163 SemState(value)
1164 }
1165
is_closed(self) -> bool1166 fn is_closed(self) -> bool {
1167 self.0 & CLOSED_FLAG == CLOSED_FLAG
1168 }
1169
1170 /// Converts the state into a `usize` representation.
to_usize(self) -> usize1171 fn to_usize(self) -> usize {
1172 self.0
1173 }
1174 }
1175
1176 impl fmt::Debug for SemState {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result1177 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
1178 let mut fmt = fmt.debug_struct("SemState");
1179
1180 if self.is_waiter() {
1181 fmt.field("state", &"<waiter>");
1182 } else {
1183 fmt.field("permits", &self.available_permits());
1184 }
1185
1186 fmt.finish()
1187 }
1188 }
1189
1190 // ===== impl WaiterState =====
1191
1192 impl WaiterState {
permits_to_acquire(self) -> usize1193 fn permits_to_acquire(self) -> usize {
1194 self.0 >> PERMIT_SHIFT
1195 }
1196
set_permits_to_acquire(&mut self, val: usize)1197 fn set_permits_to_acquire(&mut self, val: usize) {
1198 self.0 = (val << PERMIT_SHIFT) | (self.0 & !PERMIT_MASK)
1199 }
1200
is_queued(self) -> bool1201 fn is_queued(self) -> bool {
1202 self.0 & QUEUED == QUEUED
1203 }
1204
set_queued(&mut self)1205 fn set_queued(&mut self) {
1206 self.0 |= QUEUED;
1207 }
1208
is_closed(self) -> bool1209 fn is_closed(self) -> bool {
1210 self.0 & CLOSED == CLOSED
1211 }
1212
set_closed(&mut self)1213 fn set_closed(&mut self) {
1214 self.0 |= CLOSED;
1215 }
1216
unset_queued(&mut self)1217 fn unset_queued(&mut self) {
1218 assert!(self.is_queued());
1219 self.0 -= QUEUED;
1220 }
1221
is_dropped(self) -> bool1222 fn is_dropped(self) -> bool {
1223 self.0 & DROPPED == DROPPED
1224 }
1225 }
1226