1 //! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls).
2 
3 pub mod client;
4 mod common;
5 pub mod server;
6 
7 use common::{MidHandshake, Stream, TlsState};
8 use futures_core::future::FusedFuture;
9 use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession, Session};
10 use std::future::Future;
11 use std::io;
12 use std::pin::Pin;
13 use std::sync::Arc;
14 use std::task::{Context, Poll};
15 use tokio::io::{AsyncRead, AsyncWrite};
16 use webpki::DNSNameRef;
17 
18 pub use rustls;
19 pub use webpki;
20 
21 /// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
22 #[derive(Clone)]
23 pub struct TlsConnector {
24     inner: Arc<ClientConfig>,
25     #[cfg(feature = "early-data")]
26     early_data: bool,
27 }
28 
29 /// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
30 #[derive(Clone)]
31 pub struct TlsAcceptor {
32     inner: Arc<ServerConfig>,
33 }
34 
35 impl From<Arc<ClientConfig>> for TlsConnector {
from(inner: Arc<ClientConfig>) -> TlsConnector36     fn from(inner: Arc<ClientConfig>) -> TlsConnector {
37         TlsConnector {
38             inner,
39             #[cfg(feature = "early-data")]
40             early_data: false,
41         }
42     }
43 }
44 
45 impl From<Arc<ServerConfig>> for TlsAcceptor {
from(inner: Arc<ServerConfig>) -> TlsAcceptor46     fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
47         TlsAcceptor { inner }
48     }
49 }
50 
51 impl TlsConnector {
52     /// Enable 0-RTT.
53     ///
54     /// If you want to use 0-RTT,
55     /// You must also set `ClientConfig.enable_early_data` to `true`.
56     #[cfg(feature = "early-data")]
early_data(mut self, flag: bool) -> TlsConnector57     pub fn early_data(mut self, flag: bool) -> TlsConnector {
58         self.early_data = flag;
59         self
60     }
61 
62     #[inline]
connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO> where IO: AsyncRead + AsyncWrite + Unpin,63     pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO>
64     where
65         IO: AsyncRead + AsyncWrite + Unpin,
66     {
67         self.connect_with(domain, stream, |_| ())
68     }
69 
connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO> where IO: AsyncRead + AsyncWrite + Unpin, F: FnOnce(&mut ClientSession),70     pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO>
71     where
72         IO: AsyncRead + AsyncWrite + Unpin,
73         F: FnOnce(&mut ClientSession),
74     {
75         let mut session = ClientSession::new(&self.inner, domain);
76         f(&mut session);
77 
78         Connect(MidHandshake::Handshaking(client::TlsStream {
79             io: stream,
80 
81             #[cfg(not(feature = "early-data"))]
82             state: TlsState::Stream,
83 
84             #[cfg(feature = "early-data")]
85             state: if self.early_data && session.early_data().is_some() {
86                 TlsState::EarlyData(0, Vec::new())
87             } else {
88                 TlsState::Stream
89             },
90 
91             session,
92         }))
93     }
94 }
95 
96 impl TlsAcceptor {
97     #[inline]
accept<IO>(&self, stream: IO) -> Accept<IO> where IO: AsyncRead + AsyncWrite + Unpin,98     pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
99     where
100         IO: AsyncRead + AsyncWrite + Unpin,
101     {
102         self.accept_with(stream, |_| ())
103     }
104 
accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO> where IO: AsyncRead + AsyncWrite + Unpin, F: FnOnce(&mut ServerSession),105     pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
106     where
107         IO: AsyncRead + AsyncWrite + Unpin,
108         F: FnOnce(&mut ServerSession),
109     {
110         let mut session = ServerSession::new(&self.inner);
111         f(&mut session);
112 
113         Accept(MidHandshake::Handshaking(server::TlsStream {
114             session,
115             io: stream,
116             state: TlsState::Stream,
117         }))
118     }
119 }
120 
121 /// Future returned from `TlsConnector::connect` which will resolve
122 /// once the connection handshake has finished.
123 pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
124 
125 /// Future returned from `TlsAcceptor::accept` which will resolve
126 /// once the accept handshake has finished.
127 pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
128 
129 /// Like [Connect], but returns `IO` on failure.
130 pub struct FailableConnect<IO>(MidHandshake<client::TlsStream<IO>>);
131 
132 /// Like [Accept], but returns `IO` on failure.
133 pub struct FailableAccept<IO>(MidHandshake<server::TlsStream<IO>>);
134 
135 impl<IO> Connect<IO> {
136     #[inline]
into_failable(self) -> FailableConnect<IO>137     pub fn into_failable(self) -> FailableConnect<IO> {
138         FailableConnect(self.0)
139     }
140 }
141 
142 impl<IO> Accept<IO> {
143     #[inline]
into_failable(self) -> FailableAccept<IO>144     pub fn into_failable(self) -> FailableAccept<IO> {
145         FailableAccept(self.0)
146     }
147 }
148 
149 impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
150     type Output = io::Result<client::TlsStream<IO>>;
151 
152     #[inline]
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>153     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
154         Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
155     }
156 }
157 
158 impl<IO: AsyncRead + AsyncWrite + Unpin> FusedFuture for Connect<IO> {
159     #[inline]
is_terminated(&self) -> bool160     fn is_terminated(&self) -> bool {
161         self.0.is_terminated()
162     }
163 }
164 
165 impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
166     type Output = io::Result<server::TlsStream<IO>>;
167 
168     #[inline]
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>169     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
170         Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
171     }
172 }
173 
174 impl<IO: AsyncRead + AsyncWrite + Unpin> FusedFuture for Accept<IO> {
175     #[inline]
is_terminated(&self) -> bool176     fn is_terminated(&self) -> bool {
177         self.0.is_terminated()
178     }
179 }
180 
181 impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FailableConnect<IO> {
182     type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
183 
184     #[inline]
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>185     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
186         Pin::new(&mut self.0).poll(cx)
187     }
188 }
189 
190 impl<IO: AsyncRead + AsyncWrite + Unpin> FusedFuture for FailableConnect<IO> {
191     #[inline]
is_terminated(&self) -> bool192     fn is_terminated(&self) -> bool {
193         self.0.is_terminated()
194     }
195 }
196 
197 impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FailableAccept<IO> {
198     type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
199 
200     #[inline]
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>201     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
202         Pin::new(&mut self.0).poll(cx)
203     }
204 }
205 
206 impl<IO: AsyncRead + AsyncWrite + Unpin> FusedFuture for FailableAccept<IO> {
207     #[inline]
is_terminated(&self) -> bool208     fn is_terminated(&self) -> bool {
209         self.0.is_terminated()
210     }
211 }
212 
213 /// Unified TLS stream type
214 ///
215 /// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use
216 /// a single type to keep both client- and server-initiated TLS-encrypted connections.
217 pub enum TlsStream<T> {
218     Client(client::TlsStream<T>),
219     Server(server::TlsStream<T>),
220 }
221 
222 impl<T> TlsStream<T> {
get_ref(&self) -> (&T, &dyn Session)223     pub fn get_ref(&self) -> (&T, &dyn Session) {
224         use TlsStream::*;
225         match self {
226             Client(io) => {
227                 let (io, session) = io.get_ref();
228                 (io, &*session)
229             }
230             Server(io) => {
231                 let (io, session) = io.get_ref();
232                 (io, &*session)
233             }
234         }
235     }
236 
get_mut(&mut self) -> (&mut T, &mut dyn Session)237     pub fn get_mut(&mut self) -> (&mut T, &mut dyn Session) {
238         use TlsStream::*;
239         match self {
240             Client(io) => {
241                 let (io, session) = io.get_mut();
242                 (io, &mut *session)
243             }
244             Server(io) => {
245                 let (io, session) = io.get_mut();
246                 (io, &mut *session)
247             }
248         }
249     }
250 }
251 
252 impl<T> From<client::TlsStream<T>> for TlsStream<T> {
from(s: client::TlsStream<T>) -> Self253     fn from(s: client::TlsStream<T>) -> Self {
254         Self::Client(s)
255     }
256 }
257 
258 impl<T> From<server::TlsStream<T>> for TlsStream<T> {
from(s: server::TlsStream<T>) -> Self259     fn from(s: server::TlsStream<T>) -> Self {
260         Self::Server(s)
261     }
262 }
263 
264 impl<T> AsyncRead for TlsStream<T>
265 where
266     T: AsyncRead + AsyncWrite + Unpin,
267 {
268     #[inline]
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<io::Result<usize>>269     fn poll_read(
270         self: Pin<&mut Self>,
271         cx: &mut Context<'_>,
272         buf: &mut [u8],
273     ) -> Poll<io::Result<usize>> {
274         match self.get_mut() {
275             TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
276             TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
277         }
278     }
279 }
280 
281 impl<T> AsyncWrite for TlsStream<T>
282 where
283     T: AsyncRead + AsyncWrite + Unpin,
284 {
285     #[inline]
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>286     fn poll_write(
287         self: Pin<&mut Self>,
288         cx: &mut Context<'_>,
289         buf: &[u8],
290     ) -> Poll<io::Result<usize>> {
291         match self.get_mut() {
292             TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
293             TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
294         }
295     }
296 
297     #[inline]
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>298     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
299         match self.get_mut() {
300             TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
301             TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
302         }
303     }
304 
305     #[inline]
poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>306     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
307         match self.get_mut() {
308             TlsStream::Client(x) => Pin::new(x).poll_shutdown(cx),
309             TlsStream::Server(x) => Pin::new(x).poll_shutdown(cx),
310         }
311     }
312 }
313