1 //! DNS Resolution used by the `HttpConnector`.
2 //!
3 //! This module contains:
4 //!
5 //! - A [`GaiResolver`](GaiResolver) that is the default resolver for the
6 //! `HttpConnector`.
7 //! - The `Name` type used as an argument to custom resolvers.
8 //!
9 //! # Resolvers are `Service`s
10 //!
11 //! A resolver is just a
12 //! `Service<Name, Response = impl Iterator<Item = IpAddr>>`.
13 //!
14 //! A simple resolver that ignores the name and always returns a specific
15 //! address:
16 //!
17 //! ```rust,ignore
18 //! use std::{convert::Infallible, iter, net::IpAddr};
19 //!
20 //! let resolver = tower::service_fn(|_name| async {
21 //! Ok::<_, Infallible>(iter::once(IpAddr::from([127, 0, 0, 1])))
22 //! });
23 //! ```
24 use std::error::Error;
25 use std::future::Future;
26 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
27 use std::pin::Pin;
28 use std::str::FromStr;
29 use std::task::{self, Poll};
30 use std::{fmt, io, vec};
31
32 use tokio::task::JoinHandle;
33 use tower_service::Service;
34
35 pub(super) use self::sealed::Resolve;
36
37 /// A domain name to resolve into IP addresses.
38 #[derive(Clone, Hash, Eq, PartialEq)]
39 pub struct Name {
40 host: String,
41 }
42
43 /// A resolver using blocking `getaddrinfo` calls in a threadpool.
44 #[derive(Clone)]
45 pub struct GaiResolver {
46 _priv: (),
47 }
48
49 /// An iterator of IP addresses returned from `getaddrinfo`.
50 pub struct GaiAddrs {
51 inner: IpAddrs,
52 }
53
54 /// A future to resolve a name returned by `GaiResolver`.
55 pub struct GaiFuture {
56 inner: JoinHandle<Result<IpAddrs, io::Error>>,
57 }
58
59 impl Name {
new(host: String) -> Name60 pub(super) fn new(host: String) -> Name {
61 Name { host }
62 }
63
64 /// View the hostname as a string slice.
as_str(&self) -> &str65 pub fn as_str(&self) -> &str {
66 &self.host
67 }
68 }
69
70 impl fmt::Debug for Name {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 fmt::Debug::fmt(&self.host, f)
73 }
74 }
75
76 impl fmt::Display for Name {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result77 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 fmt::Display::fmt(&self.host, f)
79 }
80 }
81
82 impl FromStr for Name {
83 type Err = InvalidNameError;
84
from_str(host: &str) -> Result<Self, Self::Err>85 fn from_str(host: &str) -> Result<Self, Self::Err> {
86 // Possibly add validation later
87 Ok(Name::new(host.to_owned()))
88 }
89 }
90
91 /// Error indicating a given string was not a valid domain name.
92 #[derive(Debug)]
93 pub struct InvalidNameError(());
94
95 impl fmt::Display for InvalidNameError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97 f.write_str("Not a valid domain name")
98 }
99 }
100
101 impl Error for InvalidNameError {}
102
103 impl GaiResolver {
104 /// Construct a new `GaiResolver`.
new() -> Self105 pub fn new() -> Self {
106 GaiResolver { _priv: () }
107 }
108 }
109
110 impl Service<Name> for GaiResolver {
111 type Response = GaiAddrs;
112 type Error = io::Error;
113 type Future = GaiFuture;
114
poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>>115 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
116 Poll::Ready(Ok(()))
117 }
118
call(&mut self, name: Name) -> Self::Future119 fn call(&mut self, name: Name) -> Self::Future {
120 let blocking = tokio::task::spawn_blocking(move || {
121 debug!("resolving host={:?}", name.host);
122 (&*name.host, 0)
123 .to_socket_addrs()
124 .map(|i| IpAddrs { iter: i })
125 });
126
127 GaiFuture { inner: blocking }
128 }
129 }
130
131 impl fmt::Debug for GaiResolver {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result132 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133 f.pad("GaiResolver")
134 }
135 }
136
137 impl Future for GaiFuture {
138 type Output = Result<GaiAddrs, io::Error>;
139
poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>140 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
141 Pin::new(&mut self.inner).poll(cx).map(|res| match res {
142 Ok(Ok(addrs)) => Ok(GaiAddrs { inner: addrs }),
143 Ok(Err(err)) => Err(err),
144 Err(join_err) => panic!("gai background task failed: {:?}", join_err),
145 })
146 }
147 }
148
149 impl fmt::Debug for GaiFuture {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result150 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151 f.pad("GaiFuture")
152 }
153 }
154
155 impl Iterator for GaiAddrs {
156 type Item = IpAddr;
157
next(&mut self) -> Option<Self::Item>158 fn next(&mut self) -> Option<Self::Item> {
159 self.inner.next().map(|sa| sa.ip())
160 }
161 }
162
163 impl fmt::Debug for GaiAddrs {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 f.pad("GaiAddrs")
166 }
167 }
168
169 pub(super) struct IpAddrs {
170 iter: vec::IntoIter<SocketAddr>,
171 }
172
173 impl IpAddrs {
new(addrs: Vec<SocketAddr>) -> Self174 pub(super) fn new(addrs: Vec<SocketAddr>) -> Self {
175 IpAddrs {
176 iter: addrs.into_iter(),
177 }
178 }
179
try_parse(host: &str, port: u16) -> Option<IpAddrs>180 pub(super) fn try_parse(host: &str, port: u16) -> Option<IpAddrs> {
181 if let Ok(addr) = host.parse::<Ipv4Addr>() {
182 let addr = SocketAddrV4::new(addr, port);
183 return Some(IpAddrs {
184 iter: vec![SocketAddr::V4(addr)].into_iter(),
185 });
186 }
187 let host = host.trim_start_matches('[').trim_end_matches(']');
188 if let Ok(addr) = host.parse::<Ipv6Addr>() {
189 let addr = SocketAddrV6::new(addr, port, 0, 0);
190 return Some(IpAddrs {
191 iter: vec![SocketAddr::V6(addr)].into_iter(),
192 });
193 }
194 None
195 }
196
split_by_preference(self, local_addr: Option<IpAddr>) -> (IpAddrs, IpAddrs)197 pub(super) fn split_by_preference(self, local_addr: Option<IpAddr>) -> (IpAddrs, IpAddrs) {
198 if let Some(local_addr) = local_addr {
199 let preferred = self
200 .iter
201 .filter(|addr| addr.is_ipv6() == local_addr.is_ipv6())
202 .collect();
203
204 (IpAddrs::new(preferred), IpAddrs::new(vec![]))
205 } else {
206 let preferring_v6 = self
207 .iter
208 .as_slice()
209 .first()
210 .map(SocketAddr::is_ipv6)
211 .unwrap_or(false);
212
213 let (preferred, fallback) = self
214 .iter
215 .partition::<Vec<_>, _>(|addr| addr.is_ipv6() == preferring_v6);
216
217 (IpAddrs::new(preferred), IpAddrs::new(fallback))
218 }
219 }
220
is_empty(&self) -> bool221 pub(super) fn is_empty(&self) -> bool {
222 self.iter.as_slice().is_empty()
223 }
224
len(&self) -> usize225 pub(super) fn len(&self) -> usize {
226 self.iter.as_slice().len()
227 }
228 }
229
230 impl Iterator for IpAddrs {
231 type Item = SocketAddr;
232 #[inline]
next(&mut self) -> Option<SocketAddr>233 fn next(&mut self) -> Option<SocketAddr> {
234 self.iter.next()
235 }
236 }
237
238 /*
239 /// A resolver using `getaddrinfo` calls via the `tokio_executor::threadpool::blocking` API.
240 ///
241 /// Unlike the `GaiResolver` this will not spawn dedicated threads, but only works when running on the
242 /// multi-threaded Tokio runtime.
243 #[cfg(feature = "runtime")]
244 #[derive(Clone, Debug)]
245 pub struct TokioThreadpoolGaiResolver(());
246
247 /// The future returned by `TokioThreadpoolGaiResolver`.
248 #[cfg(feature = "runtime")]
249 #[derive(Debug)]
250 pub struct TokioThreadpoolGaiFuture {
251 name: Name,
252 }
253
254 #[cfg(feature = "runtime")]
255 impl TokioThreadpoolGaiResolver {
256 /// Creates a new DNS resolver that will use tokio threadpool's blocking
257 /// feature.
258 ///
259 /// **Requires** its futures to be run on the threadpool runtime.
260 pub fn new() -> Self {
261 TokioThreadpoolGaiResolver(())
262 }
263 }
264
265 #[cfg(feature = "runtime")]
266 impl Service<Name> for TokioThreadpoolGaiResolver {
267 type Response = GaiAddrs;
268 type Error = io::Error;
269 type Future = TokioThreadpoolGaiFuture;
270
271 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
272 Poll::Ready(Ok(()))
273 }
274
275 fn call(&mut self, name: Name) -> Self::Future {
276 TokioThreadpoolGaiFuture { name }
277 }
278 }
279
280 #[cfg(feature = "runtime")]
281 impl Future for TokioThreadpoolGaiFuture {
282 type Output = Result<GaiAddrs, io::Error>;
283
284 fn poll(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<Self::Output> {
285 match ready!(tokio_executor::threadpool::blocking(|| (
286 self.name.as_str(),
287 0
288 )
289 .to_socket_addrs()))
290 {
291 Ok(Ok(iter)) => Poll::Ready(Ok(GaiAddrs {
292 inner: IpAddrs { iter },
293 })),
294 Ok(Err(e)) => Poll::Ready(Err(e)),
295 // a BlockingError, meaning not on a tokio_executor::threadpool :(
296 Err(e) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
297 }
298 }
299 }
300 */
301
302 mod sealed {
303 use super::{IpAddr, Name};
304 use crate::common::{task, Future, Poll};
305 use tower_service::Service;
306
307 // "Trait alias" for `Service<Name, Response = Addrs>`
308 pub trait Resolve {
309 type Addrs: Iterator<Item = IpAddr>;
310 type Error: Into<Box<dyn std::error::Error + Send + Sync>>;
311 type Future: Future<Output = Result<Self::Addrs, Self::Error>>;
312
poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>313 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>;
resolve(&mut self, name: Name) -> Self::Future314 fn resolve(&mut self, name: Name) -> Self::Future;
315 }
316
317 impl<S> Resolve for S
318 where
319 S: Service<Name>,
320 S::Response: Iterator<Item = IpAddr>,
321 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
322 {
323 type Addrs = S::Response;
324 type Error = S::Error;
325 type Future = S::Future;
326
poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>327 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
328 Service::poll_ready(self, cx)
329 }
330
resolve(&mut self, name: Name) -> Self::Future331 fn resolve(&mut self, name: Name) -> Self::Future {
332 Service::call(self, name)
333 }
334 }
335 }
336
resolve<R>(resolver: &mut R, name: Name) -> Result<R::Addrs, R::Error> where R: Resolve,337 pub(crate) async fn resolve<R>(resolver: &mut R, name: Name) -> Result<R::Addrs, R::Error>
338 where
339 R: Resolve,
340 {
341 futures_util::future::poll_fn(|cx| resolver.poll_ready(cx)).await?;
342 resolver.resolve(name).await
343 }
344
345 #[cfg(test)]
346 mod tests {
347 use super::*;
348 use std::net::{Ipv4Addr, Ipv6Addr};
349
350 #[test]
test_ip_addrs_split_by_preference()351 fn test_ip_addrs_split_by_preference() {
352 let v4_addr = (Ipv4Addr::new(127, 0, 0, 1), 80).into();
353 let v6_addr = (Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 80).into();
354
355 let (mut preferred, mut fallback) = IpAddrs {
356 iter: vec![v4_addr, v6_addr].into_iter(),
357 }
358 .split_by_preference(None);
359 assert!(preferred.next().unwrap().is_ipv4());
360 assert!(fallback.next().unwrap().is_ipv6());
361
362 let (mut preferred, mut fallback) = IpAddrs {
363 iter: vec![v6_addr, v4_addr].into_iter(),
364 }
365 .split_by_preference(None);
366 assert!(preferred.next().unwrap().is_ipv6());
367 assert!(fallback.next().unwrap().is_ipv4());
368
369 let (mut preferred, fallback) = IpAddrs {
370 iter: vec![v4_addr, v6_addr].into_iter(),
371 }
372 .split_by_preference(Some(v4_addr.ip()));
373 assert!(preferred.next().unwrap().is_ipv4());
374 assert!(fallback.is_empty());
375
376 let (mut preferred, fallback) = IpAddrs {
377 iter: vec![v4_addr, v6_addr].into_iter(),
378 }
379 .split_by_preference(Some(v6_addr.ip()));
380 assert!(preferred.next().unwrap().is_ipv6());
381 assert!(fallback.is_empty());
382 }
383
384 #[test]
test_name_from_str()385 fn test_name_from_str() {
386 const DOMAIN: &str = "test.example.com";
387 let name = Name::from_str(DOMAIN).expect("Should be a valid domain");
388 assert_eq!(name.as_str(), DOMAIN);
389 assert_eq!(name.to_string(), DOMAIN);
390 }
391
392 #[test]
ip_addrs_try_parse_v6()393 fn ip_addrs_try_parse_v6() {
394 let dst = ::http::Uri::from_static("http://[::1]:8080/");
395
396 let mut addrs =
397 IpAddrs::try_parse(dst.host().expect("host"), dst.port_u16().expect("port"))
398 .expect("try_parse");
399
400 let expected = "[::1]:8080".parse::<SocketAddr>().expect("expected");
401
402 assert_eq!(addrs.next(), Some(expected));
403 }
404 }
405