1 use futures_core::ready;
2 use futures_sink::Sink;
3 use std::pin::Pin;
4 use std::sync::Arc;
5 use std::task::{Context, Poll};
6 use tokio::sync::mpsc::{error::SendError, Sender};
7
8 use super::ReusableBoxFuture;
9
10 // This implementation was chosen over something based on permits because to get a
11 // `tokio::sync::mpsc::Permit` out of the `inner` future, you must transmute the
12 // lifetime on the permit to `'static`.
13
14 /// A wrapper around [`mpsc::Sender`] that can be polled.
15 ///
16 /// [`mpsc::Sender`]: tokio::sync::mpsc::Sender
17 #[derive(Debug)]
18 pub struct PollSender<T> {
19 /// is none if closed
20 sender: Option<Arc<Sender<T>>>,
21 is_sending: bool,
22 inner: ReusableBoxFuture<Result<(), SendError<T>>>,
23 }
24
25 // By reusing the same async fn for both Some and None, we make sure every
26 // future passed to ReusableBoxFuture has the same underlying type, and hence
27 // the same size and alignment.
make_future<T>(data: Option<(Arc<Sender<T>>, T)>) -> Result<(), SendError<T>>28 async fn make_future<T>(data: Option<(Arc<Sender<T>>, T)>) -> Result<(), SendError<T>> {
29 match data {
30 Some((sender, value)) => sender.send(value).await,
31 None => unreachable!(
32 "This future should not be pollable, as is_sending should be set to false."
33 ),
34 }
35 }
36
37 impl<T: Send + 'static> PollSender<T> {
38 /// Create a new `PollSender`.
new(sender: Sender<T>) -> Self39 pub fn new(sender: Sender<T>) -> Self {
40 Self {
41 sender: Some(Arc::new(sender)),
42 is_sending: false,
43 inner: ReusableBoxFuture::new(make_future(None)),
44 }
45 }
46
47 /// Start sending a new item.
48 ///
49 /// This method panics if a send is currently in progress. To ensure that no
50 /// send is in progress, call `poll_send_done` first until it returns
51 /// `Poll::Ready`.
52 ///
53 /// If this method returns an error, that indicates that the channel is
54 /// closed. Note that this method is not guaranteed to return an error if
55 /// the channel is closed, but in that case the error would be reported by
56 /// the first call to `poll_send_done`.
start_send(&mut self, value: T) -> Result<(), SendError<T>>57 pub fn start_send(&mut self, value: T) -> Result<(), SendError<T>> {
58 if self.is_sending {
59 panic!("start_send called while not ready.");
60 }
61 match self.sender.clone() {
62 Some(sender) => {
63 self.inner.set(make_future(Some((sender, value))));
64 self.is_sending = true;
65 Ok(())
66 }
67 None => Err(SendError(value)),
68 }
69 }
70
71 /// If a send is in progress, poll for its completion. If no send is in progress,
72 /// this method returns `Poll::Ready(Ok(()))`.
73 ///
74 /// This method can return the following values:
75 ///
76 /// - `Poll::Ready(Ok(()))` if the in-progress send has been completed, or there is
77 /// no send in progress (even if the channel is closed).
78 /// - `Poll::Ready(Err(err))` if the in-progress send failed because the channel has
79 /// been closed.
80 /// - `Poll::Pending` if a send is in progress, but it could not complete now.
81 ///
82 /// When this method returns `Poll::Pending`, the current task is scheduled
83 /// to receive a wakeup when the message is sent, or when the entire channel
84 /// is closed (but not if just this sender is closed by
85 /// `close_this_sender`). Note that on multiple calls to `poll_send_done`,
86 /// only the `Waker` from the `Context` passed to the most recent call is
87 /// scheduled to receive a wakeup.
88 ///
89 /// If this method returns `Poll::Ready`, then `start_send` is guaranteed to
90 /// not panic.
poll_send_done(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), SendError<T>>>91 pub fn poll_send_done(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), SendError<T>>> {
92 if !self.is_sending {
93 return Poll::Ready(Ok(()));
94 }
95
96 let result = self.inner.poll(cx);
97 if result.is_ready() {
98 self.is_sending = false;
99 }
100 if let Poll::Ready(Err(_)) = &result {
101 self.sender = None;
102 }
103 result
104 }
105
106 /// Check whether the channel is ready to send more messages now.
107 ///
108 /// If this method returns `true`, then `start_send` is guaranteed to not
109 /// panic.
110 ///
111 /// If the channel is closed, this method returns `true`.
is_ready(&self) -> bool112 pub fn is_ready(&self) -> bool {
113 !self.is_sending
114 }
115
116 /// Check whether the channel has been closed.
is_closed(&self) -> bool117 pub fn is_closed(&self) -> bool {
118 match &self.sender {
119 Some(sender) => sender.is_closed(),
120 None => true,
121 }
122 }
123
124 /// Clone the underlying `Sender`.
125 ///
126 /// If this method returns `None`, then the channel is closed. (But it is
127 /// not guaranteed to return `None` if the channel is closed.)
clone_inner(&self) -> Option<Sender<T>>128 pub fn clone_inner(&self) -> Option<Sender<T>> {
129 self.sender.as_ref().map(|sender| (&**sender).clone())
130 }
131
132 /// Access the underlying `Sender`.
133 ///
134 /// If this method returns `None`, then the channel is closed. (But it is
135 /// not guaranteed to return `None` if the channel is closed.)
inner_ref(&self) -> Option<&Sender<T>>136 pub fn inner_ref(&self) -> Option<&Sender<T>> {
137 self.sender.as_deref()
138 }
139
140 // This operation is supported because it is required by the Sink trait.
141 /// Close this sender. No more messages can be sent from this sender.
142 ///
143 /// Note that this only closes the channel from the view-point of this
144 /// sender. The channel remains open until all senders have gone away, or
145 /// until the [`Receiver`] closes the channel.
146 ///
147 /// If there is a send in progress when this method is called, that send is
148 /// unaffected by this operation, and `poll_send_done` can still be called
149 /// to complete that send.
150 ///
151 /// [`Receiver`]: tokio::sync::mpsc::Receiver
close_this_sender(&mut self)152 pub fn close_this_sender(&mut self) {
153 self.sender = None;
154 }
155
156 /// Abort the current in-progress send, if any.
157 ///
158 /// Returns `true` if a send was aborted.
abort_send(&mut self) -> bool159 pub fn abort_send(&mut self) -> bool {
160 if self.is_sending {
161 self.inner.set(make_future(None));
162 self.is_sending = false;
163 true
164 } else {
165 false
166 }
167 }
168 }
169
170 impl<T> Clone for PollSender<T> {
171 /// Clones this `PollSender`. The resulting clone will not have any
172 /// in-progress send operations, even if the current `PollSender` does.
clone(&self) -> PollSender<T>173 fn clone(&self) -> PollSender<T> {
174 Self {
175 sender: self.sender.clone(),
176 is_sending: false,
177 inner: ReusableBoxFuture::new(async { unreachable!() }),
178 }
179 }
180 }
181
182 impl<T: Send + 'static> Sink<T> for PollSender<T> {
183 type Error = SendError<T>;
184
185 /// This is equivalent to calling [`poll_send_done`].
186 ///
187 /// [`poll_send_done`]: PollSender::poll_send_done
poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>188 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
189 Pin::into_inner(self).poll_send_done(cx)
190 }
191
192 /// This is equivalent to calling [`poll_send_done`].
193 ///
194 /// [`poll_send_done`]: PollSender::poll_send_done
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>195 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
196 Pin::into_inner(self).poll_send_done(cx)
197 }
198
199 /// This is equivalent to calling [`start_send`].
200 ///
201 /// [`start_send`]: PollSender::start_send
start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error>202 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
203 Pin::into_inner(self).start_send(item)
204 }
205
206 /// This method will first flush the `PollSender`, and then close it by
207 /// calling [`close_this_sender`].
208 ///
209 /// If a send fails while flushing because the [`Receiver`] has gone away,
210 /// then this function returns an error. The channel is still successfully
211 /// closed in this situation.
212 ///
213 /// [`close_this_sender`]: PollSender::close_this_sender
214 /// [`Receiver`]: tokio::sync::mpsc::Receiver
poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>215 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216 ready!(self.as_mut().poll_flush(cx))?;
217
218 Pin::into_inner(self).close_this_sender();
219 Poll::Ready(Ok(()))
220 }
221 }
222