1 //! Thread-safe, asynchronous counting semaphore.
2 //!
3 //! A `Semaphore` instance holds a set of permits. Permits are used to
4 //! synchronize access to a shared resource.
5 //!
6 //! Before accessing the shared resource, callers acquire a permit from the
7 //! semaphore. Once the permit is acquired, the caller then enters the critical
8 //! section. If no permits are available, then acquiring the semaphore returns
9 //! `NotReady`. The task is notified once a permit becomes available.
10 
11 use loom::{
12     futures::AtomicTask,
13     sync::{
14         atomic::{AtomicPtr, AtomicUsize},
15         CausalCell,
16     },
17     yield_now,
18 };
19 
20 use futures::Poll;
21 
22 use std::fmt;
23 use std::ptr::{self, NonNull};
24 use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed, Release};
25 use std::sync::Arc;
26 use std::usize;
27 
28 /// Futures-aware semaphore.
29 pub struct Semaphore {
30     /// Tracks both the waiter queue tail pointer and the number of remaining
31     /// permits.
32     state: AtomicUsize,
33 
34     /// waiter queue head pointer.
35     head: CausalCell<NonNull<WaiterNode>>,
36 
37     /// Coordinates access to the queue head.
38     rx_lock: AtomicUsize,
39 
40     /// Stub waiter node used as part of the MPSC channel algorithm.
41     stub: Box<WaiterNode>,
42 }
43 
44 /// A semaphore permit
45 ///
46 /// Tracks the lifecycle of a semaphore permit.
47 ///
48 /// An instance of `Permit` is intended to be used with a **single** instance of
49 /// `Semaphore`. Using a single instance of `Permit` with multiple semaphore
50 /// instances will result in unexpected behavior.
51 ///
52 /// `Permit` does **not** release the permit back to the semaphore on drop. It
53 /// is the user's responsibility to ensure that `Permit::release` is called
54 /// before dropping the permit.
55 #[derive(Debug)]
56 pub struct Permit {
57     waiter: Option<Arc<WaiterNode>>,
58     state: PermitState,
59 }
60 
61 /// Error returned by `Permit::poll_acquire`.
62 #[derive(Debug)]
63 pub struct AcquireError(());
64 
65 /// Error returned by `Permit::try_acquire`.
66 #[derive(Debug)]
67 pub struct TryAcquireError {
68     kind: ErrorKind,
69 }
70 
71 #[derive(Debug)]
72 enum ErrorKind {
73     Closed,
74     NoPermits,
75 }
76 
77 /// Node used to notify the semaphore waiter when permit is available.
78 #[derive(Debug)]
79 struct WaiterNode {
80     /// Stores waiter state.
81     ///
82     /// See `NodeState` for more details.
83     state: AtomicUsize,
84 
85     /// Task to notify when a permit is made available.
86     task: AtomicTask,
87 
88     /// Next pointer in the queue of waiting senders.
89     next: AtomicPtr<WaiterNode>,
90 }
91 
92 /// Semaphore state
93 ///
94 /// The 2 low bits track the modes.
95 ///
96 /// - Closed
97 /// - Full
98 ///
99 /// When not full, the rest of the `usize` tracks the total number of messages
100 /// in the channel. When full, the rest of the `usize` is a pointer to the tail
101 /// of the "waiting senders" queue.
102 #[derive(Copy, Clone)]
103 struct SemState(usize);
104 
105 /// Permit state
106 #[derive(Debug, Copy, Clone, Eq, PartialEq)]
107 enum PermitState {
108     /// The permit has not been requested.
109     Idle,
110 
111     /// Currently waiting for a permit to be made available and assigned to the
112     /// waiter.
113     Waiting,
114 
115     /// The permit has been acquired.
116     Acquired,
117 }
118 
119 /// Waiter node state
120 #[derive(Debug, Copy, Clone, Eq, PartialEq)]
121 #[repr(usize)]
122 enum NodeState {
123     /// Not waiting for a permit and the node is not in the wait queue.
124     ///
125     /// This is the initial state.
126     Idle = 0,
127 
128     /// Not waiting for a permit but the node is in the wait queue.
129     ///
130     /// This happens when the waiter has previously requested a permit, but has
131     /// since canceled the request. The node cannot be removed by the waiter, so
132     /// this state informs the receiver to skip the node when it pops it from
133     /// the wait queue.
134     Queued = 1,
135 
136     /// Waiting for a permit and the node is in the wait queue.
137     QueuedWaiting = 2,
138 
139     /// The waiter has been assigned a permit and the node has been removed from
140     /// the queue.
141     Assigned = 3,
142 
143     /// The semaphore has been closed. No more permits will be issued.
144     Closed = 4,
145 }
146 
147 // ===== impl Semaphore =====
148 
149 impl Semaphore {
150     /// Creates a new semaphore with the initial number of permits
151     ///
152     /// # Panics
153     ///
154     /// Panics if `permits` is zero.
new(permits: usize) -> Semaphore155     pub fn new(permits: usize) -> Semaphore {
156         let stub = Box::new(WaiterNode::new());
157         let ptr = NonNull::new(&*stub as *const _ as *mut _).unwrap();
158 
159         // Allocations are aligned
160         debug_assert!(ptr.as_ptr() as usize & NUM_FLAG == 0);
161 
162         let state = SemState::new(permits, &stub);
163 
164         Semaphore {
165             state: AtomicUsize::new(state.to_usize()),
166             head: CausalCell::new(ptr),
167             rx_lock: AtomicUsize::new(0),
168             stub,
169         }
170     }
171 
172     /// Returns the current number of available permits
available_permits(&self) -> usize173     pub fn available_permits(&self) -> usize {
174         let curr = SemState::load(&self.state, Acquire);
175         curr.available_permits()
176     }
177 
178     /// Poll for a permit
poll_permit(&self, mut permit: Option<&mut Permit>) -> Poll<(), AcquireError>179     fn poll_permit(&self, mut permit: Option<&mut Permit>) -> Poll<(), AcquireError> {
180         use futures::Async::*;
181 
182         // Load the current state
183         let mut curr = SemState::load(&self.state, Acquire);
184 
185         debug!(" + poll_permit; sem-state = {:?}", curr);
186 
187         // Tracks a *mut WaiterNode representing an Arc clone.
188         //
189         // This avoids having to bump the ref count unless required.
190         let mut maybe_strong: Option<NonNull<WaiterNode>> = None;
191 
192         macro_rules! undo_strong {
193             () => {
194                 if let Some(waiter) = maybe_strong {
195                     // The waiter was cloned, but never got queued.
196                     // Before entering `poll_permit`, the waiter was in the
197                     // `Idle` state. We must transition the node back to the
198                     // idle state.
199                     let waiter = unsafe { Arc::from_raw(waiter.as_ptr()) };
200                     waiter.revert_to_idle();
201                 }
202             };
203         }
204 
205         loop {
206             let mut next = curr;
207 
208             if curr.is_closed() {
209                 undo_strong!();
210                 return Err(AcquireError::closed());
211             }
212 
213             if !next.acquire_permit(&self.stub) {
214                 debug!(" + poll_permit -- no permits");
215 
216                 debug_assert!(curr.waiter().is_some());
217 
218                 if maybe_strong.is_none() {
219                     if let Some(ref mut permit) = permit {
220                         // Get the Sender's waiter node, or initialize one
221                         let waiter = permit
222                             .waiter
223                             .get_or_insert_with(|| Arc::new(WaiterNode::new()));
224 
225                         waiter.register();
226 
227                         debug!(" + poll_permit -- to_queued_waiting");
228 
229                         if !waiter.to_queued_waiting() {
230                             debug!(" + poll_permit; waiter already queued");
231                             // The node is alrady queued, there is no further work
232                             // to do.
233                             return Ok(NotReady);
234                         }
235 
236                         maybe_strong = Some(WaiterNode::into_non_null(waiter.clone()));
237                     } else {
238                         // If no `waiter`, then the task is not registered and there
239                         // is no further work to do.
240                         return Ok(NotReady);
241                     }
242                 }
243 
244                 next.set_waiter(maybe_strong.unwrap());
245             }
246 
247             debug!(" + poll_permit -- pre-CAS; next = {:?}", next);
248 
249             debug_assert_ne!(curr.0, 0);
250             debug_assert_ne!(next.0, 0);
251 
252             match next.compare_exchange(&self.state, curr, AcqRel, Acquire) {
253                 Ok(_) => {
254                     debug!(" + poll_permit -- CAS ok");
255                     match curr.waiter() {
256                         Some(prev_waiter) => {
257                             let waiter = maybe_strong.unwrap();
258 
259                             // Finish pushing
260                             unsafe {
261                                 prev_waiter.as_ref().next.store(waiter.as_ptr(), Release);
262                             }
263 
264                             debug!(" + poll_permit -- waiter pushed");
265 
266                             return Ok(NotReady);
267                         }
268                         None => {
269                             debug!(" + poll_permit -- permit acquired");
270 
271                             undo_strong!();
272 
273                             return Ok(Ready(()));
274                         }
275                     }
276                 }
277                 Err(actual) => {
278                     curr = actual;
279                 }
280             }
281         }
282     }
283 
284     /// Close the semaphore. This prevents the semaphore from issuing new
285     /// permits and notifies all pending waiters.
close(&self)286     pub fn close(&self) {
287         debug!("+ Semaphore::close");
288 
289         // Acquire the `rx_lock`, setting the "closed" flag on the lock.
290         let prev = self.rx_lock.fetch_or(1, AcqRel);
291         debug!(" + close -- rx_lock.fetch_add(1)");
292 
293         if prev != 0 {
294             debug!("+ close -- locked; prev = {}", prev);
295             // Another thread has the lock and will be responsible for notifying
296             // pending waiters.
297             return;
298         }
299 
300         self.add_permits_locked(0, true);
301     }
302 
303     /// Add `n` new permits to the semaphore.
add_permits(&self, n: usize)304     pub fn add_permits(&self, n: usize) {
305         debug!(" + add_permits; n = {}", n);
306 
307         if n == 0 {
308             return;
309         }
310 
311         // TODO: Handle overflow. A panic is not sufficient, the process must
312         // abort.
313         let prev = self.rx_lock.fetch_add(n << 1, AcqRel);
314         debug!(" + add_permits; rx_lock.fetch_add(n << 1); n = {}", n);
315 
316         if prev != 0 {
317             debug!(" + add_permits -- locked; prev = {}", prev);
318             // Another thread has the lock and will be responsible for notifying
319             // pending waiters.
320             return;
321         }
322 
323         self.add_permits_locked(n, false);
324     }
325 
add_permits_locked(&self, mut rem: usize, mut closed: bool)326     fn add_permits_locked(&self, mut rem: usize, mut closed: bool) {
327         while rem > 0 || closed {
328             debug!(
329                 " + add_permits_locked -- iter; rem = {}; closed = {:?}",
330                 rem, closed
331             );
332 
333             if closed {
334                 SemState::fetch_set_closed(&self.state, AcqRel);
335             }
336 
337             // Release the permits and notify
338             self.add_permits_locked2(rem, closed);
339 
340             let n = rem << 1;
341 
342             let actual = if closed {
343                 let actual = self.rx_lock.fetch_sub(n | 1, AcqRel);
344                 debug!(
345                     " + add_permits_locked; rx_lock.fetch_sub(n | 1); n = {}; actual={}",
346                     n, actual
347                 );
348 
349                 closed = false;
350                 actual
351             } else {
352                 let actual = self.rx_lock.fetch_sub(n, AcqRel);
353                 debug!(
354                     " + add_permits_locked; rx_lock.fetch_sub(n); n = {}; actual={}",
355                     n, actual
356                 );
357 
358                 closed = actual & 1 == 1;
359                 actual
360             };
361 
362             rem = (actual >> 1) - rem;
363         }
364 
365         debug!(" + add_permits; done");
366     }
367 
368     /// Release a specific amount of permits to the semaphore
369     ///
370     /// This function is called by `add_permits` after the add lock has been
371     /// acquired.
add_permits_locked2(&self, mut n: usize, closed: bool)372     fn add_permits_locked2(&self, mut n: usize, closed: bool) {
373         while n > 0 || closed {
374             let waiter = match self.pop(n, closed) {
375                 Some(waiter) => waiter,
376                 None => {
377                     return;
378                 }
379             };
380 
381             debug!(" + release_n -- notify");
382 
383             if waiter.notify(closed) {
384                 n = n.saturating_sub(1);
385                 debug!(" + release_n -- dec");
386             }
387         }
388     }
389 
390     /// Pop a waiter
391     ///
392     /// `rem` represents the remaining number of times the caller will pop. If
393     /// there are no more waiters to pop, `rem` is used to set the available
394     /// permits.
pop(&self, rem: usize, closed: bool) -> Option<Arc<WaiterNode>>395     fn pop(&self, rem: usize, closed: bool) -> Option<Arc<WaiterNode>> {
396         debug!(" + pop; rem = {}", rem);
397 
398         'outer: loop {
399             unsafe {
400                 let mut head = self.head.with(|head| *head);
401                 let mut next_ptr = head.as_ref().next.load(Acquire);
402 
403                 let stub = self.stub();
404 
405                 if head == stub {
406                     debug!(" + pop; head == stub");
407 
408                     let next = match NonNull::new(next_ptr) {
409                         Some(next) => next,
410                         None => {
411                             // This loop is not part of the standard intrusive mpsc
412                             // channel algorithm. This is where we atomically pop
413                             // the last task and add `rem` to the remaining capacity.
414                             //
415                             // This modification to the pop algorithm works because,
416                             // at this point, we have not done any work (only done
417                             // reading). We have a *pretty* good idea that there is
418                             // no concurrent pusher.
419                             //
420                             // The capacity is then atomically added by doing an
421                             // AcqRel CAS on `state`. The `state` cell is the
422                             // linchpin of the algorithm.
423                             //
424                             // By successfully CASing `head` w/ AcqRel, we ensure
425                             // that, if any thread was racing and entered a push, we
426                             // see that and abort pop, retrying as it is
427                             // "inconsistent".
428                             let mut curr = SemState::load(&self.state, Acquire);
429 
430                             loop {
431                                 if curr.has_waiter(&self.stub) {
432                                     // Inconsistent
433                                     debug!(" + pop; inconsistent 1");
434                                     yield_now();
435                                     continue 'outer;
436                                 }
437 
438                                 // When closing the semaphore, nodes are popped
439                                 // with `rem == 0`. In this case, we are not
440                                 // adding permits, but notifying waiters of the
441                                 // semaphore's closed state.
442                                 if rem == 0 {
443                                     debug_assert!(curr.is_closed(), "state = {:?}", curr);
444                                     return None;
445                                 }
446 
447                                 let mut next = curr;
448                                 next.release_permits(rem, &self.stub);
449 
450                                 match next.compare_exchange(&self.state, curr, AcqRel, Acquire) {
451                                     Ok(_) => return None,
452                                     Err(actual) => {
453                                         curr = actual;
454                                     }
455                                 }
456                             }
457                         }
458                     };
459 
460                     debug!(" + pop; got next waiter");
461 
462                     self.head.with_mut(|head| *head = next);
463                     head = next;
464                     next_ptr = next.as_ref().next.load(Acquire);
465                 }
466 
467                 if let Some(next) = NonNull::new(next_ptr) {
468                     self.head.with_mut(|head| *head = next);
469 
470                     return Some(Arc::from_raw(head.as_ptr()));
471                 }
472 
473                 let state = SemState::load(&self.state, Acquire);
474 
475                 // This must always be a pointer as the wait list is not empty.
476                 let tail = state.waiter().unwrap();
477 
478                 if tail != head {
479                     // Inconsistent
480                     debug!(" + pop; inconsistent 2");
481                     yield_now();
482                     continue 'outer;
483                 }
484 
485                 self.push_stub(closed);
486 
487                 next_ptr = head.as_ref().next.load(Acquire);
488 
489                 if let Some(next) = NonNull::new(next_ptr) {
490                     self.head.with_mut(|head| *head = next);
491 
492                     return Some(Arc::from_raw(head.as_ptr()));
493                 }
494 
495                 // Inconsistent state, loop
496                 debug!(" + pop; inconsistent 3");
497                 yield_now();
498             }
499         }
500     }
501 
push_stub(&self, closed: bool)502     unsafe fn push_stub(&self, closed: bool) {
503         let stub = self.stub();
504 
505         // Set the next pointer. This does not require an atomic operation as
506         // this node is not accessible. The write will be flushed with the next
507         // operation
508         stub.as_ref().next.store(ptr::null_mut(), Relaxed);
509 
510         // Update the tail to point to the new node. We need to see the previous
511         // node in order to update the next pointer as well as release `task`
512         // to any other threads calling `push`.
513         let prev = SemState::new_ptr(stub, closed).swap(&self.state, AcqRel);
514 
515         debug_assert_eq!(closed, prev.is_closed());
516 
517         // The stub is only pushed when there are pending tasks. Because of
518         // this, the state must *always* be in pointer mode.
519         let prev = prev.waiter().unwrap();
520 
521         // We don't want the *existing* pointer to be a stub.
522         debug_assert_ne!(prev, stub);
523 
524         // Release `task` to the consume end.
525         prev.as_ref().next.store(stub.as_ptr(), Release);
526     }
527 
stub(&self) -> NonNull<WaiterNode>528     fn stub(&self) -> NonNull<WaiterNode> {
529         unsafe { NonNull::new_unchecked(&*self.stub as *const _ as *mut _) }
530     }
531 }
532 
533 impl fmt::Debug for Semaphore {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result534     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
535         fmt.debug_struct("Semaphore")
536             .field("state", &SemState::load(&self.state, Relaxed))
537             .field("head", &self.head.with(|ptr| ptr))
538             .field("rx_lock", &self.rx_lock.load(Relaxed))
539             .field("stub", &self.stub)
540             .finish()
541     }
542 }
543 
544 unsafe impl Send for Semaphore {}
545 unsafe impl Sync for Semaphore {}
546 
547 // ===== impl Permit =====
548 
549 impl Permit {
550     /// Create a new `Permit`.
551     ///
552     /// The permit begins in the "unacquired" state.
553     ///
554     /// # Examples
555     ///
556     /// ```
557     /// use tokio_sync::semaphore::Permit;
558     ///
559     /// let permit = Permit::new();
560     /// assert!(!permit.is_acquired());
561     /// ```
new() -> Permit562     pub fn new() -> Permit {
563         Permit {
564             waiter: None,
565             state: PermitState::Idle,
566         }
567     }
568 
569     /// Returns true if the permit has been acquired
is_acquired(&self) -> bool570     pub fn is_acquired(&self) -> bool {
571         self.state == PermitState::Acquired
572     }
573 
574     /// Try to acquire the permit. If no permits are available, the current task
575     /// is notified once a new permit becomes available.
poll_acquire(&mut self, semaphore: &Semaphore) -> Poll<(), AcquireError>576     pub fn poll_acquire(&mut self, semaphore: &Semaphore) -> Poll<(), AcquireError> {
577         use futures::Async::*;
578 
579         match self.state {
580             PermitState::Idle => {}
581             PermitState::Waiting => {
582                 let waiter = self.waiter.as_ref().unwrap();
583 
584                 if waiter.acquire()? {
585                     self.state = PermitState::Acquired;
586                     return Ok(Ready(()));
587                 } else {
588                     return Ok(NotReady);
589                 }
590             }
591             PermitState::Acquired => {
592                 return Ok(Ready(()));
593             }
594         }
595 
596         match semaphore.poll_permit(Some(self))? {
597             Ready(v) => {
598                 self.state = PermitState::Acquired;
599                 Ok(Ready(v))
600             }
601             NotReady => {
602                 self.state = PermitState::Waiting;
603                 Ok(NotReady)
604             }
605         }
606     }
607 
608     /// Try to acquire the permit.
try_acquire(&mut self, semaphore: &Semaphore) -> Result<(), TryAcquireError>609     pub fn try_acquire(&mut self, semaphore: &Semaphore) -> Result<(), TryAcquireError> {
610         use futures::Async::*;
611 
612         match self.state {
613             PermitState::Idle => {}
614             PermitState::Waiting => {
615                 let waiter = self.waiter.as_ref().unwrap();
616 
617                 if waiter.acquire2().map_err(to_try_acquire)? {
618                     self.state = PermitState::Acquired;
619                     return Ok(());
620                 } else {
621                     return Err(TryAcquireError::no_permits());
622                 }
623             }
624             PermitState::Acquired => {
625                 return Ok(());
626             }
627         }
628 
629         match semaphore.poll_permit(None).map_err(to_try_acquire)? {
630             Ready(()) => {
631                 self.state = PermitState::Acquired;
632                 Ok(())
633             }
634             NotReady => Err(TryAcquireError::no_permits()),
635         }
636     }
637 
638     /// Release a permit back to the semaphore
release(&mut self, semaphore: &Semaphore)639     pub fn release(&mut self, semaphore: &Semaphore) {
640         if self.forget2() {
641             semaphore.add_permits(1);
642         }
643     }
644 
645     /// Forget the permit **without** releasing it back to the semaphore.
646     ///
647     /// After calling `forget`, `poll_acquire` is able to acquire new permit
648     /// from the sempahore.
649     ///
650     /// Repeatedly calling `forget` without associated calls to `add_permit`
651     /// will result in the semaphore losing all permits.
forget(&mut self)652     pub fn forget(&mut self) {
653         self.forget2();
654     }
655 
656     /// Returns `true` if the permit was acquired
forget2(&mut self) -> bool657     fn forget2(&mut self) -> bool {
658         match self.state {
659             PermitState::Idle => false,
660             PermitState::Waiting => {
661                 let ret = self.waiter.as_ref().unwrap().cancel_interest();
662                 self.state = PermitState::Idle;
663                 ret
664             }
665             PermitState::Acquired => {
666                 self.state = PermitState::Idle;
667                 true
668             }
669         }
670     }
671 }
672 
673 // ===== impl AcquireError ====
674 
675 impl AcquireError {
closed() -> AcquireError676     fn closed() -> AcquireError {
677         AcquireError(())
678     }
679 }
680 
to_try_acquire(_: AcquireError) -> TryAcquireError681 fn to_try_acquire(_: AcquireError) -> TryAcquireError {
682     TryAcquireError::closed()
683 }
684 
685 impl fmt::Display for AcquireError {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result686     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
687         use std::error::Error;
688         write!(fmt, "{}", self.description())
689     }
690 }
691 
692 impl ::std::error::Error for AcquireError {
description(&self) -> &str693     fn description(&self) -> &str {
694         "semaphore closed"
695     }
696 }
697 
698 // ===== impl TryAcquireError =====
699 
700 impl TryAcquireError {
closed() -> TryAcquireError701     fn closed() -> TryAcquireError {
702         TryAcquireError {
703             kind: ErrorKind::Closed,
704         }
705     }
706 
no_permits() -> TryAcquireError707     fn no_permits() -> TryAcquireError {
708         TryAcquireError {
709             kind: ErrorKind::NoPermits,
710         }
711     }
712 
713     /// Returns true if the error was caused by a closed semaphore.
is_closed(&self) -> bool714     pub fn is_closed(&self) -> bool {
715         match self.kind {
716             ErrorKind::Closed => true,
717             _ => false,
718         }
719     }
720 
721     /// Returns true if the error was caused by calling `try_acquire` on a
722     /// semaphore with no available permits.
is_no_permits(&self) -> bool723     pub fn is_no_permits(&self) -> bool {
724         match self.kind {
725             ErrorKind::NoPermits => true,
726             _ => false,
727         }
728     }
729 }
730 
731 impl fmt::Display for TryAcquireError {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result732     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
733         use std::error::Error;
734         write!(fmt, "{}", self.description())
735     }
736 }
737 
738 impl ::std::error::Error for TryAcquireError {
description(&self) -> &str739     fn description(&self) -> &str {
740         match self.kind {
741             ErrorKind::Closed => "semaphore closed",
742             ErrorKind::NoPermits => "no permits available",
743         }
744     }
745 }
746 
747 // ===== impl WaiterNode =====
748 
749 impl WaiterNode {
new() -> WaiterNode750     fn new() -> WaiterNode {
751         WaiterNode {
752             state: AtomicUsize::new(NodeState::new().to_usize()),
753             task: AtomicTask::new(),
754             next: AtomicPtr::new(ptr::null_mut()),
755         }
756     }
757 
acquire(&self) -> Result<bool, AcquireError>758     fn acquire(&self) -> Result<bool, AcquireError> {
759         if self.acquire2()? {
760             return Ok(true);
761         }
762 
763         self.task.register();
764 
765         self.acquire2()
766     }
767 
acquire2(&self) -> Result<bool, AcquireError>768     fn acquire2(&self) -> Result<bool, AcquireError> {
769         use self::NodeState::*;
770 
771         match Idle.compare_exchange(&self.state, Assigned, AcqRel, Acquire) {
772             Ok(_) => Ok(true),
773             Err(Closed) => Err(AcquireError::closed()),
774             Err(_) => Ok(false),
775         }
776     }
777 
register(&self)778     fn register(&self) {
779         self.task.register()
780     }
781 
782     /// Returns `true` if the permit has been acquired
cancel_interest(&self) -> bool783     fn cancel_interest(&self) -> bool {
784         use self::NodeState::*;
785 
786         match Queued.compare_exchange(&self.state, QueuedWaiting, AcqRel, Acquire) {
787             // Successfully removed interest from the queued node. The permit
788             // has not been assigned to the node.
789             Ok(_) => false,
790             // The semaphore has been closed, there is no further action to
791             // take.
792             Err(Closed) => false,
793             // The permit has been assigned. It must be acquired in order to
794             // be released back to the semaphore.
795             Err(Assigned) => {
796                 match self.acquire2() {
797                     Ok(true) => true,
798                     // Not a reachable state
799                     Ok(false) => panic!(),
800                     // The semaphore has been closed, no further action to take.
801                     Err(_) => false,
802                 }
803             }
804             Err(state) => panic!("unexpected state = {:?}", state),
805         }
806     }
807 
808     /// Transition the state to `QueuedWaiting`.
809     ///
810     /// This step can only happen from `Queued` or from `Idle`.
811     ///
812     /// Returns `true` if transitioning into a queued state.
to_queued_waiting(&self) -> bool813     fn to_queued_waiting(&self) -> bool {
814         use self::NodeState::*;
815 
816         let mut curr = NodeState::load(&self.state, Acquire);
817 
818         loop {
819             debug_assert!(curr == Idle || curr == Queued, "actual = {:?}", curr);
820             let next = QueuedWaiting;
821 
822             match next.compare_exchange(&self.state, curr, AcqRel, Acquire) {
823                 Ok(_) => {
824                     if curr.is_queued() {
825                         return false;
826                     } else {
827                         // Transitioned to queued, reset next pointer
828                         self.next.store(ptr::null_mut(), Relaxed);
829                         return true;
830                     }
831                 }
832                 Err(actual) => {
833                     curr = actual;
834                 }
835             }
836         }
837     }
838 
839     /// Notify the waiter
840     ///
841     /// Returns `true` if the waiter accepts the notification
notify(&self, closed: bool) -> bool842     fn notify(&self, closed: bool) -> bool {
843         use self::NodeState::*;
844 
845         // Assume QueuedWaiting state
846         let mut curr = QueuedWaiting;
847 
848         loop {
849             let next = match curr {
850                 Queued => Idle,
851                 QueuedWaiting => {
852                     if closed {
853                         Closed
854                     } else {
855                         Assigned
856                     }
857                 }
858                 actual => panic!("actual = {:?}", actual),
859             };
860 
861             match next.compare_exchange(&self.state, curr, AcqRel, Acquire) {
862                 Ok(_) => match curr {
863                     QueuedWaiting => {
864                         debug!(" + notify -- task notified");
865                         self.task.notify();
866                         return true;
867                     }
868                     other => {
869                         debug!(" + notify -- not notified; state = {:?}", other);
870                         return false;
871                     }
872                 },
873                 Err(actual) => curr = actual,
874             }
875         }
876     }
877 
revert_to_idle(&self)878     fn revert_to_idle(&self) {
879         use self::NodeState::Idle;
880 
881         // There are no other handles to the node
882         NodeState::store(&self.state, Idle, Relaxed);
883     }
884 
into_non_null(arc: Arc<WaiterNode>) -> NonNull<WaiterNode>885     fn into_non_null(arc: Arc<WaiterNode>) -> NonNull<WaiterNode> {
886         let ptr = Arc::into_raw(arc);
887         unsafe { NonNull::new_unchecked(ptr as *mut _) }
888     }
889 }
890 
891 // ===== impl State =====
892 
893 /// Flag differentiating between available permits and waiter pointers.
894 ///
895 /// If we assume pointers are properly aligned, then the least significant bit
896 /// will always be zero. So, we use that bit to track if the value represents a
897 /// number.
898 const NUM_FLAG: usize = 0b01;
899 
900 const CLOSED_FLAG: usize = 0b10;
901 
902 const MAX_PERMITS: usize = usize::MAX >> NUM_SHIFT;
903 
904 /// When representing "numbers", the state has to be shifted this much (to get
905 /// rid of the flag bit).
906 const NUM_SHIFT: usize = 2;
907 
908 impl SemState {
909     /// Returns a new default `State` value.
new(permits: usize, stub: &WaiterNode) -> SemState910     fn new(permits: usize, stub: &WaiterNode) -> SemState {
911         assert!(permits <= MAX_PERMITS);
912 
913         if permits > 0 {
914             SemState((permits << NUM_SHIFT) | NUM_FLAG)
915         } else {
916             SemState(stub as *const _ as usize)
917         }
918     }
919 
920     /// Returns a `State` tracking `ptr` as the tail of the queue.
new_ptr(tail: NonNull<WaiterNode>, closed: bool) -> SemState921     fn new_ptr(tail: NonNull<WaiterNode>, closed: bool) -> SemState {
922         let mut val = tail.as_ptr() as usize;
923 
924         if closed {
925             val |= CLOSED_FLAG;
926         }
927 
928         SemState(val)
929     }
930 
931     /// Returns the amount of remaining capacity
available_permits(&self) -> usize932     fn available_permits(&self) -> usize {
933         if !self.has_available_permits() {
934             return 0;
935         }
936 
937         self.0 >> NUM_SHIFT
938     }
939 
940     /// Returns true if the state has permits that can be claimed by a waiter.
has_available_permits(&self) -> bool941     fn has_available_permits(&self) -> bool {
942         self.0 & NUM_FLAG == NUM_FLAG
943     }
944 
has_waiter(&self, stub: &WaiterNode) -> bool945     fn has_waiter(&self, stub: &WaiterNode) -> bool {
946         !self.has_available_permits() && !self.is_stub(stub)
947     }
948 
949     /// Try to acquire a permit
950     ///
951     /// # Return
952     ///
953     /// Returns `true` if the permit was acquired, `false` otherwise. If `false`
954     /// is returned, it can be assumed that `State` represents the head pointer
955     /// in the mpsc channel.
acquire_permit(&mut self, stub: &WaiterNode) -> bool956     fn acquire_permit(&mut self, stub: &WaiterNode) -> bool {
957         if !self.has_available_permits() {
958             return false;
959         }
960 
961         debug_assert!(self.waiter().is_none());
962 
963         self.0 -= 1 << NUM_SHIFT;
964 
965         if self.0 == NUM_FLAG {
966             // Set the state to the stub pointer.
967             self.0 = stub as *const _ as usize;
968         }
969 
970         true
971     }
972 
973     /// Release permits
974     ///
975     /// Returns `true` if the permits were accepted.
release_permits(&mut self, permits: usize, stub: &WaiterNode)976     fn release_permits(&mut self, permits: usize, stub: &WaiterNode) {
977         debug_assert!(permits > 0);
978 
979         if self.is_stub(stub) {
980             self.0 = (permits << NUM_SHIFT) | NUM_FLAG | (self.0 & CLOSED_FLAG);
981             return;
982         }
983 
984         debug_assert!(self.has_available_permits());
985 
986         self.0 += permits << NUM_SHIFT;
987     }
988 
is_waiter(&self) -> bool989     fn is_waiter(&self) -> bool {
990         self.0 & NUM_FLAG == 0
991     }
992 
993     /// Returns the waiter, if one is set.
waiter(&self) -> Option<NonNull<WaiterNode>>994     fn waiter(&self) -> Option<NonNull<WaiterNode>> {
995         if self.is_waiter() {
996             let waiter = NonNull::new(self.as_ptr()).expect("null pointer stored");
997 
998             Some(waiter)
999         } else {
1000             None
1001         }
1002     }
1003 
1004     /// Assumes `self` represents a pointer
as_ptr(&self) -> *mut WaiterNode1005     fn as_ptr(&self) -> *mut WaiterNode {
1006         (self.0 & !CLOSED_FLAG) as *mut WaiterNode
1007     }
1008 
1009     /// Set to a pointer to a waiter.
1010     ///
1011     /// This can only be done from the full state.
set_waiter(&mut self, waiter: NonNull<WaiterNode>)1012     fn set_waiter(&mut self, waiter: NonNull<WaiterNode>) {
1013         let waiter = waiter.as_ptr() as usize;
1014         debug_assert!(waiter & NUM_FLAG == 0);
1015         debug_assert!(!self.is_closed());
1016 
1017         self.0 = waiter;
1018     }
1019 
is_stub(&self, stub: &WaiterNode) -> bool1020     fn is_stub(&self, stub: &WaiterNode) -> bool {
1021         self.as_ptr() as usize == stub as *const _ as usize
1022     }
1023 
1024     /// Load the state from an AtomicUsize.
load(cell: &AtomicUsize, ordering: Ordering) -> SemState1025     fn load(cell: &AtomicUsize, ordering: Ordering) -> SemState {
1026         let value = cell.load(ordering);
1027         debug!(" + SemState::load; value = {}", value);
1028         SemState(value)
1029     }
1030 
1031     /// Swap the values
swap(&self, cell: &AtomicUsize, ordering: Ordering) -> SemState1032     fn swap(&self, cell: &AtomicUsize, ordering: Ordering) -> SemState {
1033         let prev = SemState(cell.swap(self.to_usize(), ordering));
1034         debug_assert_eq!(prev.is_closed(), self.is_closed());
1035         prev
1036     }
1037 
1038     /// Compare and exchange the current value into the provided cell
compare_exchange( &self, cell: &AtomicUsize, prev: SemState, success: Ordering, failure: Ordering, ) -> Result<SemState, SemState>1039     fn compare_exchange(
1040         &self,
1041         cell: &AtomicUsize,
1042         prev: SemState,
1043         success: Ordering,
1044         failure: Ordering,
1045     ) -> Result<SemState, SemState> {
1046         debug_assert_eq!(prev.is_closed(), self.is_closed());
1047 
1048         let res = cell.compare_exchange(prev.to_usize(), self.to_usize(), success, failure);
1049 
1050         debug!(
1051             " + SemState::compare_exchange; prev = {}; next = {}; result = {:?}",
1052             prev.to_usize(),
1053             self.to_usize(),
1054             res
1055         );
1056 
1057         res.map(SemState).map_err(SemState)
1058     }
1059 
fetch_set_closed(cell: &AtomicUsize, ordering: Ordering) -> SemState1060     fn fetch_set_closed(cell: &AtomicUsize, ordering: Ordering) -> SemState {
1061         let value = cell.fetch_or(CLOSED_FLAG, ordering);
1062         SemState(value)
1063     }
1064 
is_closed(&self) -> bool1065     fn is_closed(&self) -> bool {
1066         self.0 & CLOSED_FLAG == CLOSED_FLAG
1067     }
1068 
1069     /// Converts the state into a `usize` representation.
to_usize(&self) -> usize1070     fn to_usize(&self) -> usize {
1071         self.0
1072     }
1073 }
1074 
1075 impl fmt::Debug for SemState {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result1076     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
1077         let mut fmt = fmt.debug_struct("SemState");
1078 
1079         if self.is_waiter() {
1080             fmt.field("state", &"<waiter>");
1081         } else {
1082             fmt.field("permits", &self.available_permits());
1083         }
1084 
1085         fmt.finish()
1086     }
1087 }
1088 
1089 // ===== impl NodeState =====
1090 
1091 impl NodeState {
new() -> NodeState1092     fn new() -> NodeState {
1093         NodeState::Idle
1094     }
1095 
from_usize(value: usize) -> NodeState1096     fn from_usize(value: usize) -> NodeState {
1097         use self::NodeState::*;
1098 
1099         match value {
1100             0 => Idle,
1101             1 => Queued,
1102             2 => QueuedWaiting,
1103             3 => Assigned,
1104             4 => Closed,
1105             _ => panic!(),
1106         }
1107     }
1108 
load(cell: &AtomicUsize, ordering: Ordering) -> NodeState1109     fn load(cell: &AtomicUsize, ordering: Ordering) -> NodeState {
1110         NodeState::from_usize(cell.load(ordering))
1111     }
1112 
1113     /// Store a value
store(cell: &AtomicUsize, value: NodeState, ordering: Ordering)1114     fn store(cell: &AtomicUsize, value: NodeState, ordering: Ordering) {
1115         cell.store(value.to_usize(), ordering);
1116     }
1117 
compare_exchange( &self, cell: &AtomicUsize, prev: NodeState, success: Ordering, failure: Ordering, ) -> Result<NodeState, NodeState>1118     fn compare_exchange(
1119         &self,
1120         cell: &AtomicUsize,
1121         prev: NodeState,
1122         success: Ordering,
1123         failure: Ordering,
1124     ) -> Result<NodeState, NodeState> {
1125         cell.compare_exchange(prev.to_usize(), self.to_usize(), success, failure)
1126             .map(NodeState::from_usize)
1127             .map_err(NodeState::from_usize)
1128     }
1129 
1130     /// Returns `true` if `self` represents a queued state.
is_queued(&self) -> bool1131     fn is_queued(&self) -> bool {
1132         use self::NodeState::*;
1133 
1134         match *self {
1135             Queued | QueuedWaiting => true,
1136             _ => false,
1137         }
1138     }
1139 
to_usize(&self) -> usize1140     fn to_usize(&self) -> usize {
1141         *self as usize
1142     }
1143 }
1144