1 use crate::common::{Stream, TlsState};
2 use rustls::Session;
3 use std::future::Future;
4 use std::pin::Pin;
5 use std::task::{Context, Poll};
6 use std::{io, mem};
7 use tokio::io::{AsyncRead, AsyncWrite};
8 
9 pub(crate) trait IoSession {
10     type Io;
11     type Session;
12 
skip_handshake(&self) -> bool13     fn skip_handshake(&self) -> bool;
get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session)14     fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session);
into_io(self) -> Self::Io15     fn into_io(self) -> Self::Io;
16 }
17 
18 pub(crate) enum MidHandshake<IS> {
19     Handshaking(IS),
20     End,
21 }
22 
23 impl<IS> Future for MidHandshake<IS>
24 where
25     IS: IoSession + Unpin,
26     IS::Io: AsyncRead + AsyncWrite + Unpin,
27     IS::Session: Session + Unpin,
28 {
29     type Output = Result<IS, (io::Error, IS::Io)>;
30 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>31     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
32         let this = self.get_mut();
33 
34         let mut stream =
35             if let MidHandshake::Handshaking(stream) = mem::replace(this, MidHandshake::End) {
36                 stream
37             } else {
38                 panic!("unexpected polling after handshake")
39             };
40 
41         if !stream.skip_handshake() {
42             let (state, io, session) = stream.get_mut();
43             let mut tls_stream = Stream::new(io, session).set_eof(!state.readable());
44 
45             macro_rules! try_poll {
46                 ( $e:expr ) => {
47                     match $e {
48                         Poll::Ready(Ok(_)) => (),
49                         Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))),
50                         Poll::Pending => {
51                             *this = MidHandshake::Handshaking(stream);
52                             return Poll::Pending;
53                         }
54                     }
55                 };
56             }
57 
58             while tls_stream.session.is_handshaking() {
59                 try_poll!(tls_stream.handshake(cx));
60             }
61 
62             while tls_stream.session.wants_write() {
63                 try_poll!(tls_stream.write_io(cx));
64             }
65         }
66 
67         Poll::Ready(Ok(stream))
68     }
69 }
70