1 mod handshake; 2 3 pub(crate) use handshake::{IoSession, MidHandshake}; 4 use rustls::Session; 5 use std::io::{self, IoSlice, Read, Write}; 6 use std::pin::Pin; 7 use std::task::{Context, Poll}; 8 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; 9 10 #[derive(Debug)] 11 pub enum TlsState { 12 #[cfg(feature = "early-data")] 13 EarlyData(usize, Vec<u8>), 14 Stream, 15 ReadShutdown, 16 WriteShutdown, 17 FullyShutdown, 18 } 19 20 impl TlsState { 21 #[inline] shutdown_read(&mut self)22 pub fn shutdown_read(&mut self) { 23 match *self { 24 TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, 25 _ => *self = TlsState::ReadShutdown, 26 } 27 } 28 29 #[inline] shutdown_write(&mut self)30 pub fn shutdown_write(&mut self) { 31 match *self { 32 TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, 33 _ => *self = TlsState::WriteShutdown, 34 } 35 } 36 37 #[inline] writeable(&self) -> bool38 pub fn writeable(&self) -> bool { 39 !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown) 40 } 41 42 #[inline] readable(&self) -> bool43 pub fn readable(&self) -> bool { 44 !matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown) 45 } 46 47 #[inline] 48 #[cfg(feature = "early-data")] is_early_data(&self) -> bool49 pub fn is_early_data(&self) -> bool { 50 matches!(self, TlsState::EarlyData(..)) 51 } 52 53 #[inline] 54 #[cfg(not(feature = "early-data"))] is_early_data(&self) -> bool55 pub const fn is_early_data(&self) -> bool { 56 false 57 } 58 } 59 60 pub struct Stream<'a, IO, S> { 61 pub io: &'a mut IO, 62 pub session: &'a mut S, 63 pub eof: bool, 64 } 65 66 impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { new(io: &'a mut IO, session: &'a mut S) -> Self67 pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { 68 Stream { 69 io, 70 session, 71 // The state so far is only used to detect EOF, so either Stream 72 // or EarlyData state should both be all right. 73 eof: false, 74 } 75 } 76 set_eof(mut self, eof: bool) -> Self77 pub fn set_eof(mut self, eof: bool) -> Self { 78 self.eof = eof; 79 self 80 } 81 as_mut_pin(&mut self) -> Pin<&mut Self>82 pub fn as_mut_pin(&mut self) -> Pin<&mut Self> { 83 Pin::new(self) 84 } 85 read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>>86 pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> { 87 struct Reader<'a, 'b, T> { 88 io: &'a mut T, 89 cx: &'a mut Context<'b>, 90 } 91 92 impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> { 93 #[inline] 94 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { 95 let mut buf = ReadBuf::new(buf); 96 match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) { 97 Poll::Ready(Ok(())) => Ok(buf.filled().len()), 98 Poll::Ready(Err(err)) => Err(err), 99 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), 100 } 101 } 102 } 103 104 let mut reader = Reader { io: self.io, cx }; 105 106 let n = match self.session.read_tls(&mut reader) { 107 Ok(n) => n, 108 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, 109 Err(err) => return Poll::Ready(Err(err)), 110 }; 111 112 self.session.process_new_packets().map_err(|err| { 113 // In case we have an alert to send describing this error, 114 // try a last-gasp write -- but don't predate the primary 115 // error. 116 let _ = self.write_io(cx); 117 118 io::Error::new(io::ErrorKind::InvalidData, err) 119 })?; 120 121 Poll::Ready(Ok(n)) 122 } 123 write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>>124 pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> { 125 struct Writer<'a, 'b, T> { 126 io: &'a mut T, 127 cx: &'a mut Context<'b>, 128 } 129 130 impl<'a, 'b, T: Unpin> Writer<'a, 'b, T> { 131 #[inline] 132 fn poll_with<U>( 133 &mut self, 134 f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>, 135 ) -> io::Result<U> { 136 match f(Pin::new(&mut self.io), self.cx) { 137 Poll::Ready(result) => result, 138 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), 139 } 140 } 141 } 142 143 impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> { 144 #[inline] 145 fn write(&mut self, buf: &[u8]) -> io::Result<usize> { 146 self.poll_with(|io, cx| io.poll_write(cx, buf)) 147 } 148 149 #[inline] 150 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> { 151 self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs)) 152 } 153 154 fn flush(&mut self) -> io::Result<()> { 155 self.poll_with(|io, cx| io.poll_flush(cx)) 156 } 157 } 158 159 let mut writer = Writer { io: self.io, cx }; 160 161 match self.session.write_tls(&mut writer) { 162 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, 163 result => Poll::Ready(result), 164 } 165 } 166 handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>>167 pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> { 168 let mut wrlen = 0; 169 let mut rdlen = 0; 170 171 loop { 172 let mut write_would_block = false; 173 let mut read_would_block = false; 174 175 while self.session.wants_write() { 176 match self.write_io(cx) { 177 Poll::Ready(Ok(n)) => wrlen += n, 178 Poll::Pending => { 179 write_would_block = true; 180 break; 181 } 182 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), 183 } 184 } 185 186 while !self.eof && self.session.wants_read() { 187 match self.read_io(cx) { 188 Poll::Ready(Ok(0)) => self.eof = true, 189 Poll::Ready(Ok(n)) => rdlen += n, 190 Poll::Pending => { 191 read_would_block = true; 192 break; 193 } 194 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), 195 } 196 } 197 198 return match (self.eof, self.session.is_handshaking()) { 199 (true, true) => { 200 let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); 201 Poll::Ready(Err(err)) 202 } 203 (_, false) => Poll::Ready(Ok((rdlen, wrlen))), 204 (_, true) if write_would_block || read_would_block => { 205 if rdlen != 0 || wrlen != 0 { 206 Poll::Ready(Ok((rdlen, wrlen))) 207 } else { 208 Poll::Pending 209 } 210 } 211 (..) => continue, 212 }; 213 } 214 } 215 } 216 217 impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> { poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>218 fn poll_read( 219 mut self: Pin<&mut Self>, 220 cx: &mut Context<'_>, 221 buf: &mut ReadBuf<'_>, 222 ) -> Poll<io::Result<()>> { 223 let prev = buf.remaining(); 224 225 while buf.remaining() != 0 { 226 let mut would_block = false; 227 228 // read a packet 229 while self.session.wants_read() { 230 match self.read_io(cx) { 231 Poll::Ready(Ok(0)) => { 232 self.eof = true; 233 break; 234 } 235 Poll::Ready(Ok(_)) => (), 236 Poll::Pending => { 237 would_block = true; 238 break; 239 } 240 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), 241 } 242 } 243 244 return match self.session.read(buf.initialize_unfilled()) { 245 Ok(0) if prev == buf.remaining() && would_block => Poll::Pending, 246 Ok(n) => { 247 buf.advance(n); 248 249 if self.eof || would_block { 250 break; 251 } else { 252 continue; 253 } 254 } 255 Err(ref err) 256 if err.kind() == io::ErrorKind::ConnectionAborted 257 && prev != buf.remaining() => 258 { 259 break 260 } 261 Err(err) => Poll::Ready(Err(err)), 262 }; 263 } 264 265 Poll::Ready(Ok(())) 266 } 267 } 268 269 impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> { poll_write( mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll<io::Result<usize>>270 fn poll_write( 271 mut self: Pin<&mut Self>, 272 cx: &mut Context, 273 buf: &[u8], 274 ) -> Poll<io::Result<usize>> { 275 let mut pos = 0; 276 277 while pos != buf.len() { 278 let mut would_block = false; 279 280 match self.session.write(&buf[pos..]) { 281 Ok(n) => pos += n, 282 Err(err) => return Poll::Ready(Err(err)), 283 }; 284 285 while self.session.wants_write() { 286 match self.write_io(cx) { 287 Poll::Ready(Ok(0)) | Poll::Pending => { 288 would_block = true; 289 break; 290 } 291 Poll::Ready(Ok(_)) => (), 292 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), 293 } 294 } 295 296 return match (pos, would_block) { 297 (0, true) => Poll::Pending, 298 (n, true) => Poll::Ready(Ok(n)), 299 (_, false) => continue, 300 }; 301 } 302 303 Poll::Ready(Ok(pos)) 304 } 305 poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>>306 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> { 307 self.session.flush()?; 308 while self.session.wants_write() { 309 ready!(self.write_io(cx))?; 310 } 311 Pin::new(&mut self.io).poll_flush(cx) 312 } 313 poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>314 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 315 while self.session.wants_write() { 316 ready!(self.write_io(cx))?; 317 } 318 Pin::new(&mut self.io).poll_shutdown(cx) 319 } 320 } 321 322 #[cfg(test)] 323 mod test_stream; 324