1 use futures_core::ready;
2 use futures_sink::Sink;
3 use std::pin::Pin;
4 use std::sync::Arc;
5 use std::task::{Context, Poll};
6 use tokio::sync::mpsc::{error::SendError, Sender};
7 
8 use super::ReusableBoxFuture;
9 
10 // This implementation was chosen over something based on permits because to get a
11 // `tokio::sync::mpsc::Permit` out of the `inner` future, you must transmute the
12 // lifetime on the permit to `'static`.
13 
14 /// A wrapper around [`mpsc::Sender`] that can be polled.
15 ///
16 /// [`mpsc::Sender`]: tokio::sync::mpsc::Sender
17 #[derive(Debug)]
18 pub struct PollSender<T> {
19     /// is none if closed
20     sender: Option<Arc<Sender<T>>>,
21     is_sending: bool,
22     inner: ReusableBoxFuture<Result<(), SendError<T>>>,
23 }
24 
25 // By reusing the same async fn for both Some and None, we make sure every
26 // future passed to ReusableBoxFuture has the same underlying type, and hence
27 // the same size and alignment.
make_future<T>(data: Option<(Arc<Sender<T>>, T)>) -> Result<(), SendError<T>>28 async fn make_future<T>(data: Option<(Arc<Sender<T>>, T)>) -> Result<(), SendError<T>> {
29     match data {
30         Some((sender, value)) => sender.send(value).await,
31         None => unreachable!(
32             "This future should not be pollable, as is_sending should be set to false."
33         ),
34     }
35 }
36 
37 impl<T: Send + 'static> PollSender<T> {
38     /// Create a new `PollSender`.
new(sender: Sender<T>) -> Self39     pub fn new(sender: Sender<T>) -> Self {
40         Self {
41             sender: Some(Arc::new(sender)),
42             is_sending: false,
43             inner: ReusableBoxFuture::new(make_future(None)),
44         }
45     }
46 
47     /// Start sending a new item.
48     ///
49     /// This method panics if a send is currently in progress. To ensure that no
50     /// send is in progress, call `poll_send_done` first until it returns
51     /// `Poll::Ready`.
52     ///
53     /// If this method returns an error, that indicates that the channel is
54     /// closed. Note that this method is not guaranteed to return an error if
55     /// the channel is closed, but in that case the error would be reported by
56     /// the first call to `poll_send_done`.
start_send(&mut self, value: T) -> Result<(), SendError<T>>57     pub fn start_send(&mut self, value: T) -> Result<(), SendError<T>> {
58         if self.is_sending {
59             panic!("start_send called while not ready.");
60         }
61         match self.sender.clone() {
62             Some(sender) => {
63                 self.inner.set(make_future(Some((sender, value))));
64                 self.is_sending = true;
65                 Ok(())
66             }
67             None => Err(SendError(value)),
68         }
69     }
70 
71     /// If a send is in progress, poll for its completion. If no send is in progress,
72     /// this method returns `Poll::Ready(Ok(()))`.
73     ///
74     /// This method can return the following values:
75     ///
76     ///  - `Poll::Ready(Ok(()))` if the in-progress send has been completed, or there is
77     ///    no send in progress (even if the channel is closed).
78     ///  - `Poll::Ready(Err(err))` if the in-progress send failed because the channel has
79     ///    been closed.
80     ///  - `Poll::Pending` if a send is in progress, but it could not complete now.
81     ///
82     /// When this method returns `Poll::Pending`, the current task is scheduled
83     /// to receive a wakeup when the message is sent, or when the entire channel
84     /// is closed (but not if just this sender is closed by
85     /// `close_this_sender`). Note that on multiple calls to `poll_send_done`,
86     /// only the `Waker` from the `Context` passed to the most recent call is
87     /// scheduled to receive a wakeup.
88     ///
89     /// If this method returns `Poll::Ready`, then `start_send` is guaranteed to
90     /// not panic.
poll_send_done(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), SendError<T>>>91     pub fn poll_send_done(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), SendError<T>>> {
92         if !self.is_sending {
93             return Poll::Ready(Ok(()));
94         }
95 
96         let result = self.inner.poll(cx);
97         if result.is_ready() {
98             self.is_sending = false;
99         }
100         if let Poll::Ready(Err(_)) = &result {
101             self.sender = None;
102         }
103         result
104     }
105 
106     /// Check whether the channel is ready to send more messages now.
107     ///
108     /// If this method returns `true`, then `start_send` is guaranteed to not
109     /// panic.
110     ///
111     /// If the channel is closed, this method returns `true`.
is_ready(&self) -> bool112     pub fn is_ready(&self) -> bool {
113         !self.is_sending
114     }
115 
116     /// Check whether the channel has been closed.
is_closed(&self) -> bool117     pub fn is_closed(&self) -> bool {
118         match &self.sender {
119             Some(sender) => sender.is_closed(),
120             None => true,
121         }
122     }
123 
124     /// Clone the underlying `Sender`.
125     ///
126     /// If this method returns `None`, then the channel is closed. (But it is
127     /// not guaranteed to return `None` if the channel is closed.)
clone_inner(&self) -> Option<Sender<T>>128     pub fn clone_inner(&self) -> Option<Sender<T>> {
129         self.sender.as_ref().map(|sender| (&**sender).clone())
130     }
131 
132     /// Access the underlying `Sender`.
133     ///
134     /// If this method returns `None`, then the channel is closed. (But it is
135     /// not guaranteed to return `None` if the channel is closed.)
inner_ref(&self) -> Option<&Sender<T>>136     pub fn inner_ref(&self) -> Option<&Sender<T>> {
137         self.sender.as_deref()
138     }
139 
140     // This operation is supported because it is required by the Sink trait.
141     /// Close this sender. No more messages can be sent from this sender.
142     ///
143     /// Note that this only closes the channel from the view-point of this
144     /// sender. The channel remains open until all senders have gone away, or
145     /// until the [`Receiver`] closes the channel.
146     ///
147     /// If there is a send in progress when this method is called, that send is
148     /// unaffected by this operation, and `poll_send_done` can still be called
149     /// to complete that send.
150     ///
151     /// [`Receiver`]: tokio::sync::mpsc::Receiver
close_this_sender(&mut self)152     pub fn close_this_sender(&mut self) {
153         self.sender = None;
154     }
155 
156     /// Abort the current in-progress send, if any.
157     ///
158     /// Returns `true` if a send was aborted.
abort_send(&mut self) -> bool159     pub fn abort_send(&mut self) -> bool {
160         if self.is_sending {
161             self.inner.set(make_future(None));
162             self.is_sending = false;
163             true
164         } else {
165             false
166         }
167     }
168 }
169 
170 impl<T> Clone for PollSender<T> {
171     /// Clones this `PollSender`. The resulting clone will not have any
172     /// in-progress send operations, even if the current `PollSender` does.
clone(&self) -> PollSender<T>173     fn clone(&self) -> PollSender<T> {
174         Self {
175             sender: self.sender.clone(),
176             is_sending: false,
177             inner: ReusableBoxFuture::new(async { unreachable!() }),
178         }
179     }
180 }
181 
182 impl<T: Send + 'static> Sink<T> for PollSender<T> {
183     type Error = SendError<T>;
184 
185     /// This is equivalent to calling [`poll_send_done`].
186     ///
187     /// [`poll_send_done`]: PollSender::poll_send_done
poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>188     fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
189         Pin::into_inner(self).poll_send_done(cx)
190     }
191 
192     /// This is equivalent to calling [`poll_send_done`].
193     ///
194     /// [`poll_send_done`]: PollSender::poll_send_done
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>195     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
196         Pin::into_inner(self).poll_send_done(cx)
197     }
198 
199     /// This is equivalent to calling [`start_send`].
200     ///
201     /// [`start_send`]: PollSender::start_send
start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error>202     fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
203         Pin::into_inner(self).start_send(item)
204     }
205 
206     /// This method will first flush the `PollSender`, and then close it by
207     /// calling [`close_this_sender`].
208     ///
209     /// If a send fails while flushing because the [`Receiver`] has gone away,
210     /// then this function returns an error. The channel is still successfully
211     /// closed in this situation.
212     ///
213     /// [`close_this_sender`]: PollSender::close_this_sender
214     /// [`Receiver`]: tokio::sync::mpsc::Receiver
poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>215     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216         ready!(self.as_mut().poll_flush(cx))?;
217 
218         Pin::into_inner(self).close_this_sender();
219         Poll::Ready(Ok(()))
220     }
221 }
222