1 //! DNS Resolution used by the `HttpConnector`.
2 //!
3 //! This module contains:
4 //!
5 //! - A [`GaiResolver`](GaiResolver) that is the default resolver for the
6 //!   `HttpConnector`.
7 //! - The `Name` type used as an argument to custom resolvers.
8 //!
9 //! # Resolvers are `Service`s
10 //!
11 //! A resolver is just a
12 //! `Service<Name, Response = impl Iterator<Item = SocketAddr>>`.
13 //!
14 //! A simple resolver that ignores the name and always returns a specific
15 //! address:
16 //!
17 //! ```rust,ignore
18 //! use std::{convert::Infallible, iter, net::SocketAddr};
19 //!
20 //! let resolver = tower::service_fn(|_name| async {
21 //!     Ok::<_, Infallible>(iter::once(SocketAddr::from(([127, 0, 0, 1], 8080))))
22 //! });
23 //! ```
24 use std::error::Error;
25 use std::future::Future;
26 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
27 use std::pin::Pin;
28 use std::str::FromStr;
29 use std::task::{self, Poll};
30 use std::{fmt, io, vec};
31 
32 use tokio::task::JoinHandle;
33 use tower_service::Service;
34 use tracing::debug;
35 
36 pub(super) use self::sealed::Resolve;
37 
38 /// A domain name to resolve into IP addresses.
39 #[derive(Clone, Hash, Eq, PartialEq)]
40 pub struct Name {
41     host: String,
42 }
43 
44 /// A resolver using blocking `getaddrinfo` calls in a threadpool.
45 #[derive(Clone)]
46 pub struct GaiResolver {
47     _priv: (),
48 }
49 
50 /// An iterator of IP addresses returned from `getaddrinfo`.
51 pub struct GaiAddrs {
52     inner: SocketAddrs,
53 }
54 
55 /// A future to resolve a name returned by `GaiResolver`.
56 pub struct GaiFuture {
57     inner: JoinHandle<Result<SocketAddrs, io::Error>>,
58 }
59 
60 impl Name {
new(host: String) -> Name61     pub(super) fn new(host: String) -> Name {
62         Name { host }
63     }
64 
65     /// View the hostname as a string slice.
as_str(&self) -> &str66     pub fn as_str(&self) -> &str {
67         &self.host
68     }
69 }
70 
71 impl fmt::Debug for Name {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result72     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73         fmt::Debug::fmt(&self.host, f)
74     }
75 }
76 
77 impl fmt::Display for Name {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result78     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79         fmt::Display::fmt(&self.host, f)
80     }
81 }
82 
83 impl FromStr for Name {
84     type Err = InvalidNameError;
85 
from_str(host: &str) -> Result<Self, Self::Err>86     fn from_str(host: &str) -> Result<Self, Self::Err> {
87         // Possibly add validation later
88         Ok(Name::new(host.to_owned()))
89     }
90 }
91 
92 /// Error indicating a given string was not a valid domain name.
93 #[derive(Debug)]
94 pub struct InvalidNameError(());
95 
96 impl fmt::Display for InvalidNameError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result97     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98         f.write_str("Not a valid domain name")
99     }
100 }
101 
102 impl Error for InvalidNameError {}
103 
104 impl GaiResolver {
105     /// Construct a new `GaiResolver`.
new() -> Self106     pub fn new() -> Self {
107         GaiResolver { _priv: () }
108     }
109 }
110 
111 impl Service<Name> for GaiResolver {
112     type Response = GaiAddrs;
113     type Error = io::Error;
114     type Future = GaiFuture;
115 
poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>>116     fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
117         Poll::Ready(Ok(()))
118     }
119 
call(&mut self, name: Name) -> Self::Future120     fn call(&mut self, name: Name) -> Self::Future {
121         let blocking = tokio::task::spawn_blocking(move || {
122             debug!("resolving host={:?}", name.host);
123             (&*name.host, 0)
124                 .to_socket_addrs()
125                 .map(|i| SocketAddrs { iter: i })
126         });
127 
128         GaiFuture { inner: blocking }
129     }
130 }
131 
132 impl fmt::Debug for GaiResolver {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result133     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134         f.pad("GaiResolver")
135     }
136 }
137 
138 impl Future for GaiFuture {
139     type Output = Result<GaiAddrs, io::Error>;
140 
poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>141     fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
142         Pin::new(&mut self.inner).poll(cx).map(|res| match res {
143             Ok(Ok(addrs)) => Ok(GaiAddrs { inner: addrs }),
144             Ok(Err(err)) => Err(err),
145             Err(join_err) => {
146                 if join_err.is_cancelled() {
147                     Err(io::Error::new(io::ErrorKind::Interrupted, join_err))
148                 } else {
149                     panic!("gai background task failed: {:?}", join_err)
150                 }
151             }
152         })
153     }
154 }
155 
156 impl fmt::Debug for GaiFuture {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result157     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158         f.pad("GaiFuture")
159     }
160 }
161 
162 impl Iterator for GaiAddrs {
163     type Item = SocketAddr;
164 
next(&mut self) -> Option<Self::Item>165     fn next(&mut self) -> Option<Self::Item> {
166         self.inner.next()
167     }
168 }
169 
170 impl fmt::Debug for GaiAddrs {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result171     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172         f.pad("GaiAddrs")
173     }
174 }
175 
176 pub(super) struct SocketAddrs {
177     iter: vec::IntoIter<SocketAddr>,
178 }
179 
180 impl SocketAddrs {
new(addrs: Vec<SocketAddr>) -> Self181     pub(super) fn new(addrs: Vec<SocketAddr>) -> Self {
182         SocketAddrs {
183             iter: addrs.into_iter(),
184         }
185     }
186 
try_parse(host: &str, port: u16) -> Option<SocketAddrs>187     pub(super) fn try_parse(host: &str, port: u16) -> Option<SocketAddrs> {
188         if let Ok(addr) = host.parse::<Ipv4Addr>() {
189             let addr = SocketAddrV4::new(addr, port);
190             return Some(SocketAddrs {
191                 iter: vec![SocketAddr::V4(addr)].into_iter(),
192             });
193         }
194         let host = host.trim_start_matches('[').trim_end_matches(']');
195         if let Ok(addr) = host.parse::<Ipv6Addr>() {
196             let addr = SocketAddrV6::new(addr, port, 0, 0);
197             return Some(SocketAddrs {
198                 iter: vec![SocketAddr::V6(addr)].into_iter(),
199             });
200         }
201         None
202     }
203 
204     #[inline]
filter(self, predicate: impl FnMut(&SocketAddr) -> bool) -> SocketAddrs205     fn filter(self, predicate: impl FnMut(&SocketAddr) -> bool) -> SocketAddrs {
206         SocketAddrs::new(self.iter.filter(predicate).collect())
207     }
208 
split_by_preference( self, local_addr_ipv4: Option<Ipv4Addr>, local_addr_ipv6: Option<Ipv6Addr>, ) -> (SocketAddrs, SocketAddrs)209     pub(super) fn split_by_preference(
210         self,
211         local_addr_ipv4: Option<Ipv4Addr>,
212         local_addr_ipv6: Option<Ipv6Addr>,
213     ) -> (SocketAddrs, SocketAddrs) {
214         match (local_addr_ipv4, local_addr_ipv6) {
215             (Some(_), None) => (self.filter(SocketAddr::is_ipv4), SocketAddrs::new(vec![])),
216             (None, Some(_)) => (self.filter(SocketAddr::is_ipv6), SocketAddrs::new(vec![])),
217             _ => {
218                 let preferring_v6 = self
219                     .iter
220                     .as_slice()
221                     .first()
222                     .map(SocketAddr::is_ipv6)
223                     .unwrap_or(false);
224 
225                 let (preferred, fallback) = self
226                     .iter
227                     .partition::<Vec<_>, _>(|addr| addr.is_ipv6() == preferring_v6);
228 
229                 (SocketAddrs::new(preferred), SocketAddrs::new(fallback))
230             }
231         }
232     }
233 
is_empty(&self) -> bool234     pub(super) fn is_empty(&self) -> bool {
235         self.iter.as_slice().is_empty()
236     }
237 
len(&self) -> usize238     pub(super) fn len(&self) -> usize {
239         self.iter.as_slice().len()
240     }
241 }
242 
243 impl Iterator for SocketAddrs {
244     type Item = SocketAddr;
245     #[inline]
next(&mut self) -> Option<SocketAddr>246     fn next(&mut self) -> Option<SocketAddr> {
247         self.iter.next()
248     }
249 }
250 
251 /*
252 /// A resolver using `getaddrinfo` calls via the `tokio_executor::threadpool::blocking` API.
253 ///
254 /// Unlike the `GaiResolver` this will not spawn dedicated threads, but only works when running on the
255 /// multi-threaded Tokio runtime.
256 #[cfg(feature = "runtime")]
257 #[derive(Clone, Debug)]
258 pub struct TokioThreadpoolGaiResolver(());
259 
260 /// The future returned by `TokioThreadpoolGaiResolver`.
261 #[cfg(feature = "runtime")]
262 #[derive(Debug)]
263 pub struct TokioThreadpoolGaiFuture {
264     name: Name,
265 }
266 
267 #[cfg(feature = "runtime")]
268 impl TokioThreadpoolGaiResolver {
269     /// Creates a new DNS resolver that will use tokio threadpool's blocking
270     /// feature.
271     ///
272     /// **Requires** its futures to be run on the threadpool runtime.
273     pub fn new() -> Self {
274         TokioThreadpoolGaiResolver(())
275     }
276 }
277 
278 #[cfg(feature = "runtime")]
279 impl Service<Name> for TokioThreadpoolGaiResolver {
280     type Response = GaiAddrs;
281     type Error = io::Error;
282     type Future = TokioThreadpoolGaiFuture;
283 
284     fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
285         Poll::Ready(Ok(()))
286     }
287 
288     fn call(&mut self, name: Name) -> Self::Future {
289         TokioThreadpoolGaiFuture { name }
290     }
291 }
292 
293 #[cfg(feature = "runtime")]
294 impl Future for TokioThreadpoolGaiFuture {
295     type Output = Result<GaiAddrs, io::Error>;
296 
297     fn poll(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<Self::Output> {
298         match ready!(tokio_executor::threadpool::blocking(|| (
299             self.name.as_str(),
300             0
301         )
302             .to_socket_addrs()))
303         {
304             Ok(Ok(iter)) => Poll::Ready(Ok(GaiAddrs {
305                 inner: IpAddrs { iter },
306             })),
307             Ok(Err(e)) => Poll::Ready(Err(e)),
308             // a BlockingError, meaning not on a tokio_executor::threadpool :(
309             Err(e) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
310         }
311     }
312 }
313 */
314 
315 mod sealed {
316     use super::{SocketAddr, Name};
317     use crate::common::{task, Future, Poll};
318     use tower_service::Service;
319 
320     // "Trait alias" for `Service<Name, Response = Addrs>`
321     pub trait Resolve {
322         type Addrs: Iterator<Item = SocketAddr>;
323         type Error: Into<Box<dyn std::error::Error + Send + Sync>>;
324         type Future: Future<Output = Result<Self::Addrs, Self::Error>>;
325 
poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>326         fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>;
resolve(&mut self, name: Name) -> Self::Future327         fn resolve(&mut self, name: Name) -> Self::Future;
328     }
329 
330     impl<S> Resolve for S
331     where
332         S: Service<Name>,
333         S::Response: Iterator<Item = SocketAddr>,
334         S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
335     {
336         type Addrs = S::Response;
337         type Error = S::Error;
338         type Future = S::Future;
339 
poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>340         fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
341             Service::poll_ready(self, cx)
342         }
343 
resolve(&mut self, name: Name) -> Self::Future344         fn resolve(&mut self, name: Name) -> Self::Future {
345             Service::call(self, name)
346         }
347     }
348 }
349 
resolve<R>(resolver: &mut R, name: Name) -> Result<R::Addrs, R::Error> where R: Resolve,350 pub(super) async fn resolve<R>(resolver: &mut R, name: Name) -> Result<R::Addrs, R::Error>
351 where
352     R: Resolve,
353 {
354     futures_util::future::poll_fn(|cx| resolver.poll_ready(cx)).await?;
355     resolver.resolve(name).await
356 }
357 
358 #[cfg(test)]
359 mod tests {
360     use super::*;
361     use std::net::{Ipv4Addr, Ipv6Addr};
362 
363     #[test]
test_ip_addrs_split_by_preference()364     fn test_ip_addrs_split_by_preference() {
365         let ip_v4 = Ipv4Addr::new(127, 0, 0, 1);
366         let ip_v6 = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1);
367         let v4_addr = (ip_v4, 80).into();
368         let v6_addr = (ip_v6, 80).into();
369 
370         let (mut preferred, mut fallback) = SocketAddrs {
371             iter: vec![v4_addr, v6_addr].into_iter(),
372         }
373         .split_by_preference(None, None);
374         assert!(preferred.next().unwrap().is_ipv4());
375         assert!(fallback.next().unwrap().is_ipv6());
376 
377         let (mut preferred, mut fallback) = SocketAddrs {
378             iter: vec![v6_addr, v4_addr].into_iter(),
379         }
380         .split_by_preference(None, None);
381         assert!(preferred.next().unwrap().is_ipv6());
382         assert!(fallback.next().unwrap().is_ipv4());
383 
384         let (mut preferred, mut fallback) = SocketAddrs {
385             iter: vec![v4_addr, v6_addr].into_iter(),
386         }
387         .split_by_preference(Some(ip_v4), Some(ip_v6));
388         assert!(preferred.next().unwrap().is_ipv4());
389         assert!(fallback.next().unwrap().is_ipv6());
390 
391         let (mut preferred, mut fallback) = SocketAddrs {
392             iter: vec![v6_addr, v4_addr].into_iter(),
393         }
394         .split_by_preference(Some(ip_v4), Some(ip_v6));
395         assert!(preferred.next().unwrap().is_ipv6());
396         assert!(fallback.next().unwrap().is_ipv4());
397 
398         let (mut preferred, fallback) = SocketAddrs {
399             iter: vec![v4_addr, v6_addr].into_iter(),
400         }
401         .split_by_preference(Some(ip_v4), None);
402         assert!(preferred.next().unwrap().is_ipv4());
403         assert!(fallback.is_empty());
404 
405         let (mut preferred, fallback) = SocketAddrs {
406             iter: vec![v4_addr, v6_addr].into_iter(),
407         }
408         .split_by_preference(None, Some(ip_v6));
409         assert!(preferred.next().unwrap().is_ipv6());
410         assert!(fallback.is_empty());
411     }
412 
413     #[test]
test_name_from_str()414     fn test_name_from_str() {
415         const DOMAIN: &str = "test.example.com";
416         let name = Name::from_str(DOMAIN).expect("Should be a valid domain");
417         assert_eq!(name.as_str(), DOMAIN);
418         assert_eq!(name.to_string(), DOMAIN);
419     }
420 
421     #[test]
ip_addrs_try_parse_v6()422     fn ip_addrs_try_parse_v6() {
423         let dst = ::http::Uri::from_static("http://[::1]:8080/");
424 
425         let mut addrs =
426             SocketAddrs::try_parse(dst.host().expect("host"), dst.port_u16().expect("port"))
427                 .expect("try_parse");
428 
429         let expected = "[::1]:8080".parse::<SocketAddr>().expect("expected");
430 
431         assert_eq!(addrs.next(), Some(expected));
432     }
433 }
434