1 use crate::common::{Stream, TlsState}; 2 use rustls::Session; 3 use std::future::Future; 4 use std::pin::Pin; 5 use std::task::{Context, Poll}; 6 use std::{io, mem}; 7 use tokio::io::{AsyncRead, AsyncWrite}; 8 9 pub(crate) trait IoSession { 10 type Io; 11 type Session; 12 skip_handshake(&self) -> bool13 fn skip_handshake(&self) -> bool; get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session)14 fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session); into_io(self) -> Self::Io15 fn into_io(self) -> Self::Io; 16 } 17 18 pub(crate) enum MidHandshake<IS> { 19 Handshaking(IS), 20 End, 21 } 22 23 impl<IS> Future for MidHandshake<IS> 24 where 25 IS: IoSession + Unpin, 26 IS::Io: AsyncRead + AsyncWrite + Unpin, 27 IS::Session: Session + Unpin, 28 { 29 type Output = Result<IS, (io::Error, IS::Io)>; 30 poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>31 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 32 let this = self.get_mut(); 33 34 let mut stream = 35 if let MidHandshake::Handshaking(stream) = mem::replace(this, MidHandshake::End) { 36 stream 37 } else { 38 panic!("unexpected polling after handshake") 39 }; 40 41 if !stream.skip_handshake() { 42 let (state, io, session) = stream.get_mut(); 43 let mut tls_stream = Stream::new(io, session).set_eof(!state.readable()); 44 45 macro_rules! try_poll { 46 ( $e:expr ) => { 47 match $e { 48 Poll::Ready(Ok(_)) => (), 49 Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))), 50 Poll::Pending => { 51 *this = MidHandshake::Handshaking(stream); 52 return Poll::Pending; 53 } 54 } 55 }; 56 } 57 58 while tls_stream.session.is_handshaking() { 59 try_poll!(tls_stream.handshake(cx)); 60 } 61 62 while tls_stream.session.wants_write() { 63 try_poll!(tls_stream.write_io(cx)); 64 } 65 } 66 67 Poll::Ready(Ok(stream)) 68 } 69 } 70