1 use crate::codec::Codec;
2 use crate::frame::Ping;
3 use crate::proto::{self, PingPayload};
4 
5 use bytes::Buf;
6 use futures_util::task::AtomicWaker;
7 use std::io;
8 use std::sync::atomic::{AtomicUsize, Ordering};
9 use std::sync::Arc;
10 use std::task::{Context, Poll};
11 use tokio::io::AsyncWrite;
12 
13 /// Acknowledges ping requests from the remote.
14 #[derive(Debug)]
15 pub(crate) struct PingPong {
16     pending_ping: Option<PendingPing>,
17     pending_pong: Option<PingPayload>,
18     user_pings: Option<UserPingsRx>,
19 }
20 
21 #[derive(Debug)]
22 pub(crate) struct UserPings(Arc<UserPingsInner>);
23 
24 #[derive(Debug)]
25 struct UserPingsRx(Arc<UserPingsInner>);
26 
27 #[derive(Debug)]
28 struct UserPingsInner {
29     state: AtomicUsize,
30     /// Task to wake up the main `Connection`.
31     ping_task: AtomicWaker,
32     /// Task to wake up `share::PingPong::poll_pong`.
33     pong_task: AtomicWaker,
34 }
35 
36 #[derive(Debug)]
37 struct PendingPing {
38     payload: PingPayload,
39     sent: bool,
40 }
41 
42 /// Status returned from `PingPong::recv_ping`.
43 #[derive(Debug)]
44 pub(crate) enum ReceivedPing {
45     MustAck,
46     Unknown,
47     Shutdown,
48 }
49 
50 /// No user ping pending.
51 const USER_STATE_EMPTY: usize = 0;
52 /// User has called `send_ping`, but PING hasn't been written yet.
53 const USER_STATE_PENDING_PING: usize = 1;
54 /// User PING has been written, waiting for PONG.
55 const USER_STATE_PENDING_PONG: usize = 2;
56 /// We've received user PONG, waiting for user to `poll_pong`.
57 const USER_STATE_RECEIVED_PONG: usize = 3;
58 /// The connection is closed.
59 const USER_STATE_CLOSED: usize = 4;
60 
61 // ===== impl PingPong =====
62 
63 impl PingPong {
new() -> Self64     pub(crate) fn new() -> Self {
65         PingPong {
66             pending_ping: None,
67             pending_pong: None,
68             user_pings: None,
69         }
70     }
71 
72     /// Can only be called once. If called a second time, returns `None`.
take_user_pings(&mut self) -> Option<UserPings>73     pub(crate) fn take_user_pings(&mut self) -> Option<UserPings> {
74         if self.user_pings.is_some() {
75             return None;
76         }
77 
78         let user_pings = Arc::new(UserPingsInner {
79             state: AtomicUsize::new(USER_STATE_EMPTY),
80             ping_task: AtomicWaker::new(),
81             pong_task: AtomicWaker::new(),
82         });
83         self.user_pings = Some(UserPingsRx(user_pings.clone()));
84         Some(UserPings(user_pings))
85     }
86 
ping_shutdown(&mut self)87     pub(crate) fn ping_shutdown(&mut self) {
88         assert!(self.pending_ping.is_none());
89 
90         self.pending_ping = Some(PendingPing {
91             payload: Ping::SHUTDOWN,
92             sent: false,
93         });
94     }
95 
96     /// Process a ping
recv_ping(&mut self, ping: Ping) -> ReceivedPing97     pub(crate) fn recv_ping(&mut self, ping: Ping) -> ReceivedPing {
98         // The caller should always check that `send_pongs` returns ready before
99         // calling `recv_ping`.
100         assert!(self.pending_pong.is_none());
101 
102         if ping.is_ack() {
103             if let Some(pending) = self.pending_ping.take() {
104                 if &pending.payload == ping.payload() {
105                     assert_eq!(
106                         &pending.payload,
107                         &Ping::SHUTDOWN,
108                         "pending_ping should be for shutdown",
109                     );
110                     tracing::trace!("recv PING SHUTDOWN ack");
111                     return ReceivedPing::Shutdown;
112                 }
113 
114                 // if not the payload we expected, put it back.
115                 self.pending_ping = Some(pending);
116             }
117 
118             if let Some(ref users) = self.user_pings {
119                 if ping.payload() == &Ping::USER && users.receive_pong() {
120                     tracing::trace!("recv PING USER ack");
121                     return ReceivedPing::Unknown;
122                 }
123             }
124 
125             // else we were acked a ping we didn't send?
126             // The spec doesn't require us to do anything about this,
127             // so for resiliency, just ignore it for now.
128             tracing::warn!("recv PING ack that we never sent: {:?}", ping);
129             ReceivedPing::Unknown
130         } else {
131             // Save the ping's payload to be sent as an acknowledgement.
132             self.pending_pong = Some(ping.into_payload());
133             ReceivedPing::MustAck
134         }
135     }
136 
137     /// Send any pending pongs.
send_pending_pong<T, B>( &mut self, cx: &mut Context, dst: &mut Codec<T, B>, ) -> Poll<io::Result<()>> where T: AsyncWrite + Unpin, B: Buf,138     pub(crate) fn send_pending_pong<T, B>(
139         &mut self,
140         cx: &mut Context,
141         dst: &mut Codec<T, B>,
142     ) -> Poll<io::Result<()>>
143     where
144         T: AsyncWrite + Unpin,
145         B: Buf,
146     {
147         if let Some(pong) = self.pending_pong.take() {
148             if !dst.poll_ready(cx)?.is_ready() {
149                 self.pending_pong = Some(pong);
150                 return Poll::Pending;
151             }
152 
153             dst.buffer(Ping::pong(pong).into())
154                 .expect("invalid pong frame");
155         }
156 
157         Poll::Ready(Ok(()))
158     }
159 
160     /// Send any pending pings.
send_pending_ping<T, B>( &mut self, cx: &mut Context, dst: &mut Codec<T, B>, ) -> Poll<io::Result<()>> where T: AsyncWrite + Unpin, B: Buf,161     pub(crate) fn send_pending_ping<T, B>(
162         &mut self,
163         cx: &mut Context,
164         dst: &mut Codec<T, B>,
165     ) -> Poll<io::Result<()>>
166     where
167         T: AsyncWrite + Unpin,
168         B: Buf,
169     {
170         if let Some(ref mut ping) = self.pending_ping {
171             if !ping.sent {
172                 if !dst.poll_ready(cx)?.is_ready() {
173                     return Poll::Pending;
174                 }
175 
176                 dst.buffer(Ping::new(ping.payload).into())
177                     .expect("invalid ping frame");
178                 ping.sent = true;
179             }
180         } else if let Some(ref users) = self.user_pings {
181             if users.0.state.load(Ordering::Acquire) == USER_STATE_PENDING_PING {
182                 if !dst.poll_ready(cx)?.is_ready() {
183                     return Poll::Pending;
184                 }
185 
186                 dst.buffer(Ping::new(Ping::USER).into())
187                     .expect("invalid ping frame");
188                 users
189                     .0
190                     .state
191                     .store(USER_STATE_PENDING_PONG, Ordering::Release);
192             } else {
193                 users.0.ping_task.register(cx.waker());
194             }
195         }
196 
197         Poll::Ready(Ok(()))
198     }
199 }
200 
201 impl ReceivedPing {
is_shutdown(&self) -> bool202     pub(crate) fn is_shutdown(&self) -> bool {
203         match *self {
204             ReceivedPing::Shutdown => true,
205             _ => false,
206         }
207     }
208 }
209 
210 // ===== impl UserPings =====
211 
212 impl UserPings {
send_ping(&self) -> Result<(), Option<proto::Error>>213     pub(crate) fn send_ping(&self) -> Result<(), Option<proto::Error>> {
214         let prev = self.0.state.compare_and_swap(
215             USER_STATE_EMPTY,        // current
216             USER_STATE_PENDING_PING, // new
217             Ordering::AcqRel,
218         );
219 
220         match prev {
221             USER_STATE_EMPTY => {
222                 self.0.ping_task.wake();
223                 Ok(())
224             }
225             USER_STATE_CLOSED => Err(Some(broken_pipe().into())),
226             _ => {
227                 // Was already pending, user error!
228                 Err(None)
229             }
230         }
231     }
232 
poll_pong(&self, cx: &mut Context) -> Poll<Result<(), proto::Error>>233     pub(crate) fn poll_pong(&self, cx: &mut Context) -> Poll<Result<(), proto::Error>> {
234         // Must register before checking state, in case state were to change
235         // before we could register, and then the ping would just be lost.
236         self.0.pong_task.register(cx.waker());
237         let prev = self.0.state.compare_and_swap(
238             USER_STATE_RECEIVED_PONG, // current
239             USER_STATE_EMPTY,         // new
240             Ordering::AcqRel,
241         );
242 
243         match prev {
244             USER_STATE_RECEIVED_PONG => Poll::Ready(Ok(())),
245             USER_STATE_CLOSED => Poll::Ready(Err(broken_pipe().into())),
246             _ => Poll::Pending,
247         }
248     }
249 }
250 
251 // ===== impl UserPingsRx =====
252 
253 impl UserPingsRx {
receive_pong(&self) -> bool254     fn receive_pong(&self) -> bool {
255         let prev = self.0.state.compare_and_swap(
256             USER_STATE_PENDING_PONG,  // current
257             USER_STATE_RECEIVED_PONG, // new
258             Ordering::AcqRel,
259         );
260 
261         if prev == USER_STATE_PENDING_PONG {
262             self.0.pong_task.wake();
263             true
264         } else {
265             false
266         }
267     }
268 }
269 
270 impl Drop for UserPingsRx {
drop(&mut self)271     fn drop(&mut self) {
272         self.0.state.store(USER_STATE_CLOSED, Ordering::Release);
273         self.0.pong_task.wake();
274     }
275 }
276 
broken_pipe() -> io::Error277 fn broken_pipe() -> io::Error {
278     io::ErrorKind::BrokenPipe.into()
279 }
280