1 use std::fs::File; 2 use std::future::Future; 3 use std::io::{self, BufReader, Cursor, Read}; 4 use std::net::SocketAddr; 5 use std::path::{Path, PathBuf}; 6 use std::pin::Pin; 7 use std::sync::Arc; 8 use std::task::{Context, Poll}; 9 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; 10 11 use futures::ready; 12 use hyper::server::accept::Accept; 13 use hyper::server::conn::{AddrIncoming, AddrStream}; 14 15 use crate::transport::Transport; 16 use tokio_rustls::rustls::{ 17 AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, NoClientAuth, 18 RootCertStore, ServerConfig, TLSError, 19 }; 20 21 /// Represents errors that can occur building the TlsConfig 22 #[derive(Debug)] 23 pub(crate) enum TlsConfigError { 24 Io(io::Error), 25 /// An Error parsing the Certificate 26 CertParseError, 27 /// An Error parsing a Pkcs8 key 28 Pkcs8ParseError, 29 /// An Error parsing a Rsa key 30 RsaParseError, 31 /// An error from an empty key 32 EmptyKey, 33 /// An error from an invalid key 34 InvalidKey(TLSError), 35 } 36 37 impl std::fmt::Display for TlsConfigError { fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 39 match self { 40 TlsConfigError::Io(err) => err.fmt(f), 41 TlsConfigError::CertParseError => write!(f, "certificate parse error"), 42 TlsConfigError::Pkcs8ParseError => write!(f, "pkcs8 parse error"), 43 TlsConfigError::RsaParseError => write!(f, "rsa parse error"), 44 TlsConfigError::EmptyKey => write!(f, "key contains no private key"), 45 TlsConfigError::InvalidKey(err) => write!(f, "key contains an invalid key, {}", err), 46 } 47 } 48 } 49 50 impl std::error::Error for TlsConfigError {} 51 52 /// Tls client authentication configuration. 53 pub(crate) enum TlsClientAuth { 54 /// No client auth. 55 Off, 56 /// Allow any anonymous or authenticated client. 57 Optional(Box<dyn Read + Send + Sync>), 58 /// Allow any authenticated client. 59 Required(Box<dyn Read + Send + Sync>), 60 } 61 62 /// Builder to set the configuration for the Tls server. 63 pub(crate) struct TlsConfigBuilder { 64 cert: Box<dyn Read + Send + Sync>, 65 key: Box<dyn Read + Send + Sync>, 66 client_auth: TlsClientAuth, 67 ocsp_resp: Vec<u8>, 68 } 69 70 impl std::fmt::Debug for TlsConfigBuilder { fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result71 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { 72 f.debug_struct("TlsConfigBuilder").finish() 73 } 74 } 75 76 impl TlsConfigBuilder { 77 /// Create a new TlsConfigBuilder new() -> TlsConfigBuilder78 pub(crate) fn new() -> TlsConfigBuilder { 79 TlsConfigBuilder { 80 key: Box::new(io::empty()), 81 cert: Box::new(io::empty()), 82 client_auth: TlsClientAuth::Off, 83 ocsp_resp: Vec::new(), 84 } 85 } 86 87 /// sets the Tls key via File Path, returns `TlsConfigError::IoError` if the file cannot be open key_path(mut self, path: impl AsRef<Path>) -> Self88 pub(crate) fn key_path(mut self, path: impl AsRef<Path>) -> Self { 89 self.key = Box::new(LazyFile { 90 path: path.as_ref().into(), 91 file: None, 92 }); 93 self 94 } 95 96 /// sets the Tls key via bytes slice key(mut self, key: &[u8]) -> Self97 pub(crate) fn key(mut self, key: &[u8]) -> Self { 98 self.key = Box::new(Cursor::new(Vec::from(key))); 99 self 100 } 101 102 /// Specify the file path for the TLS certificate to use. cert_path(mut self, path: impl AsRef<Path>) -> Self103 pub(crate) fn cert_path(mut self, path: impl AsRef<Path>) -> Self { 104 self.cert = Box::new(LazyFile { 105 path: path.as_ref().into(), 106 file: None, 107 }); 108 self 109 } 110 111 /// sets the Tls certificate via bytes slice cert(mut self, cert: &[u8]) -> Self112 pub(crate) fn cert(mut self, cert: &[u8]) -> Self { 113 self.cert = Box::new(Cursor::new(Vec::from(cert))); 114 self 115 } 116 117 /// Sets the trust anchor for optional Tls client authentication via file path. 118 /// 119 /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any 120 /// of the `client_auth_` methods, then client authentication is disabled by default. client_auth_optional_path(mut self, path: impl AsRef<Path>) -> Self121 pub(crate) fn client_auth_optional_path(mut self, path: impl AsRef<Path>) -> Self { 122 let file = Box::new(LazyFile { 123 path: path.as_ref().into(), 124 file: None, 125 }); 126 self.client_auth = TlsClientAuth::Optional(file); 127 self 128 } 129 130 /// Sets the trust anchor for optional Tls client authentication via bytes slice. 131 /// 132 /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any 133 /// of the `client_auth_` methods, then client authentication is disabled by default. client_auth_optional(mut self, trust_anchor: &[u8]) -> Self134 pub(crate) fn client_auth_optional(mut self, trust_anchor: &[u8]) -> Self { 135 let cursor = Box::new(Cursor::new(Vec::from(trust_anchor))); 136 self.client_auth = TlsClientAuth::Optional(cursor); 137 self 138 } 139 140 /// Sets the trust anchor for required Tls client authentication via file path. 141 /// 142 /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the 143 /// `client_auth_` methods, then client authentication is disabled by default. client_auth_required_path(mut self, path: impl AsRef<Path>) -> Self144 pub(crate) fn client_auth_required_path(mut self, path: impl AsRef<Path>) -> Self { 145 let file = Box::new(LazyFile { 146 path: path.as_ref().into(), 147 file: None, 148 }); 149 self.client_auth = TlsClientAuth::Required(file); 150 self 151 } 152 153 /// Sets the trust anchor for required Tls client authentication via bytes slice. 154 /// 155 /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the 156 /// `client_auth_` methods, then client authentication is disabled by default. client_auth_required(mut self, trust_anchor: &[u8]) -> Self157 pub(crate) fn client_auth_required(mut self, trust_anchor: &[u8]) -> Self { 158 let cursor = Box::new(Cursor::new(Vec::from(trust_anchor))); 159 self.client_auth = TlsClientAuth::Required(cursor); 160 self 161 } 162 163 /// sets the DER-encoded OCSP response ocsp_resp(mut self, ocsp_resp: &[u8]) -> Self164 pub(crate) fn ocsp_resp(mut self, ocsp_resp: &[u8]) -> Self { 165 self.ocsp_resp = Vec::from(ocsp_resp); 166 self 167 } 168 build(mut self) -> Result<ServerConfig, TlsConfigError>169 pub(crate) fn build(mut self) -> Result<ServerConfig, TlsConfigError> { 170 let mut cert_rdr = BufReader::new(self.cert); 171 let cert = tokio_rustls::rustls::internal::pemfile::certs(&mut cert_rdr) 172 .map_err(|()| TlsConfigError::CertParseError)?; 173 174 let key = { 175 // convert it to Vec<u8> to allow reading it again if key is RSA 176 let mut key_vec = Vec::new(); 177 self.key 178 .read_to_end(&mut key_vec) 179 .map_err(TlsConfigError::Io)?; 180 181 if key_vec.is_empty() { 182 return Err(TlsConfigError::EmptyKey); 183 } 184 185 let mut pkcs8 = tokio_rustls::rustls::internal::pemfile::pkcs8_private_keys( 186 &mut key_vec.as_slice(), 187 ) 188 .map_err(|()| TlsConfigError::Pkcs8ParseError)?; 189 190 if !pkcs8.is_empty() { 191 pkcs8.remove(0) 192 } else { 193 let mut rsa = tokio_rustls::rustls::internal::pemfile::rsa_private_keys( 194 &mut key_vec.as_slice(), 195 ) 196 .map_err(|()| TlsConfigError::RsaParseError)?; 197 198 if !rsa.is_empty() { 199 rsa.remove(0) 200 } else { 201 return Err(TlsConfigError::EmptyKey); 202 } 203 } 204 }; 205 206 fn read_trust_anchor( 207 trust_anchor: Box<dyn Read + Send + Sync>, 208 ) -> Result<RootCertStore, TlsConfigError> { 209 let mut reader = BufReader::new(trust_anchor); 210 let mut store = RootCertStore::empty(); 211 if let Ok((0, _)) | Err(()) = store.add_pem_file(&mut reader) { 212 Err(TlsConfigError::CertParseError) 213 } else { 214 Ok(store) 215 } 216 } 217 218 let client_auth = match self.client_auth { 219 TlsClientAuth::Off => NoClientAuth::new(), 220 TlsClientAuth::Optional(trust_anchor) => { 221 AllowAnyAnonymousOrAuthenticatedClient::new(read_trust_anchor(trust_anchor)?) 222 } 223 TlsClientAuth::Required(trust_anchor) => { 224 AllowAnyAuthenticatedClient::new(read_trust_anchor(trust_anchor)?) 225 } 226 }; 227 228 let mut config = ServerConfig::new(client_auth); 229 config 230 .set_single_cert_with_ocsp_and_sct(cert, key, self.ocsp_resp, Vec::new()) 231 .map_err(|err| TlsConfigError::InvalidKey(err))?; 232 config.set_protocols(&["h2".into(), "http/1.1".into()]); 233 Ok(config) 234 } 235 } 236 237 struct LazyFile { 238 path: PathBuf, 239 file: Option<File>, 240 } 241 242 impl LazyFile { lazy_read(&mut self, buf: &mut [u8]) -> io::Result<usize>243 fn lazy_read(&mut self, buf: &mut [u8]) -> io::Result<usize> { 244 if self.file.is_none() { 245 self.file = Some(File::open(&self.path)?); 246 } 247 248 self.file.as_mut().unwrap().read(buf) 249 } 250 } 251 252 impl Read for LazyFile { read(&mut self, buf: &mut [u8]) -> io::Result<usize>253 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { 254 self.lazy_read(buf).map_err(|err| { 255 let kind = err.kind(); 256 io::Error::new( 257 kind, 258 format!("error reading file ({:?}): {}", self.path.display(), err), 259 ) 260 }) 261 } 262 } 263 264 impl Transport for TlsStream { remote_addr(&self) -> Option<SocketAddr>265 fn remote_addr(&self) -> Option<SocketAddr> { 266 Some(self.remote_addr) 267 } 268 } 269 270 enum State { 271 Handshaking(tokio_rustls::Accept<AddrStream>), 272 Streaming(tokio_rustls::server::TlsStream<AddrStream>), 273 } 274 275 // tokio_rustls::server::TlsStream doesn't expose constructor methods, 276 // so we have to TlsAcceptor::accept and handshake to have access to it 277 // TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first 278 pub(crate) struct TlsStream { 279 state: State, 280 remote_addr: SocketAddr, 281 } 282 283 impl TlsStream { new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream284 fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream { 285 let remote_addr = stream.remote_addr(); 286 let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream); 287 TlsStream { 288 state: State::Handshaking(accept), 289 remote_addr, 290 } 291 } 292 } 293 294 impl AsyncRead for TlsStream { poll_read( self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf, ) -> Poll<io::Result<()>>295 fn poll_read( 296 self: Pin<&mut Self>, 297 cx: &mut Context, 298 buf: &mut ReadBuf, 299 ) -> Poll<io::Result<()>> { 300 let pin = self.get_mut(); 301 match pin.state { 302 State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { 303 Ok(mut stream) => { 304 let result = Pin::new(&mut stream).poll_read(cx, buf); 305 pin.state = State::Streaming(stream); 306 result 307 } 308 Err(err) => Poll::Ready(Err(err)), 309 }, 310 State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf), 311 } 312 } 313 } 314 315 impl AsyncWrite for TlsStream { poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>316 fn poll_write( 317 self: Pin<&mut Self>, 318 cx: &mut Context<'_>, 319 buf: &[u8], 320 ) -> Poll<io::Result<usize>> { 321 let pin = self.get_mut(); 322 match pin.state { 323 State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { 324 Ok(mut stream) => { 325 let result = Pin::new(&mut stream).poll_write(cx, buf); 326 pin.state = State::Streaming(stream); 327 result 328 } 329 Err(err) => Poll::Ready(Err(err)), 330 }, 331 State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf), 332 } 333 } 334 poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>335 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 336 match self.state { 337 State::Handshaking(_) => Poll::Ready(Ok(())), 338 State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx), 339 } 340 } 341 poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>342 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 343 match self.state { 344 State::Handshaking(_) => Poll::Ready(Ok(())), 345 State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx), 346 } 347 } 348 } 349 350 pub(crate) struct TlsAcceptor { 351 config: Arc<ServerConfig>, 352 incoming: AddrIncoming, 353 } 354 355 impl TlsAcceptor { new(config: ServerConfig, incoming: AddrIncoming) -> TlsAcceptor356 pub(crate) fn new(config: ServerConfig, incoming: AddrIncoming) -> TlsAcceptor { 357 TlsAcceptor { 358 config: Arc::new(config), 359 incoming, 360 } 361 } 362 } 363 364 impl Accept for TlsAcceptor { 365 type Conn = TlsStream; 366 type Error = io::Error; 367 poll_accept( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Conn, Self::Error>>>368 fn poll_accept( 369 self: Pin<&mut Self>, 370 cx: &mut Context<'_>, 371 ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { 372 let pin = self.get_mut(); 373 match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) { 374 Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))), 375 Some(Err(e)) => Poll::Ready(Some(Err(e))), 376 None => Poll::Ready(None), 377 } 378 } 379 } 380 381 #[cfg(test)] 382 mod tests { 383 use super::*; 384 385 #[test] file_cert_key()386 fn file_cert_key() { 387 TlsConfigBuilder::new() 388 .key_path("examples/tls/key.rsa") 389 .cert_path("examples/tls/cert.pem") 390 .build() 391 .unwrap(); 392 } 393 394 #[test] bytes_cert_key()395 fn bytes_cert_key() { 396 let key = include_str!("../examples/tls/key.rsa"); 397 let cert = include_str!("../examples/tls/cert.pem"); 398 399 TlsConfigBuilder::new() 400 .key(key.as_bytes()) 401 .cert(cert.as_bytes()) 402 .build() 403 .unwrap(); 404 } 405 } 406