1 use std::fs::File;
2 use std::future::Future;
3 use std::io::{self, BufReader, Cursor, Read};
4 use std::net::SocketAddr;
5 use std::path::{Path, PathBuf};
6 use std::pin::Pin;
7 use std::sync::Arc;
8 use std::task::{Context, Poll};
9 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10 
11 use futures::ready;
12 use hyper::server::accept::Accept;
13 use hyper::server::conn::{AddrIncoming, AddrStream};
14 
15 use crate::transport::Transport;
16 use tokio_rustls::rustls::{
17     AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, NoClientAuth,
18     RootCertStore, ServerConfig, TLSError,
19 };
20 
21 /// Represents errors that can occur building the TlsConfig
22 #[derive(Debug)]
23 pub(crate) enum TlsConfigError {
24     Io(io::Error),
25     /// An Error parsing the Certificate
26     CertParseError,
27     /// An Error parsing a Pkcs8 key
28     Pkcs8ParseError,
29     /// An Error parsing a Rsa key
30     RsaParseError,
31     /// An error from an empty key
32     EmptyKey,
33     /// An error from an invalid key
34     InvalidKey(TLSError),
35 }
36 
37 impl std::fmt::Display for TlsConfigError {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result38     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39         match self {
40             TlsConfigError::Io(err) => err.fmt(f),
41             TlsConfigError::CertParseError => write!(f, "certificate parse error"),
42             TlsConfigError::Pkcs8ParseError => write!(f, "pkcs8 parse error"),
43             TlsConfigError::RsaParseError => write!(f, "rsa parse error"),
44             TlsConfigError::EmptyKey => write!(f, "key contains no private key"),
45             TlsConfigError::InvalidKey(err) => write!(f, "key contains an invalid key, {}", err),
46         }
47     }
48 }
49 
50 impl std::error::Error for TlsConfigError {}
51 
52 /// Tls client authentication configuration.
53 pub(crate) enum TlsClientAuth {
54     /// No client auth.
55     Off,
56     /// Allow any anonymous or authenticated client.
57     Optional(Box<dyn Read + Send + Sync>),
58     /// Allow any authenticated client.
59     Required(Box<dyn Read + Send + Sync>),
60 }
61 
62 /// Builder to set the configuration for the Tls server.
63 pub(crate) struct TlsConfigBuilder {
64     cert: Box<dyn Read + Send + Sync>,
65     key: Box<dyn Read + Send + Sync>,
66     client_auth: TlsClientAuth,
67     ocsp_resp: Vec<u8>,
68 }
69 
70 impl std::fmt::Debug for TlsConfigBuilder {
fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result71     fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
72         f.debug_struct("TlsConfigBuilder").finish()
73     }
74 }
75 
76 impl TlsConfigBuilder {
77     /// Create a new TlsConfigBuilder
new() -> TlsConfigBuilder78     pub(crate) fn new() -> TlsConfigBuilder {
79         TlsConfigBuilder {
80             key: Box::new(io::empty()),
81             cert: Box::new(io::empty()),
82             client_auth: TlsClientAuth::Off,
83             ocsp_resp: Vec::new(),
84         }
85     }
86 
87     /// sets the Tls key via File Path, returns `TlsConfigError::IoError` if the file cannot be open
key_path(mut self, path: impl AsRef<Path>) -> Self88     pub(crate) fn key_path(mut self, path: impl AsRef<Path>) -> Self {
89         self.key = Box::new(LazyFile {
90             path: path.as_ref().into(),
91             file: None,
92         });
93         self
94     }
95 
96     /// sets the Tls key via bytes slice
key(mut self, key: &[u8]) -> Self97     pub(crate) fn key(mut self, key: &[u8]) -> Self {
98         self.key = Box::new(Cursor::new(Vec::from(key)));
99         self
100     }
101 
102     /// Specify the file path for the TLS certificate to use.
cert_path(mut self, path: impl AsRef<Path>) -> Self103     pub(crate) fn cert_path(mut self, path: impl AsRef<Path>) -> Self {
104         self.cert = Box::new(LazyFile {
105             path: path.as_ref().into(),
106             file: None,
107         });
108         self
109     }
110 
111     /// sets the Tls certificate via bytes slice
cert(mut self, cert: &[u8]) -> Self112     pub(crate) fn cert(mut self, cert: &[u8]) -> Self {
113         self.cert = Box::new(Cursor::new(Vec::from(cert)));
114         self
115     }
116 
117     /// Sets the trust anchor for optional Tls client authentication via file path.
118     ///
119     /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any
120     /// of the `client_auth_` methods, then client authentication is disabled by default.
client_auth_optional_path(mut self, path: impl AsRef<Path>) -> Self121     pub(crate) fn client_auth_optional_path(mut self, path: impl AsRef<Path>) -> Self {
122         let file = Box::new(LazyFile {
123             path: path.as_ref().into(),
124             file: None,
125         });
126         self.client_auth = TlsClientAuth::Optional(file);
127         self
128     }
129 
130     /// Sets the trust anchor for optional Tls client authentication via bytes slice.
131     ///
132     /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any
133     /// of the `client_auth_` methods, then client authentication is disabled by default.
client_auth_optional(mut self, trust_anchor: &[u8]) -> Self134     pub(crate) fn client_auth_optional(mut self, trust_anchor: &[u8]) -> Self {
135         let cursor = Box::new(Cursor::new(Vec::from(trust_anchor)));
136         self.client_auth = TlsClientAuth::Optional(cursor);
137         self
138     }
139 
140     /// Sets the trust anchor for required Tls client authentication via file path.
141     ///
142     /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the
143     /// `client_auth_` methods, then client authentication is disabled by default.
client_auth_required_path(mut self, path: impl AsRef<Path>) -> Self144     pub(crate) fn client_auth_required_path(mut self, path: impl AsRef<Path>) -> Self {
145         let file = Box::new(LazyFile {
146             path: path.as_ref().into(),
147             file: None,
148         });
149         self.client_auth = TlsClientAuth::Required(file);
150         self
151     }
152 
153     /// Sets the trust anchor for required Tls client authentication via bytes slice.
154     ///
155     /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the
156     /// `client_auth_` methods, then client authentication is disabled by default.
client_auth_required(mut self, trust_anchor: &[u8]) -> Self157     pub(crate) fn client_auth_required(mut self, trust_anchor: &[u8]) -> Self {
158         let cursor = Box::new(Cursor::new(Vec::from(trust_anchor)));
159         self.client_auth = TlsClientAuth::Required(cursor);
160         self
161     }
162 
163     /// sets the DER-encoded OCSP response
ocsp_resp(mut self, ocsp_resp: &[u8]) -> Self164     pub(crate) fn ocsp_resp(mut self, ocsp_resp: &[u8]) -> Self {
165         self.ocsp_resp = Vec::from(ocsp_resp);
166         self
167     }
168 
build(mut self) -> Result<ServerConfig, TlsConfigError>169     pub(crate) fn build(mut self) -> Result<ServerConfig, TlsConfigError> {
170         let mut cert_rdr = BufReader::new(self.cert);
171         let cert = tokio_rustls::rustls::internal::pemfile::certs(&mut cert_rdr)
172             .map_err(|()| TlsConfigError::CertParseError)?;
173 
174         let key = {
175             // convert it to Vec<u8> to allow reading it again if key is RSA
176             let mut key_vec = Vec::new();
177             self.key
178                 .read_to_end(&mut key_vec)
179                 .map_err(TlsConfigError::Io)?;
180 
181             if key_vec.is_empty() {
182                 return Err(TlsConfigError::EmptyKey);
183             }
184 
185             let mut pkcs8 = tokio_rustls::rustls::internal::pemfile::pkcs8_private_keys(
186                 &mut key_vec.as_slice(),
187             )
188             .map_err(|()| TlsConfigError::Pkcs8ParseError)?;
189 
190             if !pkcs8.is_empty() {
191                 pkcs8.remove(0)
192             } else {
193                 let mut rsa = tokio_rustls::rustls::internal::pemfile::rsa_private_keys(
194                     &mut key_vec.as_slice(),
195                 )
196                 .map_err(|()| TlsConfigError::RsaParseError)?;
197 
198                 if !rsa.is_empty() {
199                     rsa.remove(0)
200                 } else {
201                     return Err(TlsConfigError::EmptyKey);
202                 }
203             }
204         };
205 
206         fn read_trust_anchor(
207             trust_anchor: Box<dyn Read + Send + Sync>,
208         ) -> Result<RootCertStore, TlsConfigError> {
209             let mut reader = BufReader::new(trust_anchor);
210             let mut store = RootCertStore::empty();
211             if let Ok((0, _)) | Err(()) = store.add_pem_file(&mut reader) {
212                 Err(TlsConfigError::CertParseError)
213             } else {
214                 Ok(store)
215             }
216         }
217 
218         let client_auth = match self.client_auth {
219             TlsClientAuth::Off => NoClientAuth::new(),
220             TlsClientAuth::Optional(trust_anchor) => {
221                 AllowAnyAnonymousOrAuthenticatedClient::new(read_trust_anchor(trust_anchor)?)
222             }
223             TlsClientAuth::Required(trust_anchor) => {
224                 AllowAnyAuthenticatedClient::new(read_trust_anchor(trust_anchor)?)
225             }
226         };
227 
228         let mut config = ServerConfig::new(client_auth);
229         config
230             .set_single_cert_with_ocsp_and_sct(cert, key, self.ocsp_resp, Vec::new())
231             .map_err(|err| TlsConfigError::InvalidKey(err))?;
232         config.set_protocols(&["h2".into(), "http/1.1".into()]);
233         Ok(config)
234     }
235 }
236 
237 struct LazyFile {
238     path: PathBuf,
239     file: Option<File>,
240 }
241 
242 impl LazyFile {
lazy_read(&mut self, buf: &mut [u8]) -> io::Result<usize>243     fn lazy_read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
244         if self.file.is_none() {
245             self.file = Some(File::open(&self.path)?);
246         }
247 
248         self.file.as_mut().unwrap().read(buf)
249     }
250 }
251 
252 impl Read for LazyFile {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>253     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
254         self.lazy_read(buf).map_err(|err| {
255             let kind = err.kind();
256             io::Error::new(
257                 kind,
258                 format!("error reading file ({:?}): {}", self.path.display(), err),
259             )
260         })
261     }
262 }
263 
264 impl Transport for TlsStream {
remote_addr(&self) -> Option<SocketAddr>265     fn remote_addr(&self) -> Option<SocketAddr> {
266         Some(self.remote_addr)
267     }
268 }
269 
270 enum State {
271     Handshaking(tokio_rustls::Accept<AddrStream>),
272     Streaming(tokio_rustls::server::TlsStream<AddrStream>),
273 }
274 
275 // tokio_rustls::server::TlsStream doesn't expose constructor methods,
276 // so we have to TlsAcceptor::accept and handshake to have access to it
277 // TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
278 pub(crate) struct TlsStream {
279     state: State,
280     remote_addr: SocketAddr,
281 }
282 
283 impl TlsStream {
new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream284     fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
285         let remote_addr = stream.remote_addr();
286         let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
287         TlsStream {
288             state: State::Handshaking(accept),
289             remote_addr,
290         }
291     }
292 }
293 
294 impl AsyncRead for TlsStream {
poll_read( self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf, ) -> Poll<io::Result<()>>295     fn poll_read(
296         self: Pin<&mut Self>,
297         cx: &mut Context,
298         buf: &mut ReadBuf,
299     ) -> Poll<io::Result<()>> {
300         let pin = self.get_mut();
301         match pin.state {
302             State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
303                 Ok(mut stream) => {
304                     let result = Pin::new(&mut stream).poll_read(cx, buf);
305                     pin.state = State::Streaming(stream);
306                     result
307                 }
308                 Err(err) => Poll::Ready(Err(err)),
309             },
310             State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
311         }
312     }
313 }
314 
315 impl AsyncWrite for TlsStream {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>316     fn poll_write(
317         self: Pin<&mut Self>,
318         cx: &mut Context<'_>,
319         buf: &[u8],
320     ) -> Poll<io::Result<usize>> {
321         let pin = self.get_mut();
322         match pin.state {
323             State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
324                 Ok(mut stream) => {
325                     let result = Pin::new(&mut stream).poll_write(cx, buf);
326                     pin.state = State::Streaming(stream);
327                     result
328                 }
329                 Err(err) => Poll::Ready(Err(err)),
330             },
331             State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
332         }
333     }
334 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>335     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
336         match self.state {
337             State::Handshaking(_) => Poll::Ready(Ok(())),
338             State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
339         }
340     }
341 
poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>342     fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
343         match self.state {
344             State::Handshaking(_) => Poll::Ready(Ok(())),
345             State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
346         }
347     }
348 }
349 
350 pub(crate) struct TlsAcceptor {
351     config: Arc<ServerConfig>,
352     incoming: AddrIncoming,
353 }
354 
355 impl TlsAcceptor {
new(config: ServerConfig, incoming: AddrIncoming) -> TlsAcceptor356     pub(crate) fn new(config: ServerConfig, incoming: AddrIncoming) -> TlsAcceptor {
357         TlsAcceptor {
358             config: Arc::new(config),
359             incoming,
360         }
361     }
362 }
363 
364 impl Accept for TlsAcceptor {
365     type Conn = TlsStream;
366     type Error = io::Error;
367 
poll_accept( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Conn, Self::Error>>>368     fn poll_accept(
369         self: Pin<&mut Self>,
370         cx: &mut Context<'_>,
371     ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
372         let pin = self.get_mut();
373         match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
374             Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
375             Some(Err(e)) => Poll::Ready(Some(Err(e))),
376             None => Poll::Ready(None),
377         }
378     }
379 }
380 
381 #[cfg(test)]
382 mod tests {
383     use super::*;
384 
385     #[test]
file_cert_key()386     fn file_cert_key() {
387         TlsConfigBuilder::new()
388             .key_path("examples/tls/key.rsa")
389             .cert_path("examples/tls/cert.pem")
390             .build()
391             .unwrap();
392     }
393 
394     #[test]
bytes_cert_key()395     fn bytes_cert_key() {
396         let key = include_str!("../examples/tls/key.rsa");
397         let cert = include_str!("../examples/tls/cert.pem");
398 
399         TlsConfigBuilder::new()
400             .key(key.as_bytes())
401             .cert(cert.as_bytes())
402             .build()
403             .unwrap();
404     }
405 }
406