1 use crate::loom::cell::UnsafeCell;
2 use crate::loom::future::AtomicWaker;
3 use crate::loom::sync::atomic::AtomicUsize;
4 use crate::loom::sync::Arc;
5 use crate::sync::mpsc::error::{ClosedError, TryRecvError};
6 use crate::sync::mpsc::{error, list};
7 
8 use std::fmt;
9 use std::process;
10 use std::sync::atomic::Ordering::{AcqRel, Relaxed};
11 use std::task::Poll::{Pending, Ready};
12 use std::task::{Context, Poll};
13 
14 /// Channel sender
15 pub(crate) struct Tx<T, S: Semaphore> {
16     inner: Arc<Chan<T, S>>,
17     permit: S::Permit,
18 }
19 
20 impl<T, S: Semaphore> fmt::Debug for Tx<T, S>
21 where
22     S::Permit: fmt::Debug,
23     S: fmt::Debug,
24 {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result25     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
26         fmt.debug_struct("Tx")
27             .field("inner", &self.inner)
28             .field("permit", &self.permit)
29             .finish()
30     }
31 }
32 
33 /// Channel receiver
34 pub(crate) struct Rx<T, S: Semaphore> {
35     inner: Arc<Chan<T, S>>,
36 }
37 
38 impl<T, S: Semaphore> fmt::Debug for Rx<T, S>
39 where
40     S: fmt::Debug,
41 {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result42     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
43         fmt.debug_struct("Rx").field("inner", &self.inner).finish()
44     }
45 }
46 
47 #[derive(Debug, Eq, PartialEq)]
48 pub(crate) enum TrySendError {
49     Closed,
50     Full,
51 }
52 
53 impl<T> From<(T, TrySendError)> for error::SendError<T> {
from(src: (T, TrySendError)) -> error::SendError<T>54     fn from(src: (T, TrySendError)) -> error::SendError<T> {
55         match src.1 {
56             TrySendError::Closed => error::SendError(src.0),
57             TrySendError::Full => unreachable!(),
58         }
59     }
60 }
61 
62 impl<T> From<(T, TrySendError)> for error::TrySendError<T> {
from(src: (T, TrySendError)) -> error::TrySendError<T>63     fn from(src: (T, TrySendError)) -> error::TrySendError<T> {
64         match src.1 {
65             TrySendError::Closed => error::TrySendError::Closed(src.0),
66             TrySendError::Full => error::TrySendError::Full(src.0),
67         }
68     }
69 }
70 
71 pub(crate) trait Semaphore {
72     type Permit;
73 
new_permit() -> Self::Permit74     fn new_permit() -> Self::Permit;
75 
76     /// The permit is dropped without a value being sent. In this case, the
77     /// permit must be returned to the semaphore.
78     ///
79     /// # Return
80     ///
81     /// Returns true if the permit was acquired.
drop_permit(&self, permit: &mut Self::Permit) -> bool82     fn drop_permit(&self, permit: &mut Self::Permit) -> bool;
83 
is_idle(&self) -> bool84     fn is_idle(&self) -> bool;
85 
add_permit(&self)86     fn add_permit(&self);
87 
poll_acquire( &self, cx: &mut Context<'_>, permit: &mut Self::Permit, ) -> Poll<Result<(), ClosedError>>88     fn poll_acquire(
89         &self,
90         cx: &mut Context<'_>,
91         permit: &mut Self::Permit,
92     ) -> Poll<Result<(), ClosedError>>;
93 
try_acquire(&self, permit: &mut Self::Permit) -> Result<(), TrySendError>94     fn try_acquire(&self, permit: &mut Self::Permit) -> Result<(), TrySendError>;
95 
96     /// A value was sent into the channel and the permit held by `tx` is
97     /// dropped. In this case, the permit should not immeditely be returned to
98     /// the semaphore. Instead, the permit is returnred to the semaphore once
99     /// the sent value is read by the rx handle.
forget(&self, permit: &mut Self::Permit)100     fn forget(&self, permit: &mut Self::Permit);
101 
close(&self)102     fn close(&self);
103 }
104 
105 struct Chan<T, S> {
106     /// Handle to the push half of the lock-free list.
107     tx: list::Tx<T>,
108 
109     /// Coordinates access to channel's capacity.
110     semaphore: S,
111 
112     /// Receiver waker. Notified when a value is pushed into the channel.
113     rx_waker: AtomicWaker,
114 
115     /// Tracks the number of outstanding sender handles.
116     ///
117     /// When this drops to zero, the send half of the channel is closed.
118     tx_count: AtomicUsize,
119 
120     /// Only accessed by `Rx` handle.
121     rx_fields: UnsafeCell<RxFields<T>>,
122 }
123 
124 impl<T, S> fmt::Debug for Chan<T, S>
125 where
126     S: fmt::Debug,
127 {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result128     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
129         fmt.debug_struct("Chan")
130             .field("tx", &self.tx)
131             .field("semaphore", &self.semaphore)
132             .field("rx_waker", &self.rx_waker)
133             .field("tx_count", &self.tx_count)
134             .field("rx_fields", &"...")
135             .finish()
136     }
137 }
138 
139 /// Fields only accessed by `Rx` handle.
140 struct RxFields<T> {
141     /// Channel receiver. This field is only accessed by the `Receiver` type.
142     list: list::Rx<T>,
143 
144     /// `true` if `Rx::close` is called.
145     rx_closed: bool,
146 }
147 
148 impl<T> fmt::Debug for RxFields<T> {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result149     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
150         fmt.debug_struct("RxFields")
151             .field("list", &self.list)
152             .field("rx_closed", &self.rx_closed)
153             .finish()
154     }
155 }
156 
157 unsafe impl<T: Send, S: Send> Send for Chan<T, S> {}
158 unsafe impl<T: Send, S: Sync> Sync for Chan<T, S> {}
159 
channel<T, S>(semaphore: S) -> (Tx<T, S>, Rx<T, S>) where S: Semaphore,160 pub(crate) fn channel<T, S>(semaphore: S) -> (Tx<T, S>, Rx<T, S>)
161 where
162     S: Semaphore,
163 {
164     let (tx, rx) = list::channel();
165 
166     let chan = Arc::new(Chan {
167         tx,
168         semaphore,
169         rx_waker: AtomicWaker::new(),
170         tx_count: AtomicUsize::new(1),
171         rx_fields: UnsafeCell::new(RxFields {
172             list: rx,
173             rx_closed: false,
174         }),
175     });
176 
177     (Tx::new(chan.clone()), Rx::new(chan))
178 }
179 
180 // ===== impl Tx =====
181 
182 impl<T, S> Tx<T, S>
183 where
184     S: Semaphore,
185 {
new(chan: Arc<Chan<T, S>>) -> Tx<T, S>186     fn new(chan: Arc<Chan<T, S>>) -> Tx<T, S> {
187         Tx {
188             inner: chan,
189             permit: S::new_permit(),
190         }
191     }
192 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), ClosedError>>193     pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), ClosedError>> {
194         self.inner.semaphore.poll_acquire(cx, &mut self.permit)
195     }
196 
disarm(&mut self)197     pub(crate) fn disarm(&mut self) {
198         // TODO: should this error if not acquired?
199         self.inner.semaphore.drop_permit(&mut self.permit);
200     }
201 
202     /// Send a message and notify the receiver.
try_send(&mut self, value: T) -> Result<(), (T, TrySendError)>203     pub(crate) fn try_send(&mut self, value: T) -> Result<(), (T, TrySendError)> {
204         self.inner.try_send(value, &mut self.permit)
205     }
206 }
207 
208 impl<T> Tx<T, (crate::sync::semaphore_ll::Semaphore, usize)> {
is_ready(&self) -> bool209     pub(crate) fn is_ready(&self) -> bool {
210         self.permit.is_acquired()
211     }
212 }
213 
214 impl<T> Tx<T, AtomicUsize> {
send_unbounded(&self, value: T) -> Result<(), (T, TrySendError)>215     pub(crate) fn send_unbounded(&self, value: T) -> Result<(), (T, TrySendError)> {
216         self.inner.try_send(value, &mut ())
217     }
218 }
219 
220 impl<T, S> Clone for Tx<T, S>
221 where
222     S: Semaphore,
223 {
clone(&self) -> Tx<T, S>224     fn clone(&self) -> Tx<T, S> {
225         // Using a Relaxed ordering here is sufficient as the caller holds a
226         // strong ref to `self`, preventing a concurrent decrement to zero.
227         self.inner.tx_count.fetch_add(1, Relaxed);
228 
229         Tx {
230             inner: self.inner.clone(),
231             permit: S::new_permit(),
232         }
233     }
234 }
235 
236 impl<T, S> Drop for Tx<T, S>
237 where
238     S: Semaphore,
239 {
drop(&mut self)240     fn drop(&mut self) {
241         let notify = self.inner.semaphore.drop_permit(&mut self.permit);
242 
243         if notify && self.inner.semaphore.is_idle() {
244             self.inner.rx_waker.wake();
245         }
246 
247         if self.inner.tx_count.fetch_sub(1, AcqRel) != 1 {
248             return;
249         }
250 
251         // Close the list, which sends a `Close` message
252         self.inner.tx.close();
253 
254         // Notify the receiver
255         self.inner.rx_waker.wake();
256     }
257 }
258 
259 // ===== impl Rx =====
260 
261 impl<T, S> Rx<T, S>
262 where
263     S: Semaphore,
264 {
new(chan: Arc<Chan<T, S>>) -> Rx<T, S>265     fn new(chan: Arc<Chan<T, S>>) -> Rx<T, S> {
266         Rx { inner: chan }
267     }
268 
close(&mut self)269     pub(crate) fn close(&mut self) {
270         self.inner.rx_fields.with_mut(|rx_fields_ptr| {
271             let rx_fields = unsafe { &mut *rx_fields_ptr };
272 
273             if rx_fields.rx_closed {
274                 return;
275             }
276 
277             rx_fields.rx_closed = true;
278         });
279 
280         self.inner.semaphore.close();
281     }
282 
283     /// Receive the next value
recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>>284     pub(crate) fn recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
285         use super::block::Read::*;
286 
287         // Keep track of task budget
288         let coop = ready!(crate::coop::poll_proceed(cx));
289 
290         self.inner.rx_fields.with_mut(|rx_fields_ptr| {
291             let rx_fields = unsafe { &mut *rx_fields_ptr };
292 
293             macro_rules! try_recv {
294                 () => {
295                     match rx_fields.list.pop(&self.inner.tx) {
296                         Some(Value(value)) => {
297                             self.inner.semaphore.add_permit();
298                             coop.made_progress();
299                             return Ready(Some(value));
300                         }
301                         Some(Closed) => {
302                             // TODO: This check may not be required as it most
303                             // likely can only return `true` at this point. A
304                             // channel is closed when all tx handles are
305                             // dropped. Dropping a tx handle releases memory,
306                             // which ensures that if dropping the tx handle is
307                             // visible, then all messages sent are also visible.
308                             assert!(self.inner.semaphore.is_idle());
309                             coop.made_progress();
310                             return Ready(None);
311                         }
312                         None => {} // fall through
313                     }
314                 };
315             }
316 
317             try_recv!();
318 
319             self.inner.rx_waker.register_by_ref(cx.waker());
320 
321             // It is possible that a value was pushed between attempting to read
322             // and registering the task, so we have to check the channel a
323             // second time here.
324             try_recv!();
325 
326             if rx_fields.rx_closed && self.inner.semaphore.is_idle() {
327                 coop.made_progress();
328                 Ready(None)
329             } else {
330                 Pending
331             }
332         })
333     }
334 
335     /// Receives the next value without blocking
try_recv(&mut self) -> Result<T, TryRecvError>336     pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> {
337         use super::block::Read::*;
338         self.inner.rx_fields.with_mut(|rx_fields_ptr| {
339             let rx_fields = unsafe { &mut *rx_fields_ptr };
340             match rx_fields.list.pop(&self.inner.tx) {
341                 Some(Value(value)) => {
342                     self.inner.semaphore.add_permit();
343                     Ok(value)
344                 }
345                 Some(Closed) => Err(TryRecvError::Closed),
346                 None => Err(TryRecvError::Empty),
347             }
348         })
349     }
350 }
351 
352 impl<T, S> Drop for Rx<T, S>
353 where
354     S: Semaphore,
355 {
drop(&mut self)356     fn drop(&mut self) {
357         use super::block::Read::Value;
358 
359         self.close();
360 
361         self.inner.rx_fields.with_mut(|rx_fields_ptr| {
362             let rx_fields = unsafe { &mut *rx_fields_ptr };
363 
364             while let Some(Value(_)) = rx_fields.list.pop(&self.inner.tx) {
365                 self.inner.semaphore.add_permit();
366             }
367         })
368     }
369 }
370 
371 // ===== impl Chan =====
372 
373 impl<T, S> Chan<T, S>
374 where
375     S: Semaphore,
376 {
try_send(&self, value: T, permit: &mut S::Permit) -> Result<(), (T, TrySendError)>377     fn try_send(&self, value: T, permit: &mut S::Permit) -> Result<(), (T, TrySendError)> {
378         if let Err(e) = self.semaphore.try_acquire(permit) {
379             return Err((value, e));
380         }
381 
382         // Push the value
383         self.tx.push(value);
384 
385         // Notify the rx task
386         self.rx_waker.wake();
387 
388         // Release the permit
389         self.semaphore.forget(permit);
390 
391         Ok(())
392     }
393 }
394 
395 impl<T, S> Drop for Chan<T, S> {
drop(&mut self)396     fn drop(&mut self) {
397         use super::block::Read::Value;
398 
399         // Safety: the only owner of the rx fields is Chan, and eing
400         // inside its own Drop means we're the last ones to touch it.
401         self.rx_fields.with_mut(|rx_fields_ptr| {
402             let rx_fields = unsafe { &mut *rx_fields_ptr };
403 
404             while let Some(Value(_)) = rx_fields.list.pop(&self.tx) {}
405             unsafe { rx_fields.list.free_blocks() };
406         });
407     }
408 }
409 
410 use crate::sync::semaphore_ll::TryAcquireError;
411 
412 impl From<TryAcquireError> for TrySendError {
from(src: TryAcquireError) -> TrySendError413     fn from(src: TryAcquireError) -> TrySendError {
414         if src.is_closed() {
415             TrySendError::Closed
416         } else if src.is_no_permits() {
417             TrySendError::Full
418         } else {
419             unreachable!();
420         }
421     }
422 }
423 
424 // ===== impl Semaphore for (::Semaphore, capacity) =====
425 
426 use crate::sync::semaphore_ll::Permit;
427 
428 impl Semaphore for (crate::sync::semaphore_ll::Semaphore, usize) {
429     type Permit = Permit;
430 
new_permit() -> Permit431     fn new_permit() -> Permit {
432         Permit::new()
433     }
434 
drop_permit(&self, permit: &mut Permit) -> bool435     fn drop_permit(&self, permit: &mut Permit) -> bool {
436         let ret = permit.is_acquired();
437         permit.release(1, &self.0);
438         ret
439     }
440 
add_permit(&self)441     fn add_permit(&self) {
442         self.0.add_permits(1)
443     }
444 
is_idle(&self) -> bool445     fn is_idle(&self) -> bool {
446         self.0.available_permits() == self.1
447     }
448 
poll_acquire( &self, cx: &mut Context<'_>, permit: &mut Permit, ) -> Poll<Result<(), ClosedError>>449     fn poll_acquire(
450         &self,
451         cx: &mut Context<'_>,
452         permit: &mut Permit,
453     ) -> Poll<Result<(), ClosedError>> {
454         // Keep track of task budget
455         let coop = ready!(crate::coop::poll_proceed(cx));
456 
457         permit
458             .poll_acquire(cx, 1, &self.0)
459             .map_err(|_| ClosedError::new())
460             .map(move |r| {
461                 coop.made_progress();
462                 r
463             })
464     }
465 
try_acquire(&self, permit: &mut Permit) -> Result<(), TrySendError>466     fn try_acquire(&self, permit: &mut Permit) -> Result<(), TrySendError> {
467         permit.try_acquire(1, &self.0)?;
468         Ok(())
469     }
470 
forget(&self, permit: &mut Self::Permit)471     fn forget(&self, permit: &mut Self::Permit) {
472         permit.forget(1);
473     }
474 
close(&self)475     fn close(&self) {
476         self.0.close();
477     }
478 }
479 
480 // ===== impl Semaphore for AtomicUsize =====
481 
482 use std::sync::atomic::Ordering::{Acquire, Release};
483 use std::usize;
484 
485 impl Semaphore for AtomicUsize {
486     type Permit = ();
487 
new_permit()488     fn new_permit() {}
489 
drop_permit(&self, _permit: &mut ()) -> bool490     fn drop_permit(&self, _permit: &mut ()) -> bool {
491         false
492     }
493 
add_permit(&self)494     fn add_permit(&self) {
495         let prev = self.fetch_sub(2, Release);
496 
497         if prev >> 1 == 0 {
498             // Something went wrong
499             process::abort();
500         }
501     }
502 
is_idle(&self) -> bool503     fn is_idle(&self) -> bool {
504         self.load(Acquire) >> 1 == 0
505     }
506 
poll_acquire( &self, _cx: &mut Context<'_>, permit: &mut (), ) -> Poll<Result<(), ClosedError>>507     fn poll_acquire(
508         &self,
509         _cx: &mut Context<'_>,
510         permit: &mut (),
511     ) -> Poll<Result<(), ClosedError>> {
512         Ready(self.try_acquire(permit).map_err(|_| ClosedError::new()))
513     }
514 
try_acquire(&self, _permit: &mut ()) -> Result<(), TrySendError>515     fn try_acquire(&self, _permit: &mut ()) -> Result<(), TrySendError> {
516         let mut curr = self.load(Acquire);
517 
518         loop {
519             if curr & 1 == 1 {
520                 return Err(TrySendError::Closed);
521             }
522 
523             if curr == usize::MAX ^ 1 {
524                 // Overflowed the ref count. There is no safe way to recover, so
525                 // abort the process. In practice, this should never happen.
526                 process::abort()
527             }
528 
529             match self.compare_exchange(curr, curr + 2, AcqRel, Acquire) {
530                 Ok(_) => return Ok(()),
531                 Err(actual) => {
532                     curr = actual;
533                 }
534             }
535         }
536     }
537 
forget(&self, _permit: &mut ())538     fn forget(&self, _permit: &mut ()) {}
539 
close(&self)540     fn close(&self) {
541         self.fetch_or(1, Release);
542     }
543 }
544