1 use std::fmt;
2 use std::future::Future;
3 use std::pin::Pin;
4 use std::task::{Context, Poll};
5 
6 use hyper::{client::connect::HttpConnector, service::Service, Uri};
7 use tokio::io::{AsyncRead, AsyncWrite};
8 use tokio_tls::TlsConnector;
9 
10 use crate::stream::MaybeHttpsStream;
11 
12 type BoxError = Box<dyn std::error::Error + Send + Sync>;
13 
14 /// A Connector for the `https` scheme.
15 #[derive(Clone)]
16 pub struct HttpsConnector<T> {
17     force_https: bool,
18     http: T,
19     tls: TlsConnector,
20 }
21 
22 impl HttpsConnector<HttpConnector> {
23     /// Construct a new HttpsConnector.
24     ///
25     /// This uses hyper's default `HttpConnector`, and default `TlsConnector`.
26     /// If you wish to use something besides the defaults, use `From::from`.
27     ///
28     /// # Note
29     ///
30     /// By default this connector will use plain HTTP if the URL provded uses
31     /// the HTTP scheme (eg: http://example.com/).
32     ///
33     /// If you would like to force the use of HTTPS then call https_only(true)
34     /// on the returned connector.
35     ///
36     /// # Panics
37     ///
38     /// This will panic if the underlying TLS context could not be created.
39     ///
40     /// To handle that error yourself, you can use the `HttpsConnector::from`
41     /// constructor after trying to make a `TlsConnector`.
new() -> Self42     pub fn new() -> Self {
43         native_tls::TlsConnector::new()
44             .map(|tls| HttpsConnector::new_(tls.into()))
45             .unwrap_or_else(|e| panic!("HttpsConnector::new() failure: {}", e))
46     }
47 
new_(tls: TlsConnector) -> Self48     fn new_(tls: TlsConnector) -> Self {
49         let mut http = HttpConnector::new();
50         http.enforce_http(false);
51         HttpsConnector::from((http, tls))
52     }
53 }
54 
55 impl<T: Default> Default for HttpsConnector<T> {
default() -> Self56     fn default() -> Self {
57         Self::new_with_connector(Default::default())
58     }
59 }
60 
61 impl<T> HttpsConnector<T> {
62     /// Force the use of HTTPS when connecting.
63     ///
64     /// If a URL is not `https` when connecting, an error is returned.
https_only(&mut self, enable: bool)65     pub fn https_only(&mut self, enable: bool) {
66         self.force_https = enable;
67     }
68 
69     /// With connector constructor
70     ///
new_with_connector(http: T) -> Self71     pub fn new_with_connector(http: T) -> Self {
72         native_tls::TlsConnector::new()
73             .map(|tls| HttpsConnector::from((http, tls.into())))
74             .unwrap_or_else(|e| panic!("HttpsConnector::new_with_connector(<connector>) failure: {}", e))
75     }
76 }
77 
78 impl<T> From<(T, TlsConnector)> for HttpsConnector<T> {
from(args: (T, TlsConnector)) -> HttpsConnector<T>79     fn from(args: (T, TlsConnector)) -> HttpsConnector<T> {
80         HttpsConnector {
81             force_https: false,
82             http: args.0,
83             tls: args.1,
84         }
85     }
86 }
87 
88 impl<T: fmt::Debug> fmt::Debug for HttpsConnector<T> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result89     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
90         f.debug_struct("HttpsConnector")
91             .field("force_https", &self.force_https)
92             .field("http", &self.http)
93             .finish()
94     }
95 }
96 
97 impl<T> Service<Uri> for HttpsConnector<T>
98 where
99     T: Service<Uri>,
100     T::Response: AsyncRead + AsyncWrite + Send + Unpin,
101     T::Future: Send + 'static,
102     T::Error: Into<BoxError>,
103 {
104     type Response = MaybeHttpsStream<T::Response>;
105     type Error = BoxError;
106     type Future = HttpsConnecting<T::Response>;
107 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>108     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109         match self.http.poll_ready(cx) {
110             Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
111             Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
112             Poll::Pending => Poll::Pending,
113         }
114     }
115 
call(&mut self, dst: Uri) -> Self::Future116     fn call(&mut self, dst: Uri) -> Self::Future {
117         let is_https = dst.scheme_str() == Some("https");
118         // Early abort if HTTPS is forced but can't be used
119         if !is_https && self.force_https {
120             return err(ForceHttpsButUriNotHttps.into());
121         }
122 
123         let host = dst.host().unwrap_or("").trim_matches(|c| c == '[' || c == ']').to_owned();
124         let connecting = self.http.call(dst);
125         let tls = self.tls.clone();
126         let fut = async move {
127             let tcp = connecting.await.map_err(Into::into)?;
128             let maybe = if is_https {
129                 let tls = tls
130                     .connect(&host, tcp)
131                     .await?;
132                 MaybeHttpsStream::Https(tls)
133             } else {
134                 MaybeHttpsStream::Http(tcp)
135             };
136             Ok(maybe)
137         };
138         HttpsConnecting(Box::pin(fut))
139     }
140 }
141 
err<T>(e: BoxError) -> HttpsConnecting<T>142 fn err<T>(e: BoxError) -> HttpsConnecting<T> {
143     HttpsConnecting(Box::pin(async { Err(e) }))
144 }
145 
146 type BoxedFut<T> =
147     Pin<Box<dyn Future<Output = Result<MaybeHttpsStream<T>, BoxError>> + Send>>;
148 
149 /// A Future representing work to connect to a URL, and a TLS handshake.
150 pub struct HttpsConnecting<T>(BoxedFut<T>);
151 
152 impl<T: AsyncRead + AsyncWrite + Unpin> Future for HttpsConnecting<T> {
153     type Output = Result<MaybeHttpsStream<T>, BoxError>;
154 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>155     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
156         Pin::new(&mut self.0).poll(cx)
157     }
158 }
159 
160 impl<T> fmt::Debug for HttpsConnecting<T> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result161     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
162         f.pad("HttpsConnecting")
163     }
164 }
165 
166 // ===== Custom Errors =====
167 
168 #[derive(Debug)]
169 struct ForceHttpsButUriNotHttps;
170 
171 impl fmt::Display for ForceHttpsButUriNotHttps {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result172     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173         f.write_str("https required but URI was not https")
174     }
175 }
176 
177 impl std::error::Error for ForceHttpsButUriNotHttps {}
178