1 // Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
2 //
3 // Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4 // http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5 // http://opensource.org/licenses/MIT>, at your option. This file may not be
6 // copied, modified, or distributed except according to those terms.
7 
8 use std::net::SocketAddr;
9 use std::pin::Pin;
10 use std::sync::{Arc, Mutex};
11 use std::task::{Context, Poll};
12 use std::time::Duration;
13 
14 use futures::{Future, FutureExt, TryFutureExt};
15 use tokio;
16 use tokio::net::TcpStream as TokioTcpStream;
17 use tokio::net::UdpSocket as TokioUdpSocket;
18 
19 use proto;
20 #[cfg(feature = "mdns")]
21 use proto::multicast::{MdnsClientStream, MdnsQueryType};
22 use proto::op::NoopMessageFinalizer;
23 use proto::tcp::TcpClientStream;
24 use proto::udp::{UdpClientStream, UdpResponse};
25 use proto::xfer::{
26     self, BufDnsRequestStreamHandle, DnsExchange, DnsHandle, DnsMultiplexer,
27     DnsMultiplexerSerialResponse, DnsRequest, DnsResponse,
28 };
29 #[cfg(feature = "dns-over-https")]
30 use trust_dns_https::{self, HttpsClientResponse};
31 
32 #[cfg(feature = "dns-over-rustls")]
33 use crate::config::TlsClientConfig;
34 use crate::config::{NameServerConfig, Protocol, ResolverOpts};
35 
36 /// A type to allow for custom ConnectionProviders. Needed mainly for mocking purposes.
37 pub trait ConnectionProvider: 'static + Clone + Send + Sync + Unpin {
38     type ConnHandle;
39 
40     /// The returned handle should
new_connection(&self, config: &NameServerConfig, options: &ResolverOpts) -> Self::ConnHandle41     fn new_connection(&self, config: &NameServerConfig, options: &ResolverOpts)
42         -> Self::ConnHandle;
43 }
44 
45 /// Standard connection implements the default mechanism for creating new Connections
46 #[derive(Clone)]
47 pub struct StandardConnection;
48 
49 impl ConnectionProvider for StandardConnection {
50     type ConnHandle = ConnectionHandle;
51 
52     /// Constructs an initial constructor for the ConnectionHandle to be used to establish a
53     ///   future connection.
new_connection( &self, config: &NameServerConfig, options: &ResolverOpts, ) -> Self::ConnHandle54     fn new_connection(
55         &self,
56         config: &NameServerConfig,
57         options: &ResolverOpts,
58     ) -> Self::ConnHandle {
59         let dns_handle = match config.protocol {
60             Protocol::Udp => ConnectionHandleInner::Connect(Some(ConnectionHandleConnect::Udp {
61                 socket_addr: config.socket_addr,
62                 timeout: options.timeout,
63             })),
64             Protocol::Tcp => ConnectionHandleInner::Connect(Some(ConnectionHandleConnect::Tcp {
65                 socket_addr: config.socket_addr,
66                 timeout: options.timeout,
67             })),
68             #[cfg(feature = "dns-over-tls")]
69             Protocol::Tls => ConnectionHandleInner::Connect(Some(ConnectionHandleConnect::Tls {
70                 socket_addr: config.socket_addr,
71                 timeout: options.timeout,
72                 tls_dns_name: config.tls_dns_name.clone().unwrap_or_default(),
73                 #[cfg(feature = "dns-over-rustls")]
74                 client_config: config.tls_config.clone(),
75             })),
76             #[cfg(feature = "dns-over-https")]
77             Protocol::Https => {
78                 ConnectionHandleInner::Connect(Some(ConnectionHandleConnect::Https {
79                     socket_addr: config.socket_addr,
80                     timeout: options.timeout,
81                     tls_dns_name: config.tls_dns_name.clone().unwrap_or_default(),
82                     #[cfg(feature = "dns-over-rustls")]
83                     client_config: config.tls_config.clone(),
84                 }))
85             }
86             #[cfg(feature = "mdns")]
87             Protocol::Mdns => ConnectionHandleInner::Connect(Some(ConnectionHandleConnect::Mdns {
88                 socket_addr: config.socket_addr,
89                 timeout: options.timeout,
90             })),
91         };
92 
93         ConnectionHandle(Arc::new(Mutex::new(dns_handle)))
94     }
95 }
96 
97 /// The variants of all supported connections for the Resolver
98 #[derive(Debug)]
99 pub(crate) enum ConnectionHandleConnect {
100     Udp {
101         socket_addr: SocketAddr,
102         timeout: Duration,
103     },
104     Tcp {
105         socket_addr: SocketAddr,
106         timeout: Duration,
107     },
108     #[cfg(feature = "dns-over-tls")]
109     Tls {
110         socket_addr: SocketAddr,
111         timeout: Duration,
112         tls_dns_name: String,
113         #[cfg(feature = "dns-over-rustls")]
114         client_config: Option<TlsClientConfig>,
115     },
116     #[cfg(feature = "dns-over-https")]
117     Https {
118         socket_addr: SocketAddr,
119         timeout: Duration,
120         tls_dns_name: String,
121         #[cfg(feature = "dns-over-rustls")]
122         client_config: Option<TlsClientConfig>,
123     },
124     #[cfg(feature = "mdns")]
125     Mdns {
126         socket_addr: SocketAddr,
127         timeout: Duration,
128     },
129 }
130 
131 // TODO: rather than spawning here, return the background process, and rmove Background indirection.
132 impl ConnectionHandleConnect {
133     /// Establishes the connection, this is allowed to perform network operations,
134     ///   such as tokio::spawns of background tasks, etc.
connect(self) -> Result<ConnectionHandleConnected, proto::error::ProtoError>135     fn connect(self) -> Result<ConnectionHandleConnected, proto::error::ProtoError> {
136         use self::ConnectionHandleConnect::*;
137 
138         debug!("connecting: {:?}", self);
139         match self {
140             Udp {
141                 socket_addr,
142                 timeout,
143             } => {
144                 let stream = UdpClientStream::<TokioUdpSocket>::with_timeout(socket_addr, timeout);
145                 let (stream, handle) = DnsExchange::connect(stream);
146 
147                 let stream = stream
148                     .and_then(|stream| stream)
149                     .map_err(|e| {
150                         debug!("udp connection shutting down: {}", e);
151                     })
152                     .map(|_| ());
153                 let handle = BufDnsRequestStreamHandle::new(handle);
154 
155                 tokio::spawn(stream.boxed());
156                 Ok(ConnectionHandleConnected::Udp(handle))
157             }
158             Tcp {
159                 socket_addr,
160                 timeout,
161             } => {
162                 let (stream, handle) =
163                     TcpClientStream::<TokioTcpStream>::with_timeout(socket_addr, timeout);
164                 // TODO: need config for Signer...
165                 let dns_conn = DnsMultiplexer::with_timeout(
166                     Box::new(stream),
167                     handle,
168                     timeout,
169                     NoopMessageFinalizer::new(),
170                 );
171 
172                 let (stream, handle) = DnsExchange::connect(dns_conn);
173                 let stream = stream
174                     .and_then(|stream| stream)
175                     .map_err(|e| {
176                         debug!("tcp connection shutting down: {}", e);
177                     })
178                     .map(|_| ());
179                 let handle = BufDnsRequestStreamHandle::new(handle);
180 
181                 tokio::spawn(stream.boxed());
182                 Ok(ConnectionHandleConnected::Tcp(handle))
183             }
184             #[cfg(feature = "dns-over-tls")]
185             Tls {
186                 socket_addr,
187                 timeout,
188                 tls_dns_name,
189                 #[cfg(feature = "dns-over-rustls")]
190                 client_config,
191             } => {
192                 #[cfg(feature = "dns-over-rustls")]
193                 let (stream, handle) =
194                     { crate::tls::new_tls_stream(socket_addr, tls_dns_name, client_config) };
195                 #[cfg(not(feature = "dns-over-rustls"))]
196                 let (stream, handle) = { crate::tls::new_tls_stream(socket_addr, tls_dns_name) };
197 
198                 let dns_conn = DnsMultiplexer::with_timeout(
199                     stream,
200                     Box::new(handle),
201                     timeout,
202                     NoopMessageFinalizer::new(),
203                 );
204 
205                 let (stream, handle) = DnsExchange::connect(dns_conn);
206                 let stream = stream
207                     .and_then(|stream| stream)
208                     .map_err(|e| {
209                         debug!("tls connection shutting down: {}", e);
210                     })
211                     .map(|_| ());
212                 let handle = BufDnsRequestStreamHandle::new(handle);
213 
214                 tokio::spawn(Box::pin(stream));
215                 Ok(ConnectionHandleConnected::Tcp(handle))
216             }
217             #[cfg(feature = "dns-over-https")]
218             Https {
219                 socket_addr,
220                 // TODO: https needs timeout!
221                 timeout: _t,
222                 tls_dns_name,
223                 client_config,
224             } => {
225                 let (stream, handle) =
226                     crate::https::new_https_stream(socket_addr, tls_dns_name, client_config);
227 
228                 let stream = stream
229                     .and_then(|stream| stream)
230                     .map_err(|e| {
231                         debug!("https connection shutting down: {}", e);
232                     })
233                     .map(|_| ());
234 
235                 tokio::spawn(Box::pin(stream));
236                 Ok(ConnectionHandleConnected::Https(handle))
237             }
238             #[cfg(feature = "mdns")]
239             Mdns {
240                 socket_addr,
241                 timeout,
242             } => {
243                 let (stream, handle) =
244                     MdnsClientStream::new(socket_addr, MdnsQueryType::OneShot, None, None, None);
245                 // TODO: need config for Signer...
246                 let dns_conn = DnsMultiplexer::with_timeout(
247                     stream,
248                     handle,
249                     timeout,
250                     NoopMessageFinalizer::new(),
251                 );
252 
253                 let (stream, handle) = DnsExchange::connect(dns_conn);
254                 let stream = stream
255                     .and_then(|stream| stream)
256                     .map_err(|e| {
257                         debug!("mdns connection shutting down: {}", e);
258                     })
259                     .map(|_| ());
260                 let handle = BufDnsRequestStreamHandle::new(handle);
261 
262                 tokio::spawn(Box::pin(stream));
263                 Ok(ConnectionHandleConnected::Tcp(handle))
264             }
265         }
266     }
267 }
268 
269 /// A representation of an established connection
270 #[derive(Clone)]
271 enum ConnectionHandleConnected {
272     Udp(xfer::BufDnsRequestStreamHandle<UdpResponse>),
273     Tcp(xfer::BufDnsRequestStreamHandle<DnsMultiplexerSerialResponse>),
274     #[cfg(feature = "dns-over-https")]
275     Https(xfer::BufDnsRequestStreamHandle<HttpsClientResponse>),
276 }
277 
278 impl DnsHandle for ConnectionHandleConnected {
279     type Response = ConnectionHandleResponseInner;
280 
send<R: Into<DnsRequest> + Unpin + Send + 'static>( &mut self, request: R, ) -> ConnectionHandleResponseInner281     fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(
282         &mut self,
283         request: R,
284     ) -> ConnectionHandleResponseInner {
285         match self {
286             ConnectionHandleConnected::Udp(ref mut conn) => {
287                 ConnectionHandleResponseInner::Udp(conn.send(request))
288             }
289             ConnectionHandleConnected::Tcp(ref mut conn) => {
290                 ConnectionHandleResponseInner::Tcp(conn.send(request))
291             }
292             #[cfg(feature = "dns-over-https")]
293             ConnectionHandleConnected::Https(ref mut https) => {
294                 ConnectionHandleResponseInner::Https(https.send(request))
295             }
296         }
297     }
298 }
299 
300 /// Allows us to wrap a connection that is either pending or already connected
301 enum ConnectionHandleInner {
302     //    Connect(Option<ConnectionHandleConnect>),
303     Connect(Option<ConnectionHandleConnect>),
304     Connected(ConnectionHandleConnected),
305 }
306 
307 impl ConnectionHandleInner {
send<R: Into<DnsRequest> + Unpin + Send + 'static>( &mut self, request: R, ) -> ConnectionHandleResponseInner308     fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(
309         &mut self,
310         request: R,
311     ) -> ConnectionHandleResponseInner {
312         loop {
313             let connected: Result<ConnectionHandleConnected, proto::error::ProtoError> = match self
314             {
315                 // still need to connect, drop through
316                 ConnectionHandleInner::Connect(conn) => {
317                     conn.take().expect("already connected?").connect()
318                 }
319                 ConnectionHandleInner::Connected(conn) => return conn.send(request),
320             };
321 
322             match connected {
323                 Ok(connected) => *self = ConnectionHandleInner::Connected(connected),
324                 Err(e) => return ConnectionHandleResponseInner::ProtoError(Some(e)),
325             };
326             // continue to return on send...
327         }
328     }
329 }
330 
331 /// ConnectionHandle is used for sending DNS requests to a specific upstream DNS resolver
332 #[derive(Clone)]
333 pub struct ConnectionHandle(Arc<Mutex<ConnectionHandleInner>>);
334 
335 impl DnsHandle for ConnectionHandle {
336     type Response = ConnectionHandleResponse;
337 
send<R: Into<DnsRequest>>(&mut self, request: R) -> ConnectionHandleResponse338     fn send<R: Into<DnsRequest>>(&mut self, request: R) -> ConnectionHandleResponse {
339         ConnectionHandleResponse(ConnectionHandleResponseInner::ConnectAndRequest {
340             conn: self.clone(),
341             request: Some(request.into()),
342         })
343     }
344 }
345 
346 /// A wrapper type to switch over a connection that still needs to be made, or is already established
347 #[must_use = "futures do nothing unless polled"]
348 enum ConnectionHandleResponseInner {
349     ConnectAndRequest {
350         conn: ConnectionHandle,
351         request: Option<DnsRequest>,
352     },
353     Udp(xfer::OneshotDnsResponseReceiver<UdpResponse>),
354     Tcp(xfer::OneshotDnsResponseReceiver<DnsMultiplexerSerialResponse>),
355     #[cfg(feature = "dns-over-https")]
356     Https(xfer::OneshotDnsResponseReceiver<HttpsClientResponse>),
357     ProtoError(Option<proto::error::ProtoError>),
358 }
359 
360 impl Future for ConnectionHandleResponseInner {
361     type Output = Result<DnsResponse, proto::error::ProtoError>;
362 
poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output>363     fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
364         use self::ConnectionHandleResponseInner::*;
365 
366         trace!("polling response inner");
367         loop {
368             *self = match *self {
369                 // we still need to check the connection
370                 ConnectAndRequest {
371                     ref conn,
372                     ref mut request,
373                 } => match conn.0.lock() {
374                     Ok(mut c) => c.send(request.take().expect("already sent request?")),
375                     Err(e) => ProtoError(Some(proto::error::ProtoError::from(e))),
376                 },
377                 Udp(ref mut resp) => return resp.poll_unpin(cx),
378                 Tcp(ref mut resp) => return resp.poll_unpin(cx),
379                 #[cfg(feature = "dns-over-https")]
380                 Https(ref mut https) => return https.poll_unpin(cx),
381                 ProtoError(ref mut e) => {
382                     return Poll::Ready(Err(e
383                         .take()
384                         .expect("futures cannot be polled once complete")));
385                 }
386             };
387 
388             // ok, connected, loop around and use poll the actual send request
389         }
390     }
391 }
392 
393 /// A future response from a DNS request.
394 #[must_use = "futures do nothing unless polled"]
395 pub struct ConnectionHandleResponse(ConnectionHandleResponseInner);
396 
397 impl Future for ConnectionHandleResponse {
398     type Output = Result<DnsResponse, proto::error::ProtoError>;
399 
poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output>400     fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
401         self.0.poll_unpin(cx)
402     }
403 }
404