1 //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls). 2 3 pub mod client; 4 mod common; 5 pub mod server; 6 7 use common::{MidHandshake, Stream, TlsState}; 8 use futures_core::future::FusedFuture; 9 use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession, Session}; 10 use std::future::Future; 11 use std::io; 12 use std::pin::Pin; 13 use std::sync::Arc; 14 use std::task::{Context, Poll}; 15 use tokio::io::{AsyncRead, AsyncWrite}; 16 use webpki::DNSNameRef; 17 18 pub use rustls; 19 pub use webpki; 20 21 /// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. 22 #[derive(Clone)] 23 pub struct TlsConnector { 24 inner: Arc<ClientConfig>, 25 #[cfg(feature = "early-data")] 26 early_data: bool, 27 } 28 29 /// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method. 30 #[derive(Clone)] 31 pub struct TlsAcceptor { 32 inner: Arc<ServerConfig>, 33 } 34 35 impl From<Arc<ClientConfig>> for TlsConnector { from(inner: Arc<ClientConfig>) -> TlsConnector36 fn from(inner: Arc<ClientConfig>) -> TlsConnector { 37 TlsConnector { 38 inner, 39 #[cfg(feature = "early-data")] 40 early_data: false, 41 } 42 } 43 } 44 45 impl From<Arc<ServerConfig>> for TlsAcceptor { from(inner: Arc<ServerConfig>) -> TlsAcceptor46 fn from(inner: Arc<ServerConfig>) -> TlsAcceptor { 47 TlsAcceptor { inner } 48 } 49 } 50 51 impl TlsConnector { 52 /// Enable 0-RTT. 53 /// 54 /// If you want to use 0-RTT, 55 /// You must also set `ClientConfig.enable_early_data` to `true`. 56 #[cfg(feature = "early-data")] early_data(mut self, flag: bool) -> TlsConnector57 pub fn early_data(mut self, flag: bool) -> TlsConnector { 58 self.early_data = flag; 59 self 60 } 61 62 #[inline] connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO> where IO: AsyncRead + AsyncWrite + Unpin,63 pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO> 64 where 65 IO: AsyncRead + AsyncWrite + Unpin, 66 { 67 self.connect_with(domain, stream, |_| ()) 68 } 69 connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO> where IO: AsyncRead + AsyncWrite + Unpin, F: FnOnce(&mut ClientSession),70 pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO> 71 where 72 IO: AsyncRead + AsyncWrite + Unpin, 73 F: FnOnce(&mut ClientSession), 74 { 75 let mut session = ClientSession::new(&self.inner, domain); 76 f(&mut session); 77 78 Connect(MidHandshake::Handshaking(client::TlsStream { 79 io: stream, 80 81 #[cfg(not(feature = "early-data"))] 82 state: TlsState::Stream, 83 84 #[cfg(feature = "early-data")] 85 state: if self.early_data && session.early_data().is_some() { 86 TlsState::EarlyData(0, Vec::new()) 87 } else { 88 TlsState::Stream 89 }, 90 91 session, 92 })) 93 } 94 } 95 96 impl TlsAcceptor { 97 #[inline] accept<IO>(&self, stream: IO) -> Accept<IO> where IO: AsyncRead + AsyncWrite + Unpin,98 pub fn accept<IO>(&self, stream: IO) -> Accept<IO> 99 where 100 IO: AsyncRead + AsyncWrite + Unpin, 101 { 102 self.accept_with(stream, |_| ()) 103 } 104 accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO> where IO: AsyncRead + AsyncWrite + Unpin, F: FnOnce(&mut ServerSession),105 pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO> 106 where 107 IO: AsyncRead + AsyncWrite + Unpin, 108 F: FnOnce(&mut ServerSession), 109 { 110 let mut session = ServerSession::new(&self.inner); 111 f(&mut session); 112 113 Accept(MidHandshake::Handshaking(server::TlsStream { 114 session, 115 io: stream, 116 state: TlsState::Stream, 117 })) 118 } 119 } 120 121 /// Future returned from `TlsConnector::connect` which will resolve 122 /// once the connection handshake has finished. 123 pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>); 124 125 /// Future returned from `TlsAcceptor::accept` which will resolve 126 /// once the accept handshake has finished. 127 pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>); 128 129 /// Like [Connect], but returns `IO` on failure. 130 pub struct FailableConnect<IO>(MidHandshake<client::TlsStream<IO>>); 131 132 /// Like [Accept], but returns `IO` on failure. 133 pub struct FailableAccept<IO>(MidHandshake<server::TlsStream<IO>>); 134 135 impl<IO> Connect<IO> { 136 #[inline] into_failable(self) -> FailableConnect<IO>137 pub fn into_failable(self) -> FailableConnect<IO> { 138 FailableConnect(self.0) 139 } 140 } 141 142 impl<IO> Accept<IO> { 143 #[inline] into_failable(self) -> FailableAccept<IO>144 pub fn into_failable(self) -> FailableAccept<IO> { 145 FailableAccept(self.0) 146 } 147 } 148 149 impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> { 150 type Output = io::Result<client::TlsStream<IO>>; 151 152 #[inline] poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>153 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 154 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err) 155 } 156 } 157 158 impl<IO: AsyncRead + AsyncWrite + Unpin> FusedFuture for Connect<IO> { 159 #[inline] is_terminated(&self) -> bool160 fn is_terminated(&self) -> bool { 161 self.0.is_terminated() 162 } 163 } 164 165 impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> { 166 type Output = io::Result<server::TlsStream<IO>>; 167 168 #[inline] poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>169 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 170 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err) 171 } 172 } 173 174 impl<IO: AsyncRead + AsyncWrite + Unpin> FusedFuture for Accept<IO> { 175 #[inline] is_terminated(&self) -> bool176 fn is_terminated(&self) -> bool { 177 self.0.is_terminated() 178 } 179 } 180 181 impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FailableConnect<IO> { 182 type Output = Result<client::TlsStream<IO>, (io::Error, IO)>; 183 184 #[inline] poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>185 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 186 Pin::new(&mut self.0).poll(cx) 187 } 188 } 189 190 impl<IO: AsyncRead + AsyncWrite + Unpin> FusedFuture for FailableConnect<IO> { 191 #[inline] is_terminated(&self) -> bool192 fn is_terminated(&self) -> bool { 193 self.0.is_terminated() 194 } 195 } 196 197 impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FailableAccept<IO> { 198 type Output = Result<server::TlsStream<IO>, (io::Error, IO)>; 199 200 #[inline] poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>201 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 202 Pin::new(&mut self.0).poll(cx) 203 } 204 } 205 206 impl<IO: AsyncRead + AsyncWrite + Unpin> FusedFuture for FailableAccept<IO> { 207 #[inline] is_terminated(&self) -> bool208 fn is_terminated(&self) -> bool { 209 self.0.is_terminated() 210 } 211 } 212 213 /// Unified TLS stream type 214 /// 215 /// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use 216 /// a single type to keep both client- and server-initiated TLS-encrypted connections. 217 pub enum TlsStream<T> { 218 Client(client::TlsStream<T>), 219 Server(server::TlsStream<T>), 220 } 221 222 impl<T> TlsStream<T> { get_ref(&self) -> (&T, &dyn Session)223 pub fn get_ref(&self) -> (&T, &dyn Session) { 224 use TlsStream::*; 225 match self { 226 Client(io) => { 227 let (io, session) = io.get_ref(); 228 (io, &*session) 229 } 230 Server(io) => { 231 let (io, session) = io.get_ref(); 232 (io, &*session) 233 } 234 } 235 } 236 get_mut(&mut self) -> (&mut T, &mut dyn Session)237 pub fn get_mut(&mut self) -> (&mut T, &mut dyn Session) { 238 use TlsStream::*; 239 match self { 240 Client(io) => { 241 let (io, session) = io.get_mut(); 242 (io, &mut *session) 243 } 244 Server(io) => { 245 let (io, session) = io.get_mut(); 246 (io, &mut *session) 247 } 248 } 249 } 250 } 251 252 impl<T> From<client::TlsStream<T>> for TlsStream<T> { from(s: client::TlsStream<T>) -> Self253 fn from(s: client::TlsStream<T>) -> Self { 254 Self::Client(s) 255 } 256 } 257 258 impl<T> From<server::TlsStream<T>> for TlsStream<T> { from(s: server::TlsStream<T>) -> Self259 fn from(s: server::TlsStream<T>) -> Self { 260 Self::Server(s) 261 } 262 } 263 264 impl<T> AsyncRead for TlsStream<T> 265 where 266 T: AsyncRead + AsyncWrite + Unpin, 267 { 268 #[inline] poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<io::Result<usize>>269 fn poll_read( 270 self: Pin<&mut Self>, 271 cx: &mut Context<'_>, 272 buf: &mut [u8], 273 ) -> Poll<io::Result<usize>> { 274 match self.get_mut() { 275 TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf), 276 TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf), 277 } 278 } 279 } 280 281 impl<T> AsyncWrite for TlsStream<T> 282 where 283 T: AsyncRead + AsyncWrite + Unpin, 284 { 285 #[inline] poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>286 fn poll_write( 287 self: Pin<&mut Self>, 288 cx: &mut Context<'_>, 289 buf: &[u8], 290 ) -> Poll<io::Result<usize>> { 291 match self.get_mut() { 292 TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf), 293 TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf), 294 } 295 } 296 297 #[inline] poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>298 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 299 match self.get_mut() { 300 TlsStream::Client(x) => Pin::new(x).poll_flush(cx), 301 TlsStream::Server(x) => Pin::new(x).poll_flush(cx), 302 } 303 } 304 305 #[inline] poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>306 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 307 match self.get_mut() { 308 TlsStream::Client(x) => Pin::new(x).poll_shutdown(cx), 309 TlsStream::Server(x) => Pin::new(x).poll_shutdown(cx), 310 } 311 } 312 } 313