1 // Copied from hyperium/hyper-tls#62e3376/src/stream.rs
2 use std::fmt;
3 use std::io;
4 use std::pin::Pin;
5 use std::task::{Context, Poll};
6 
7 use hyper::client::connect::{Connected, Connection};
8 
9 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10 use tokio_rustls::client::TlsStream;
11 
12 /// A stream that might be protected with TLS.
13 pub enum MaybeHttpsStream<T> {
14     /// A stream over plain text.
15     Http(T),
16     /// A stream protected with TLS.
17     Https(TlsStream<T>),
18 }
19 
20 impl<T: AsyncRead + AsyncWrite + Connection + Unpin> Connection for MaybeHttpsStream<T> {
connected(&self) -> Connected21     fn connected(&self) -> Connected {
22         match self {
23             MaybeHttpsStream::Http(s) => s.connected(),
24             MaybeHttpsStream::Https(s) => {
25                 let (tcp, tls) = s.get_ref();
26                 if tls.alpn_protocol() == Some(b"h2") {
27                     tcp.connected().negotiated_h2()
28                 } else {
29                     tcp.connected()
30                 }
31             }
32         }
33     }
34 }
35 
36 impl<T: fmt::Debug> fmt::Debug for MaybeHttpsStream<T> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result37     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
38         match *self {
39             MaybeHttpsStream::Http(..) => f.pad("Http(..)"),
40             MaybeHttpsStream::Https(..) => f.pad("Https(..)"),
41         }
42     }
43 }
44 
45 impl<T> From<T> for MaybeHttpsStream<T> {
from(inner: T) -> Self46     fn from(inner: T) -> Self {
47         MaybeHttpsStream::Http(inner)
48     }
49 }
50 
51 impl<T> From<TlsStream<T>> for MaybeHttpsStream<T> {
from(inner: TlsStream<T>) -> Self52     fn from(inner: TlsStream<T>) -> Self {
53         MaybeHttpsStream::Https(inner)
54     }
55 }
56 
57 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeHttpsStream<T> {
58     #[inline]
poll_read( self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf<'_>, ) -> Poll<Result<(), io::Error>>59     fn poll_read(
60         self: Pin<&mut Self>,
61         cx: &mut Context,
62         buf: &mut ReadBuf<'_>,
63     ) -> Poll<Result<(), io::Error>> {
64         match Pin::get_mut(self) {
65             MaybeHttpsStream::Http(s) => Pin::new(s).poll_read(cx, buf),
66             MaybeHttpsStream::Https(s) => Pin::new(s).poll_read(cx, buf),
67         }
68     }
69 }
70 
71 impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for MaybeHttpsStream<T> {
72     #[inline]
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, io::Error>>73     fn poll_write(
74         self: Pin<&mut Self>,
75         cx: &mut Context<'_>,
76         buf: &[u8],
77     ) -> Poll<Result<usize, io::Error>> {
78         match Pin::get_mut(self) {
79             MaybeHttpsStream::Http(s) => Pin::new(s).poll_write(cx, buf),
80             MaybeHttpsStream::Https(s) => Pin::new(s).poll_write(cx, buf),
81         }
82     }
83 
84     #[inline]
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>85     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
86         match Pin::get_mut(self) {
87             MaybeHttpsStream::Http(s) => Pin::new(s).poll_flush(cx),
88             MaybeHttpsStream::Https(s) => Pin::new(s).poll_flush(cx),
89         }
90     }
91 
92     #[inline]
poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>93     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
94         match Pin::get_mut(self) {
95             MaybeHttpsStream::Http(s) => Pin::new(s).poll_shutdown(cx),
96             MaybeHttpsStream::Https(s) => Pin::new(s).poll_shutdown(cx),
97         }
98     }
99 }
100