1 mod handshake;
2 
3 pub(crate) use handshake::{IoSession, MidHandshake};
4 use rustls::Session;
5 use std::io::{self, IoSlice, Read, Write};
6 use std::pin::Pin;
7 use std::task::{Context, Poll};
8 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
9 
10 #[derive(Debug)]
11 pub enum TlsState {
12     #[cfg(feature = "early-data")]
13     EarlyData(usize, Vec<u8>),
14     Stream,
15     ReadShutdown,
16     WriteShutdown,
17     FullyShutdown,
18 }
19 
20 impl TlsState {
21     #[inline]
shutdown_read(&mut self)22     pub fn shutdown_read(&mut self) {
23         match *self {
24             TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
25             _ => *self = TlsState::ReadShutdown,
26         }
27     }
28 
29     #[inline]
shutdown_write(&mut self)30     pub fn shutdown_write(&mut self) {
31         match *self {
32             TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
33             _ => *self = TlsState::WriteShutdown,
34         }
35     }
36 
37     #[inline]
writeable(&self) -> bool38     pub fn writeable(&self) -> bool {
39         !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown)
40     }
41 
42     #[inline]
readable(&self) -> bool43     pub fn readable(&self) -> bool {
44         !matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown)
45     }
46 
47     #[inline]
48     #[cfg(feature = "early-data")]
is_early_data(&self) -> bool49     pub fn is_early_data(&self) -> bool {
50         matches!(self, TlsState::EarlyData(..))
51     }
52 
53     #[inline]
54     #[cfg(not(feature = "early-data"))]
is_early_data(&self) -> bool55     pub const fn is_early_data(&self) -> bool {
56         false
57     }
58 }
59 
60 pub struct Stream<'a, IO, S> {
61     pub io: &'a mut IO,
62     pub session: &'a mut S,
63     pub eof: bool,
64 }
65 
66 impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
new(io: &'a mut IO, session: &'a mut S) -> Self67     pub fn new(io: &'a mut IO, session: &'a mut S) -> Self {
68         Stream {
69             io,
70             session,
71             // The state so far is only used to detect EOF, so either Stream
72             // or EarlyData state should both be all right.
73             eof: false,
74         }
75     }
76 
set_eof(mut self, eof: bool) -> Self77     pub fn set_eof(mut self, eof: bool) -> Self {
78         self.eof = eof;
79         self
80     }
81 
as_mut_pin(&mut self) -> Pin<&mut Self>82     pub fn as_mut_pin(&mut self) -> Pin<&mut Self> {
83         Pin::new(self)
84     }
85 
read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>>86     pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
87         struct Reader<'a, 'b, T> {
88             io: &'a mut T,
89             cx: &'a mut Context<'b>,
90         }
91 
92         impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> {
93             #[inline]
94             fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
95                 let mut buf = ReadBuf::new(buf);
96                 match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) {
97                     Poll::Ready(Ok(())) => Ok(buf.filled().len()),
98                     Poll::Ready(Err(err)) => Err(err),
99                     Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
100                 }
101             }
102         }
103 
104         let mut reader = Reader { io: self.io, cx };
105 
106         let n = match self.session.read_tls(&mut reader) {
107             Ok(n) => n,
108             Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
109             Err(err) => return Poll::Ready(Err(err)),
110         };
111 
112         self.session.process_new_packets().map_err(|err| {
113             // In case we have an alert to send describing this error,
114             // try a last-gasp write -- but don't predate the primary
115             // error.
116             let _ = self.write_io(cx);
117 
118             io::Error::new(io::ErrorKind::InvalidData, err)
119         })?;
120 
121         Poll::Ready(Ok(n))
122     }
123 
write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>>124     pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
125         struct Writer<'a, 'b, T> {
126             io: &'a mut T,
127             cx: &'a mut Context<'b>,
128         }
129 
130         impl<'a, 'b, T: Unpin> Writer<'a, 'b, T> {
131             #[inline]
132             fn poll_with<U>(
133                 &mut self,
134                 f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
135             ) -> io::Result<U> {
136                 match f(Pin::new(&mut self.io), self.cx) {
137                     Poll::Ready(result) => result,
138                     Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
139                 }
140             }
141         }
142 
143         impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> {
144             #[inline]
145             fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
146                 self.poll_with(|io, cx| io.poll_write(cx, buf))
147             }
148 
149             #[inline]
150             fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
151                 self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
152             }
153 
154             fn flush(&mut self) -> io::Result<()> {
155                 self.poll_with(|io, cx| io.poll_flush(cx))
156             }
157         }
158 
159         let mut writer = Writer { io: self.io, cx };
160 
161         match self.session.write_tls(&mut writer) {
162             Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
163             result => Poll::Ready(result),
164         }
165     }
166 
handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>>167     pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
168         let mut wrlen = 0;
169         let mut rdlen = 0;
170 
171         loop {
172             let mut write_would_block = false;
173             let mut read_would_block = false;
174 
175             while self.session.wants_write() {
176                 match self.write_io(cx) {
177                     Poll::Ready(Ok(n)) => wrlen += n,
178                     Poll::Pending => {
179                         write_would_block = true;
180                         break;
181                     }
182                     Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
183                 }
184             }
185 
186             while !self.eof && self.session.wants_read() {
187                 match self.read_io(cx) {
188                     Poll::Ready(Ok(0)) => self.eof = true,
189                     Poll::Ready(Ok(n)) => rdlen += n,
190                     Poll::Pending => {
191                         read_would_block = true;
192                         break;
193                     }
194                     Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
195                 }
196             }
197 
198             return match (self.eof, self.session.is_handshaking()) {
199                 (true, true) => {
200                     let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
201                     Poll::Ready(Err(err))
202                 }
203                 (_, false) => Poll::Ready(Ok((rdlen, wrlen))),
204                 (_, true) if write_would_block || read_would_block => {
205                     if rdlen != 0 || wrlen != 0 {
206                         Poll::Ready(Ok((rdlen, wrlen)))
207                     } else {
208                         Poll::Pending
209                     }
210                 }
211                 (..) => continue,
212             };
213         }
214     }
215 }
216 
217 impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> {
poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>218     fn poll_read(
219         mut self: Pin<&mut Self>,
220         cx: &mut Context<'_>,
221         buf: &mut ReadBuf<'_>,
222     ) -> Poll<io::Result<()>> {
223         let prev = buf.remaining();
224 
225         while buf.remaining() != 0 {
226             let mut would_block = false;
227 
228             // read a packet
229             while self.session.wants_read() {
230                 match self.read_io(cx) {
231                     Poll::Ready(Ok(0)) => {
232                         self.eof = true;
233                         break;
234                     }
235                     Poll::Ready(Ok(_)) => (),
236                     Poll::Pending => {
237                         would_block = true;
238                         break;
239                     }
240                     Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
241                 }
242             }
243 
244             return match self.session.read(buf.initialize_unfilled()) {
245                 Ok(0) if prev == buf.remaining() && would_block => Poll::Pending,
246                 Ok(n) => {
247                     buf.advance(n);
248 
249                     if self.eof || would_block {
250                         break;
251                     } else {
252                         continue;
253                     }
254                 }
255                 Err(ref err)
256                     if err.kind() == io::ErrorKind::ConnectionAborted
257                         && prev != buf.remaining() =>
258                 {
259                     break
260                 }
261                 Err(err) => Poll::Ready(Err(err)),
262             };
263         }
264 
265         Poll::Ready(Ok(()))
266     }
267 }
268 
269 impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> {
poll_write( mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll<io::Result<usize>>270     fn poll_write(
271         mut self: Pin<&mut Self>,
272         cx: &mut Context,
273         buf: &[u8],
274     ) -> Poll<io::Result<usize>> {
275         let mut pos = 0;
276 
277         while pos != buf.len() {
278             let mut would_block = false;
279 
280             match self.session.write(&buf[pos..]) {
281                 Ok(n) => pos += n,
282                 Err(err) => return Poll::Ready(Err(err)),
283             };
284 
285             while self.session.wants_write() {
286                 match self.write_io(cx) {
287                     Poll::Ready(Ok(0)) | Poll::Pending => {
288                         would_block = true;
289                         break;
290                     }
291                     Poll::Ready(Ok(_)) => (),
292                     Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
293                 }
294             }
295 
296             return match (pos, would_block) {
297                 (0, true) => Poll::Pending,
298                 (n, true) => Poll::Ready(Ok(n)),
299                 (_, false) => continue,
300             };
301         }
302 
303         Poll::Ready(Ok(pos))
304     }
305 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>>306     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
307         self.session.flush()?;
308         while self.session.wants_write() {
309             ready!(self.write_io(cx))?;
310         }
311         Pin::new(&mut self.io).poll_flush(cx)
312     }
313 
poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>314     fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
315         while self.session.wants_write() {
316             ready!(self.write_io(cx))?;
317         }
318         Pin::new(&mut self.io).poll_shutdown(cx)
319     }
320 }
321 
322 #[cfg(test)]
323 mod test_stream;
324