1 //! DNS Resolution used by the `HttpConnector`.
2 //!
3 //! This module contains:
4 //!
5 //! - A [`GaiResolver`](GaiResolver) that is the default resolver for the
6 //!   `HttpConnector`.
7 //! - The `Name` type used as an argument to custom resolvers.
8 //!
9 //! # Resolvers are `Service`s
10 //!
11 //! A resolver is just a
12 //! `Service<Name, Response = impl Iterator<Item = IpAddr>>`.
13 //!
14 //! A simple resolver that ignores the name and always returns a specific
15 //! address:
16 //!
17 //! ```rust,ignore
18 //! use std::{convert::Infallible, iter, net::IpAddr};
19 //!
20 //! let resolver = tower::service_fn(|_name| async {
21 //!     Ok::<_, Infallible>(iter::once(IpAddr::from([127, 0, 0, 1])))
22 //! });
23 //! ```
24 use std::error::Error;
25 use std::future::Future;
26 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
27 use std::pin::Pin;
28 use std::str::FromStr;
29 use std::task::{self, Poll};
30 use std::{fmt, io, vec};
31 
32 use tokio::task::JoinHandle;
33 use tower_service::Service;
34 
35 pub(super) use self::sealed::Resolve;
36 
37 /// A domain name to resolve into IP addresses.
38 #[derive(Clone, Hash, Eq, PartialEq)]
39 pub struct Name {
40     host: String,
41 }
42 
43 /// A resolver using blocking `getaddrinfo` calls in a threadpool.
44 #[derive(Clone)]
45 pub struct GaiResolver {
46     _priv: (),
47 }
48 
49 /// An iterator of IP addresses returned from `getaddrinfo`.
50 pub struct GaiAddrs {
51     inner: IpAddrs,
52 }
53 
54 /// A future to resolve a name returned by `GaiResolver`.
55 pub struct GaiFuture {
56     inner: JoinHandle<Result<IpAddrs, io::Error>>,
57 }
58 
59 impl Name {
new(host: String) -> Name60     pub(super) fn new(host: String) -> Name {
61         Name { host }
62     }
63 
64     /// View the hostname as a string slice.
as_str(&self) -> &str65     pub fn as_str(&self) -> &str {
66         &self.host
67     }
68 }
69 
70 impl fmt::Debug for Name {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result71     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72         fmt::Debug::fmt(&self.host, f)
73     }
74 }
75 
76 impl fmt::Display for Name {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result77     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78         fmt::Display::fmt(&self.host, f)
79     }
80 }
81 
82 impl FromStr for Name {
83     type Err = InvalidNameError;
84 
from_str(host: &str) -> Result<Self, Self::Err>85     fn from_str(host: &str) -> Result<Self, Self::Err> {
86         // Possibly add validation later
87         Ok(Name::new(host.to_owned()))
88     }
89 }
90 
91 /// Error indicating a given string was not a valid domain name.
92 #[derive(Debug)]
93 pub struct InvalidNameError(());
94 
95 impl fmt::Display for InvalidNameError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result96     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97         f.write_str("Not a valid domain name")
98     }
99 }
100 
101 impl Error for InvalidNameError {}
102 
103 impl GaiResolver {
104     /// Construct a new `GaiResolver`.
new() -> Self105     pub fn new() -> Self {
106         GaiResolver { _priv: () }
107     }
108 }
109 
110 impl Service<Name> for GaiResolver {
111     type Response = GaiAddrs;
112     type Error = io::Error;
113     type Future = GaiFuture;
114 
poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>>115     fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
116         Poll::Ready(Ok(()))
117     }
118 
call(&mut self, name: Name) -> Self::Future119     fn call(&mut self, name: Name) -> Self::Future {
120         let blocking = tokio::task::spawn_blocking(move || {
121             debug!("resolving host={:?}", name.host);
122             (&*name.host, 0)
123                 .to_socket_addrs()
124                 .map(|i| IpAddrs { iter: i })
125         });
126 
127         GaiFuture { inner: blocking }
128     }
129 }
130 
131 impl fmt::Debug for GaiResolver {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result132     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133         f.pad("GaiResolver")
134     }
135 }
136 
137 impl Future for GaiFuture {
138     type Output = Result<GaiAddrs, io::Error>;
139 
poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>140     fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
141         Pin::new(&mut self.inner).poll(cx).map(|res| match res {
142             Ok(Ok(addrs)) => Ok(GaiAddrs { inner: addrs }),
143             Ok(Err(err)) => Err(err),
144             Err(join_err) => panic!("gai background task failed: {:?}", join_err),
145         })
146     }
147 }
148 
149 impl fmt::Debug for GaiFuture {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result150     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151         f.pad("GaiFuture")
152     }
153 }
154 
155 impl Iterator for GaiAddrs {
156     type Item = IpAddr;
157 
next(&mut self) -> Option<Self::Item>158     fn next(&mut self) -> Option<Self::Item> {
159         self.inner.next().map(|sa| sa.ip())
160     }
161 }
162 
163 impl fmt::Debug for GaiAddrs {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result164     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165         f.pad("GaiAddrs")
166     }
167 }
168 
169 pub(super) struct IpAddrs {
170     iter: vec::IntoIter<SocketAddr>,
171 }
172 
173 impl IpAddrs {
new(addrs: Vec<SocketAddr>) -> Self174     pub(super) fn new(addrs: Vec<SocketAddr>) -> Self {
175         IpAddrs {
176             iter: addrs.into_iter(),
177         }
178     }
179 
try_parse(host: &str, port: u16) -> Option<IpAddrs>180     pub(super) fn try_parse(host: &str, port: u16) -> Option<IpAddrs> {
181         if let Ok(addr) = host.parse::<Ipv4Addr>() {
182             let addr = SocketAddrV4::new(addr, port);
183             return Some(IpAddrs {
184                 iter: vec![SocketAddr::V4(addr)].into_iter(),
185             });
186         }
187         let host = host.trim_start_matches('[').trim_end_matches(']');
188         if let Ok(addr) = host.parse::<Ipv6Addr>() {
189             let addr = SocketAddrV6::new(addr, port, 0, 0);
190             return Some(IpAddrs {
191                 iter: vec![SocketAddr::V6(addr)].into_iter(),
192             });
193         }
194         None
195     }
196 
split_by_preference(self, local_addr: Option<IpAddr>) -> (IpAddrs, IpAddrs)197     pub(super) fn split_by_preference(self, local_addr: Option<IpAddr>) -> (IpAddrs, IpAddrs) {
198         if let Some(local_addr) = local_addr {
199             let preferred = self
200                 .iter
201                 .filter(|addr| addr.is_ipv6() == local_addr.is_ipv6())
202                 .collect();
203 
204             (IpAddrs::new(preferred), IpAddrs::new(vec![]))
205         } else {
206             let preferring_v6 = self
207                 .iter
208                 .as_slice()
209                 .first()
210                 .map(SocketAddr::is_ipv6)
211                 .unwrap_or(false);
212 
213             let (preferred, fallback) = self
214                 .iter
215                 .partition::<Vec<_>, _>(|addr| addr.is_ipv6() == preferring_v6);
216 
217             (IpAddrs::new(preferred), IpAddrs::new(fallback))
218         }
219     }
220 
is_empty(&self) -> bool221     pub(super) fn is_empty(&self) -> bool {
222         self.iter.as_slice().is_empty()
223     }
224 
len(&self) -> usize225     pub(super) fn len(&self) -> usize {
226         self.iter.as_slice().len()
227     }
228 }
229 
230 impl Iterator for IpAddrs {
231     type Item = SocketAddr;
232     #[inline]
next(&mut self) -> Option<SocketAddr>233     fn next(&mut self) -> Option<SocketAddr> {
234         self.iter.next()
235     }
236 }
237 
238 /*
239 /// A resolver using `getaddrinfo` calls via the `tokio_executor::threadpool::blocking` API.
240 ///
241 /// Unlike the `GaiResolver` this will not spawn dedicated threads, but only works when running on the
242 /// multi-threaded Tokio runtime.
243 #[cfg(feature = "runtime")]
244 #[derive(Clone, Debug)]
245 pub struct TokioThreadpoolGaiResolver(());
246 
247 /// The future returned by `TokioThreadpoolGaiResolver`.
248 #[cfg(feature = "runtime")]
249 #[derive(Debug)]
250 pub struct TokioThreadpoolGaiFuture {
251     name: Name,
252 }
253 
254 #[cfg(feature = "runtime")]
255 impl TokioThreadpoolGaiResolver {
256     /// Creates a new DNS resolver that will use tokio threadpool's blocking
257     /// feature.
258     ///
259     /// **Requires** its futures to be run on the threadpool runtime.
260     pub fn new() -> Self {
261         TokioThreadpoolGaiResolver(())
262     }
263 }
264 
265 #[cfg(feature = "runtime")]
266 impl Service<Name> for TokioThreadpoolGaiResolver {
267     type Response = GaiAddrs;
268     type Error = io::Error;
269     type Future = TokioThreadpoolGaiFuture;
270 
271     fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
272         Poll::Ready(Ok(()))
273     }
274 
275     fn call(&mut self, name: Name) -> Self::Future {
276         TokioThreadpoolGaiFuture { name }
277     }
278 }
279 
280 #[cfg(feature = "runtime")]
281 impl Future for TokioThreadpoolGaiFuture {
282     type Output = Result<GaiAddrs, io::Error>;
283 
284     fn poll(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<Self::Output> {
285         match ready!(tokio_executor::threadpool::blocking(|| (
286             self.name.as_str(),
287             0
288         )
289             .to_socket_addrs()))
290         {
291             Ok(Ok(iter)) => Poll::Ready(Ok(GaiAddrs {
292                 inner: IpAddrs { iter },
293             })),
294             Ok(Err(e)) => Poll::Ready(Err(e)),
295             // a BlockingError, meaning not on a tokio_executor::threadpool :(
296             Err(e) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
297         }
298     }
299 }
300 */
301 
302 mod sealed {
303     use super::{IpAddr, Name};
304     use crate::common::{task, Future, Poll};
305     use tower_service::Service;
306 
307     // "Trait alias" for `Service<Name, Response = Addrs>`
308     pub trait Resolve {
309         type Addrs: Iterator<Item = IpAddr>;
310         type Error: Into<Box<dyn std::error::Error + Send + Sync>>;
311         type Future: Future<Output = Result<Self::Addrs, Self::Error>>;
312 
poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>313         fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>;
resolve(&mut self, name: Name) -> Self::Future314         fn resolve(&mut self, name: Name) -> Self::Future;
315     }
316 
317     impl<S> Resolve for S
318     where
319         S: Service<Name>,
320         S::Response: Iterator<Item = IpAddr>,
321         S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
322     {
323         type Addrs = S::Response;
324         type Error = S::Error;
325         type Future = S::Future;
326 
poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>327         fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
328             Service::poll_ready(self, cx)
329         }
330 
resolve(&mut self, name: Name) -> Self::Future331         fn resolve(&mut self, name: Name) -> Self::Future {
332             Service::call(self, name)
333         }
334     }
335 }
336 
resolve<R>(resolver: &mut R, name: Name) -> Result<R::Addrs, R::Error> where R: Resolve,337 pub(crate) async fn resolve<R>(resolver: &mut R, name: Name) -> Result<R::Addrs, R::Error>
338 where
339     R: Resolve,
340 {
341     futures_util::future::poll_fn(|cx| resolver.poll_ready(cx)).await?;
342     resolver.resolve(name).await
343 }
344 
345 #[cfg(test)]
346 mod tests {
347     use super::*;
348     use std::net::{Ipv4Addr, Ipv6Addr};
349 
350     #[test]
test_ip_addrs_split_by_preference()351     fn test_ip_addrs_split_by_preference() {
352         let v4_addr = (Ipv4Addr::new(127, 0, 0, 1), 80).into();
353         let v6_addr = (Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 80).into();
354 
355         let (mut preferred, mut fallback) = IpAddrs {
356             iter: vec![v4_addr, v6_addr].into_iter(),
357         }
358         .split_by_preference(None);
359         assert!(preferred.next().unwrap().is_ipv4());
360         assert!(fallback.next().unwrap().is_ipv6());
361 
362         let (mut preferred, mut fallback) = IpAddrs {
363             iter: vec![v6_addr, v4_addr].into_iter(),
364         }
365         .split_by_preference(None);
366         assert!(preferred.next().unwrap().is_ipv6());
367         assert!(fallback.next().unwrap().is_ipv4());
368 
369         let (mut preferred, fallback) = IpAddrs {
370             iter: vec![v4_addr, v6_addr].into_iter(),
371         }
372         .split_by_preference(Some(v4_addr.ip()));
373         assert!(preferred.next().unwrap().is_ipv4());
374         assert!(fallback.is_empty());
375 
376         let (mut preferred, fallback) = IpAddrs {
377             iter: vec![v4_addr, v6_addr].into_iter(),
378         }
379         .split_by_preference(Some(v6_addr.ip()));
380         assert!(preferred.next().unwrap().is_ipv6());
381         assert!(fallback.is_empty());
382     }
383 
384     #[test]
test_name_from_str()385     fn test_name_from_str() {
386         const DOMAIN: &str = "test.example.com";
387         let name = Name::from_str(DOMAIN).expect("Should be a valid domain");
388         assert_eq!(name.as_str(), DOMAIN);
389         assert_eq!(name.to_string(), DOMAIN);
390     }
391 
392     #[test]
ip_addrs_try_parse_v6()393     fn ip_addrs_try_parse_v6() {
394         let dst = ::http::Uri::from_static("http://[::1]:8080/");
395 
396         let mut addrs =
397             IpAddrs::try_parse(dst.host().expect("host"), dst.port_u16().expect("port"))
398                 .expect("try_parse");
399 
400         let expected = "[::1]:8080".parse::<SocketAddr>().expect("expected");
401 
402         assert_eq!(addrs.next(), Some(expected));
403     }
404 }
405