1 //! Hyper SSL support via OpenSSL.
2 #![warn(missing_docs)]
3 #![doc(html_root_url = "https://docs.rs/hyper-openssl/0.8")]
4 
5 use crate::cache::{SessionCache, SessionKey};
6 use antidote::Mutex;
7 use bytes::{Buf, BufMut};
8 use http::uri::Scheme;
9 use hyper::client::connect::{Connected, Connection};
10 #[cfg(feature = "runtime")]
11 use hyper::client::HttpConnector;
12 use hyper::service::Service;
13 use hyper::Uri;
14 use once_cell::sync::OnceCell;
15 use openssl::error::ErrorStack;
16 use openssl::ex_data::Index;
17 #[cfg(feature = "runtime")]
18 use openssl::ssl::SslMethod;
19 use openssl::ssl::{
20     ConnectConfiguration, Ssl, SslConnector, SslConnectorBuilder, SslSessionCacheMode,
21 };
22 use std::error::Error;
23 use std::fmt::Debug;
24 use std::future::Future;
25 use std::io;
26 use std::mem::MaybeUninit;
27 use std::pin::Pin;
28 use std::sync::Arc;
29 use std::task::{Context, Poll};
30 use tokio::io::{AsyncRead, AsyncWrite};
31 use tokio_openssl::SslStream;
32 
33 mod cache;
34 #[cfg(test)]
35 mod test;
36 
key_index() -> Result<Index<Ssl, SessionKey>, ErrorStack>37 fn key_index() -> Result<Index<Ssl, SessionKey>, ErrorStack> {
38     static IDX: OnceCell<Index<Ssl, SessionKey>> = OnceCell::new();
39     IDX.get_or_try_init(|| Ssl::new_ex_index()).map(|v| *v)
40 }
41 
42 #[derive(Clone)]
43 struct Inner {
44     ssl: SslConnector,
45     cache: Arc<Mutex<SessionCache>>,
46     callback: Option<
47         Arc<dyn Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + Sync + Send>,
48     >,
49 }
50 
51 impl Inner {
setup_ssl(&self, uri: &Uri, host: &str) -> Result<ConnectConfiguration, ErrorStack>52     fn setup_ssl(&self, uri: &Uri, host: &str) -> Result<ConnectConfiguration, ErrorStack> {
53         let mut conf = self.ssl.configure()?;
54 
55         if let Some(ref callback) = self.callback {
56             callback(&mut conf, uri)?;
57         }
58 
59         let key = SessionKey {
60             host: host.to_string(),
61             port: uri.port_u16().unwrap_or(443),
62         };
63 
64         if let Some(session) = self.cache.lock().get(&key) {
65             unsafe {
66                 conf.set_session(&session)?;
67             }
68         }
69 
70         let idx = key_index()?;
71         conf.set_ex_data(idx, key);
72 
73         Ok(conf)
74     }
75 }
76 
77 /// A Connector using OpenSSL to support `http` and `https` schemes.
78 #[derive(Clone)]
79 pub struct HttpsConnector<T> {
80     http: T,
81     inner: Inner,
82 }
83 
84 #[cfg(feature = "runtime")]
85 impl HttpsConnector<HttpConnector> {
86     /// Creates a a new `HttpsConnector` using default settings.
87     ///
88     /// The Hyper `HttpConnector` is used to perform the TCP socket connection. ALPN is configured to support both
89     /// HTTP/2 and HTTP/1.1.
90     ///
91     /// Requires the `runtime` Cargo feature.
new() -> Result<HttpsConnector<HttpConnector>, ErrorStack>92     pub fn new() -> Result<HttpsConnector<HttpConnector>, ErrorStack> {
93         let mut http = HttpConnector::new();
94         http.enforce_http(false);
95         let mut ssl = SslConnector::builder(SslMethod::tls())?;
96         // avoid unused_mut warnings when building against OpenSSL 1.0.1
97         ssl = ssl;
98 
99         #[cfg(ossl102)]
100         ssl.set_alpn_protos(b"\x02h2\x08http/1.1")?;
101 
102         HttpsConnector::with_connector(http, ssl)
103     }
104 }
105 
106 impl<S, T> HttpsConnector<S>
107 where
108     S: Service<Uri, Response = T> + Send,
109     S::Error: Into<Box<dyn Error + Send + Sync>>,
110     S::Future: Unpin + Send + 'static,
111     T: AsyncRead + AsyncWrite + Connection + Unpin + Debug + Sync + Send + 'static,
112 {
113     /// Creates a new `HttpsConnector`.
114     ///
115     /// The session cache configuration of `ssl` will be overwritten.
with_connector( http: S, mut ssl: SslConnectorBuilder, ) -> Result<HttpsConnector<S>, ErrorStack>116     pub fn with_connector(
117         http: S,
118         mut ssl: SslConnectorBuilder,
119     ) -> Result<HttpsConnector<S>, ErrorStack> {
120         let cache = Arc::new(Mutex::new(SessionCache::new()));
121 
122         ssl.set_session_cache_mode(SslSessionCacheMode::CLIENT);
123 
124         ssl.set_new_session_callback({
125             let cache = cache.clone();
126             move |ssl, session| {
127                 if let Some(key) = key_index().ok().and_then(|idx| ssl.ex_data(idx)) {
128                     cache.lock().insert(key.clone(), session);
129                 }
130             }
131         });
132 
133         ssl.set_remove_session_callback({
134             let cache = cache.clone();
135             move |_, session| cache.lock().remove(session)
136         });
137 
138         Ok(HttpsConnector {
139             http,
140             inner: Inner {
141                 ssl: ssl.build(),
142                 cache,
143                 callback: None,
144             },
145         })
146     }
147 
148     /// Registers a callback which can customize the configuration of each connection.
set_callback<F>(&mut self, callback: F) where F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,149     pub fn set_callback<F>(&mut self, callback: F)
150     where
151         F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
152     {
153         self.inner.callback = Some(Arc::new(callback));
154     }
155 }
156 
157 impl<S> Service<Uri> for HttpsConnector<S>
158 where
159     S: Service<Uri> + Send,
160     S::Error: Into<Box<dyn Error + Send + Sync>>,
161     S::Future: Unpin + Send + 'static,
162     S::Response: AsyncRead + AsyncWrite + Connection + Unpin + Debug + Sync + Send + 'static,
163 {
164     type Response = MaybeHttpsStream<S::Response>;
165     type Error = Box<dyn Error + Sync + Send>;
166     type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
167 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>168     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
169         self.http.poll_ready(cx).map_err(Into::into)
170     }
171 
call(&mut self, uri: Uri) -> Self::Future172     fn call(&mut self, uri: Uri) -> Self::Future {
173         let tls_setup = if uri.scheme() == Some(&Scheme::HTTPS) {
174             Some((self.inner.clone(), uri.clone()))
175         } else {
176             None
177         };
178 
179         let connect = self.http.call(uri);
180 
181         let f = async {
182             let conn = connect.await.map_err(Into::into)?;
183 
184             let (inner, uri) = match tls_setup {
185                 Some((inner, uri)) => (inner, uri),
186                 None => return Ok(MaybeHttpsStream::Http(conn)),
187             };
188 
189             let host = uri.host().ok_or_else(|| "URI missing host")?;
190 
191             let config = inner.setup_ssl(&uri, host)?;
192             let stream = tokio_openssl::connect(config, host, conn).await?;
193 
194             Ok(MaybeHttpsStream::Https(stream))
195         };
196 
197         Box::pin(f)
198     }
199 }
200 
201 /// A stream which may be wrapped with TLS.
202 pub enum MaybeHttpsStream<T> {
203     /// A raw HTTP stream.
204     Http(T),
205     /// An SSL-wrapped HTTP stream.
206     Https(SslStream<T>),
207 }
208 
209 impl<T> AsyncRead for MaybeHttpsStream<T>
210 where
211     T: AsyncRead + AsyncWrite + Unpin,
212 {
prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool213     unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
214         match &*self {
215             MaybeHttpsStream::Http(s) => s.prepare_uninitialized_buffer(buf),
216             MaybeHttpsStream::Https(s) => s.prepare_uninitialized_buffer(buf),
217         }
218     }
219 
poll_read( mut self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<io::Result<usize>>220     fn poll_read(
221         mut self: Pin<&mut Self>,
222         ctx: &mut Context<'_>,
223         buf: &mut [u8],
224     ) -> Poll<io::Result<usize>> {
225         match &mut *self {
226             MaybeHttpsStream::Http(s) => Pin::new(s).poll_read(ctx, buf),
227             MaybeHttpsStream::Https(s) => Pin::new(s).poll_read(ctx, buf),
228         }
229     }
230 
poll_read_buf<B>( mut self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &mut B, ) -> Poll<io::Result<usize>> where B: BufMut,231     fn poll_read_buf<B>(
232         mut self: Pin<&mut Self>,
233         ctx: &mut Context<'_>,
234         buf: &mut B,
235     ) -> Poll<io::Result<usize>>
236     where
237         B: BufMut,
238     {
239         match &mut *self {
240             MaybeHttpsStream::Http(s) => Pin::new(s).poll_read_buf(ctx, buf),
241             MaybeHttpsStream::Https(s) => Pin::new(s).poll_read_buf(ctx, buf),
242         }
243     }
244 }
245 
246 impl<T> AsyncWrite for MaybeHttpsStream<T>
247 where
248     T: AsyncRead + AsyncWrite + Unpin,
249 {
poll_write( mut self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>250     fn poll_write(
251         mut self: Pin<&mut Self>,
252         ctx: &mut Context<'_>,
253         buf: &[u8],
254     ) -> Poll<io::Result<usize>> {
255         match &mut *self {
256             MaybeHttpsStream::Http(s) => Pin::new(s).poll_write(ctx, buf),
257             MaybeHttpsStream::Https(s) => Pin::new(s).poll_write(ctx, buf),
258         }
259     }
260 
poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>>261     fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
262         match &mut *self {
263             MaybeHttpsStream::Http(s) => Pin::new(s).poll_flush(ctx),
264             MaybeHttpsStream::Https(s) => Pin::new(s).poll_flush(ctx),
265         }
266     }
267 
poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>>268     fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
269         match &mut *self {
270             MaybeHttpsStream::Http(s) => Pin::new(s).poll_shutdown(ctx),
271             MaybeHttpsStream::Https(s) => Pin::new(s).poll_shutdown(ctx),
272         }
273     }
274 
poll_write_buf<B>( mut self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &mut B, ) -> Poll<io::Result<usize>> where B: Buf,275     fn poll_write_buf<B>(
276         mut self: Pin<&mut Self>,
277         ctx: &mut Context<'_>,
278         buf: &mut B,
279     ) -> Poll<io::Result<usize>>
280     where
281         B: Buf,
282     {
283         match &mut *self {
284             MaybeHttpsStream::Http(s) => Pin::new(s).poll_write_buf(ctx, buf),
285             MaybeHttpsStream::Https(s) => Pin::new(s).poll_write_buf(ctx, buf),
286         }
287     }
288 }
289 
290 impl<T> Connection for MaybeHttpsStream<T>
291 where
292     T: Connection,
293 {
connected(&self) -> Connected294     fn connected(&self) -> Connected {
295         match self {
296             MaybeHttpsStream::Http(s) => s.connected(),
297             MaybeHttpsStream::Https(s) => {
298                 let mut connected = s.get_ref().connected();
299                 // Avoid unused_mut warnings on OpenSSL 1.0.1
300                 connected = connected;
301                 #[cfg(ossl102)]
302                 {
303                     if s.ssl().selected_alpn_protocol() == Some(b"h2") {
304                         connected = connected.negotiated_h2();
305                     }
306                 }
307                 connected
308             }
309         }
310     }
311 }
312