1 // Copied from hyperium/hyper-tls#62e3376/src/stream.rs 2 use std::fmt; 3 use std::io; 4 use std::pin::Pin; 5 use std::task::{Context, Poll}; 6 7 use hyper::client::connect::{Connected, Connection}; 8 9 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; 10 use tokio_rustls::client::TlsStream; 11 12 /// A stream that might be protected with TLS. 13 pub enum MaybeHttpsStream<T> { 14 /// A stream over plain text. 15 Http(T), 16 /// A stream protected with TLS. 17 Https(TlsStream<T>), 18 } 19 20 impl<T: AsyncRead + AsyncWrite + Connection + Unpin> Connection for MaybeHttpsStream<T> { connected(&self) -> Connected21 fn connected(&self) -> Connected { 22 match self { 23 MaybeHttpsStream::Http(s) => s.connected(), 24 MaybeHttpsStream::Https(s) => { 25 let (tcp, tls) = s.get_ref(); 26 if tls.alpn_protocol() == Some(b"h2") { 27 tcp.connected().negotiated_h2() 28 } else { 29 tcp.connected() 30 } 31 } 32 } 33 } 34 } 35 36 impl<T: fmt::Debug> fmt::Debug for MaybeHttpsStream<T> { fmt(&self, f: &mut fmt::Formatter) -> fmt::Result37 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 38 match *self { 39 MaybeHttpsStream::Http(..) => f.pad("Http(..)"), 40 MaybeHttpsStream::Https(..) => f.pad("Https(..)"), 41 } 42 } 43 } 44 45 impl<T> From<T> for MaybeHttpsStream<T> { from(inner: T) -> Self46 fn from(inner: T) -> Self { 47 MaybeHttpsStream::Http(inner) 48 } 49 } 50 51 impl<T> From<TlsStream<T>> for MaybeHttpsStream<T> { from(inner: TlsStream<T>) -> Self52 fn from(inner: TlsStream<T>) -> Self { 53 MaybeHttpsStream::Https(inner) 54 } 55 } 56 57 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeHttpsStream<T> { 58 #[inline] poll_read( self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf<'_>, ) -> Poll<Result<(), io::Error>>59 fn poll_read( 60 self: Pin<&mut Self>, 61 cx: &mut Context, 62 buf: &mut ReadBuf<'_>, 63 ) -> Poll<Result<(), io::Error>> { 64 match Pin::get_mut(self) { 65 MaybeHttpsStream::Http(s) => Pin::new(s).poll_read(cx, buf), 66 MaybeHttpsStream::Https(s) => Pin::new(s).poll_read(cx, buf), 67 } 68 } 69 } 70 71 impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for MaybeHttpsStream<T> { 72 #[inline] poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, io::Error>>73 fn poll_write( 74 self: Pin<&mut Self>, 75 cx: &mut Context<'_>, 76 buf: &[u8], 77 ) -> Poll<Result<usize, io::Error>> { 78 match Pin::get_mut(self) { 79 MaybeHttpsStream::Http(s) => Pin::new(s).poll_write(cx, buf), 80 MaybeHttpsStream::Https(s) => Pin::new(s).poll_write(cx, buf), 81 } 82 } 83 84 #[inline] poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>85 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { 86 match Pin::get_mut(self) { 87 MaybeHttpsStream::Http(s) => Pin::new(s).poll_flush(cx), 88 MaybeHttpsStream::Https(s) => Pin::new(s).poll_flush(cx), 89 } 90 } 91 92 #[inline] poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>93 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { 94 match Pin::get_mut(self) { 95 MaybeHttpsStream::Http(s) => Pin::new(s).poll_shutdown(cx), 96 MaybeHttpsStream::Https(s) => Pin::new(s).poll_shutdown(cx), 97 } 98 } 99 } 100