1 use futures_util::future;
2 use tokio::sync::{mpsc, oneshot};
3 
4 use crate::common::{task, Future, Pin, Poll};
5 
6 pub type RetryPromise<T, U> = oneshot::Receiver<Result<U, (crate::Error, Option<T>)>>;
7 pub type Promise<T> = oneshot::Receiver<Result<T, crate::Error>>;
8 
channel<T, U>() -> (Sender<T, U>, Receiver<T, U>)9 pub fn channel<T, U>() -> (Sender<T, U>, Receiver<T, U>) {
10     let (tx, rx) = mpsc::unbounded_channel();
11     let (giver, taker) = want::new();
12     let tx = Sender {
13         buffered_once: false,
14         giver,
15         inner: tx,
16     };
17     let rx = Receiver { inner: rx, taker };
18     (tx, rx)
19 }
20 
21 /// A bounded sender of requests and callbacks for when responses are ready.
22 ///
23 /// While the inner sender is unbounded, the Giver is used to determine
24 /// if the Receiver is ready for another request.
25 pub struct Sender<T, U> {
26     /// One message is always allowed, even if the Receiver hasn't asked
27     /// for it yet. This boolean keeps track of whether we've sent one
28     /// without notice.
29     buffered_once: bool,
30     /// The Giver helps watch that the the Receiver side has been polled
31     /// when the queue is empty. This helps us know when a request and
32     /// response have been fully processed, and a connection is ready
33     /// for more.
34     giver: want::Giver,
35     /// Actually bounded by the Giver, plus `buffered_once`.
36     inner: mpsc::UnboundedSender<Envelope<T, U>>,
37 }
38 
39 /// An unbounded version.
40 ///
41 /// Cannot poll the Giver, but can still use it to determine if the Receiver
42 /// has been dropped. However, this version can be cloned.
43 pub struct UnboundedSender<T, U> {
44     /// Only used for `is_closed`, since mpsc::UnboundedSender cannot be checked.
45     giver: want::SharedGiver,
46     inner: mpsc::UnboundedSender<Envelope<T, U>>,
47 }
48 
49 impl<T, U> Sender<T, U> {
poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>>50     pub fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
51         self.giver
52             .poll_want(cx)
53             .map_err(|_| crate::Error::new_closed())
54     }
55 
is_ready(&self) -> bool56     pub fn is_ready(&self) -> bool {
57         self.giver.is_wanting()
58     }
59 
is_closed(&self) -> bool60     pub fn is_closed(&self) -> bool {
61         self.giver.is_canceled()
62     }
63 
can_send(&mut self) -> bool64     fn can_send(&mut self) -> bool {
65         if self.giver.give() || !self.buffered_once {
66             // If the receiver is ready *now*, then of course we can send.
67             //
68             // If the receiver isn't ready yet, but we don't have anything
69             // in the channel yet, then allow one message.
70             self.buffered_once = true;
71             true
72         } else {
73             false
74         }
75     }
76 
try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T>77     pub fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> {
78         if !self.can_send() {
79             return Err(val);
80         }
81         let (tx, rx) = oneshot::channel();
82         self.inner
83             .send(Envelope(Some((val, Callback::Retry(tx)))))
84             .map(move |_| rx)
85             .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
86     }
87 
send(&mut self, val: T) -> Result<Promise<U>, T>88     pub fn send(&mut self, val: T) -> Result<Promise<U>, T> {
89         if !self.can_send() {
90             return Err(val);
91         }
92         let (tx, rx) = oneshot::channel();
93         self.inner
94             .send(Envelope(Some((val, Callback::NoRetry(tx)))))
95             .map(move |_| rx)
96             .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
97     }
98 
unbound(self) -> UnboundedSender<T, U>99     pub fn unbound(self) -> UnboundedSender<T, U> {
100         UnboundedSender {
101             giver: self.giver.shared(),
102             inner: self.inner,
103         }
104     }
105 }
106 
107 impl<T, U> UnboundedSender<T, U> {
is_ready(&self) -> bool108     pub fn is_ready(&self) -> bool {
109         !self.giver.is_canceled()
110     }
111 
is_closed(&self) -> bool112     pub fn is_closed(&self) -> bool {
113         self.giver.is_canceled()
114     }
115 
try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T>116     pub fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> {
117         let (tx, rx) = oneshot::channel();
118         self.inner
119             .send(Envelope(Some((val, Callback::Retry(tx)))))
120             .map(move |_| rx)
121             .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
122     }
123 }
124 
125 impl<T, U> Clone for UnboundedSender<T, U> {
clone(&self) -> Self126     fn clone(&self) -> Self {
127         UnboundedSender {
128             giver: self.giver.clone(),
129             inner: self.inner.clone(),
130         }
131     }
132 }
133 
134 pub struct Receiver<T, U> {
135     inner: mpsc::UnboundedReceiver<Envelope<T, U>>,
136     taker: want::Taker,
137 }
138 
139 impl<T, U> Receiver<T, U> {
poll_next( &mut self, cx: &mut task::Context<'_>, ) -> Poll<Option<(T, Callback<T, U>)>>140     pub(crate) fn poll_next(
141         &mut self,
142         cx: &mut task::Context<'_>,
143     ) -> Poll<Option<(T, Callback<T, U>)>> {
144         match self.inner.poll_recv(cx) {
145             Poll::Ready(item) => {
146                 Poll::Ready(item.map(|mut env| env.0.take().expect("envelope not dropped")))
147             }
148             Poll::Pending => {
149                 self.taker.want();
150                 Poll::Pending
151             }
152         }
153     }
154 
close(&mut self)155     pub(crate) fn close(&mut self) {
156         self.taker.cancel();
157         self.inner.close();
158     }
159 
try_recv(&mut self) -> Option<(T, Callback<T, U>)>160     pub(crate) fn try_recv(&mut self) -> Option<(T, Callback<T, U>)> {
161         match self.inner.try_recv() {
162             Ok(mut env) => env.0.take(),
163             Err(_) => None,
164         }
165     }
166 }
167 
168 impl<T, U> Drop for Receiver<T, U> {
drop(&mut self)169     fn drop(&mut self) {
170         // Notify the giver about the closure first, before dropping
171         // the mpsc::Receiver.
172         self.taker.cancel();
173     }
174 }
175 
176 struct Envelope<T, U>(Option<(T, Callback<T, U>)>);
177 
178 impl<T, U> Drop for Envelope<T, U> {
drop(&mut self)179     fn drop(&mut self) {
180         if let Some((val, cb)) = self.0.take() {
181             cb.send(Err((
182                 crate::Error::new_canceled().with("connection closed"),
183                 Some(val),
184             )));
185         }
186     }
187 }
188 
189 pub enum Callback<T, U> {
190     Retry(oneshot::Sender<Result<U, (crate::Error, Option<T>)>>),
191     NoRetry(oneshot::Sender<Result<U, crate::Error>>),
192 }
193 
194 impl<T, U> Callback<T, U> {
is_canceled(&self) -> bool195     pub(crate) fn is_canceled(&self) -> bool {
196         match *self {
197             Callback::Retry(ref tx) => tx.is_closed(),
198             Callback::NoRetry(ref tx) => tx.is_closed(),
199         }
200     }
201 
poll_canceled(&mut self, cx: &mut task::Context<'_>) -> Poll<()>202     pub(crate) fn poll_canceled(&mut self, cx: &mut task::Context<'_>) -> Poll<()> {
203         match *self {
204             Callback::Retry(ref mut tx) => tx.poll_closed(cx),
205             Callback::NoRetry(ref mut tx) => tx.poll_closed(cx),
206         }
207     }
208 
send(self, val: Result<U, (crate::Error, Option<T>)>)209     pub(crate) fn send(self, val: Result<U, (crate::Error, Option<T>)>) {
210         match self {
211             Callback::Retry(tx) => {
212                 let _ = tx.send(val);
213             }
214             Callback::NoRetry(tx) => {
215                 let _ = tx.send(val.map_err(|e| e.0));
216             }
217         }
218     }
219 
send_when( self, mut when: impl Future<Output = Result<U, (crate::Error, Option<T>)>> + Unpin, ) -> impl Future<Output = ()>220     pub(crate) fn send_when(
221         self,
222         mut when: impl Future<Output = Result<U, (crate::Error, Option<T>)>> + Unpin,
223     ) -> impl Future<Output = ()> {
224         let mut cb = Some(self);
225 
226         // "select" on this callback being canceled, and the future completing
227         future::poll_fn(move |cx| {
228             match Pin::new(&mut when).poll(cx) {
229                 Poll::Ready(Ok(res)) => {
230                     cb.take().expect("polled after complete").send(Ok(res));
231                     Poll::Ready(())
232                 }
233                 Poll::Pending => {
234                     // check if the callback is canceled
235                     ready!(cb.as_mut().unwrap().poll_canceled(cx));
236                     trace!("send_when canceled");
237                     Poll::Ready(())
238                 }
239                 Poll::Ready(Err(err)) => {
240                     cb.take().expect("polled after complete").send(Err(err));
241                     Poll::Ready(())
242                 }
243             }
244         })
245     }
246 }
247 
248 #[cfg(test)]
249 mod tests {
250     #[cfg(feature = "nightly")]
251     extern crate test;
252 
253     use std::future::Future;
254     use std::pin::Pin;
255     use std::task::{Context, Poll};
256 
257     use super::{channel, Callback, Receiver};
258 
259     #[derive(Debug)]
260     struct Custom(i32);
261 
262     impl<T, U> Future for Receiver<T, U> {
263         type Output = Option<(T, Callback<T, U>)>;
264 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>265         fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
266             self.poll_next(cx)
267         }
268     }
269 
270     /// Helper to check if the future is ready after polling once.
271     struct PollOnce<'a, F>(&'a mut F);
272 
273     impl<F, T> Future for PollOnce<'_, F>
274     where
275         F: Future<Output = T> + Unpin,
276     {
277         type Output = Option<()>;
278 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>279         fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
280             match Pin::new(&mut self.0).poll(cx) {
281                 Poll::Ready(_) => Poll::Ready(Some(())),
282                 Poll::Pending => Poll::Ready(None),
283             }
284         }
285     }
286 
287     #[tokio::test]
drop_receiver_sends_cancel_errors()288     async fn drop_receiver_sends_cancel_errors() {
289         let _ = pretty_env_logger::try_init();
290 
291         let (mut tx, mut rx) = channel::<Custom, ()>();
292 
293         // must poll once for try_send to succeed
294         assert!(PollOnce(&mut rx).await.is_none(), "rx empty");
295 
296         let promise = tx.try_send(Custom(43)).unwrap();
297         drop(rx);
298 
299         let fulfilled = promise.await;
300         let err = fulfilled
301             .expect("fulfilled")
302             .expect_err("promise should error");
303         match (err.0.kind(), err.1) {
304             (&crate::error::Kind::Canceled, Some(_)) => (),
305             e => panic!("expected Error::Cancel(_), found {:?}", e),
306         }
307     }
308 
309     #[tokio::test]
sender_checks_for_want_on_send()310     async fn sender_checks_for_want_on_send() {
311         let (mut tx, mut rx) = channel::<Custom, ()>();
312 
313         // one is allowed to buffer, second is rejected
314         let _ = tx.try_send(Custom(1)).expect("1 buffered");
315         tx.try_send(Custom(2)).expect_err("2 not ready");
316 
317         assert!(PollOnce(&mut rx).await.is_some(), "rx once");
318 
319         // Even though 1 has been popped, only 1 could be buffered for the
320         // lifetime of the channel.
321         tx.try_send(Custom(2)).expect_err("2 still not ready");
322 
323         assert!(PollOnce(&mut rx).await.is_none(), "rx empty");
324 
325         let _ = tx.try_send(Custom(2)).expect("2 ready");
326     }
327 
328     #[test]
unbounded_sender_doesnt_bound_on_want()329     fn unbounded_sender_doesnt_bound_on_want() {
330         let (tx, rx) = channel::<Custom, ()>();
331         let mut tx = tx.unbound();
332 
333         let _ = tx.try_send(Custom(1)).unwrap();
334         let _ = tx.try_send(Custom(2)).unwrap();
335         let _ = tx.try_send(Custom(3)).unwrap();
336 
337         drop(rx);
338 
339         let _ = tx.try_send(Custom(4)).unwrap_err();
340     }
341 
342     #[cfg(feature = "nightly")]
343     #[bench]
giver_queue_throughput(b: &mut test::Bencher)344     fn giver_queue_throughput(b: &mut test::Bencher) {
345         use crate::{Body, Request, Response};
346 
347         let mut rt = tokio::runtime::Builder::new()
348             .enable_all()
349             .basic_scheduler()
350             .build()
351             .unwrap();
352         let (mut tx, mut rx) = channel::<Request<Body>, Response<Body>>();
353 
354         b.iter(move || {
355             let _ = tx.send(Request::default()).unwrap();
356             rt.block_on(async {
357                 loop {
358                     let poll_once = PollOnce(&mut rx);
359                     let opt = poll_once.await;
360                     if opt.is_none() {
361                         break;
362                     }
363                 }
364             });
365         })
366     }
367 
368     #[cfg(feature = "nightly")]
369     #[bench]
giver_queue_not_ready(b: &mut test::Bencher)370     fn giver_queue_not_ready(b: &mut test::Bencher) {
371         let mut rt = tokio::runtime::Builder::new()
372             .enable_all()
373             .basic_scheduler()
374             .build()
375             .unwrap();
376         let (_tx, mut rx) = channel::<i32, ()>();
377         b.iter(move || {
378             rt.block_on(async {
379                 let poll_once = PollOnce(&mut rx);
380                 assert!(poll_once.await.is_none());
381             });
382         })
383     }
384 
385     #[cfg(feature = "nightly")]
386     #[bench]
giver_queue_cancel(b: &mut test::Bencher)387     fn giver_queue_cancel(b: &mut test::Bencher) {
388         let (_tx, mut rx) = channel::<i32, ()>();
389 
390         b.iter(move || {
391             rx.taker.cancel();
392         })
393     }
394 }
395