1 use std::error::Error as StdError;
2 use std::fmt;
3 use std::future::Future;
4 use std::io;
5 use std::marker::PhantomData;
6 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7 use std::pin::Pin;
8 use std::sync::Arc;
9 use std::task::{self, Poll};
10 use std::time::Duration;
11 
12 use futures_util::future::Either;
13 use http::uri::{Scheme, Uri};
14 use pin_project::pin_project;
15 use tokio::net::TcpStream;
16 use tokio::time::Delay;
17 
18 use super::dns::{self, resolve, GaiResolver, Resolve};
19 use super::{Connected, Connection};
20 //#[cfg(feature = "runtime")] use super::dns::TokioThreadpoolGaiResolver;
21 
22 /// A connector for the `http` scheme.
23 ///
24 /// Performs DNS resolution in a thread pool, and then connects over TCP.
25 ///
26 /// # Note
27 ///
28 /// Sets the [`HttpInfo`](HttpInfo) value on responses, which includes
29 /// transport information such as the remote socket address used.
30 #[derive(Clone)]
31 pub struct HttpConnector<R = GaiResolver> {
32     config: Arc<Config>,
33     resolver: R,
34 }
35 
36 /// Extra information about the transport when an HttpConnector is used.
37 ///
38 /// # Example
39 ///
40 /// ```
41 /// # async fn doc() -> hyper::Result<()> {
42 /// use hyper::Uri;
43 /// use hyper::client::{Client, connect::HttpInfo};
44 ///
45 /// let client = Client::new();
46 /// let uri = Uri::from_static("http://example.com");
47 ///
48 /// let res = client.get(uri).await?;
49 /// res
50 ///     .extensions()
51 ///     .get::<HttpInfo>()
52 ///     .map(|info| {
53 ///         println!("remote addr = {}", info.remote_addr());
54 ///     });
55 /// # Ok(())
56 /// # }
57 /// ```
58 ///
59 /// # Note
60 ///
61 /// If a different connector is used besides [`HttpConnector`](HttpConnector),
62 /// this value will not exist in the extensions. Consult that specific
63 /// connector to see what "extra" information it might provide to responses.
64 #[derive(Clone, Debug)]
65 pub struct HttpInfo {
66     remote_addr: SocketAddr,
67 }
68 
69 #[derive(Clone)]
70 struct Config {
71     connect_timeout: Option<Duration>,
72     enforce_http: bool,
73     happy_eyeballs_timeout: Option<Duration>,
74     keep_alive_timeout: Option<Duration>,
75     local_address_ipv4: Option<Ipv4Addr>,
76     local_address_ipv6: Option<Ipv6Addr>,
77     nodelay: bool,
78     reuse_address: bool,
79     send_buffer_size: Option<usize>,
80     recv_buffer_size: Option<usize>,
81 }
82 
83 // ===== impl HttpConnector =====
84 
85 impl HttpConnector {
86     /// Construct a new HttpConnector.
new() -> HttpConnector87     pub fn new() -> HttpConnector {
88         HttpConnector::new_with_resolver(GaiResolver::new())
89     }
90 }
91 
92 /*
93 #[cfg(feature = "runtime")]
94 impl HttpConnector<TokioThreadpoolGaiResolver> {
95     /// Construct a new HttpConnector using the `TokioThreadpoolGaiResolver`.
96     ///
97     /// This resolver **requires** the threadpool runtime to be used.
98     pub fn new_with_tokio_threadpool_resolver() -> Self {
99         HttpConnector::new_with_resolver(TokioThreadpoolGaiResolver::new())
100     }
101 }
102 */
103 
104 impl<R> HttpConnector<R> {
105     /// Construct a new HttpConnector.
106     ///
107     /// Takes a [`Resolver`](crate::client::connect::dns#resolvers-are-services) to handle DNS lookups.
new_with_resolver(resolver: R) -> HttpConnector<R>108     pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
109         HttpConnector {
110             config: Arc::new(Config {
111                 connect_timeout: None,
112                 enforce_http: true,
113                 happy_eyeballs_timeout: Some(Duration::from_millis(300)),
114                 keep_alive_timeout: None,
115                 local_address_ipv4: None,
116                 local_address_ipv6: None,
117                 nodelay: false,
118                 reuse_address: false,
119                 send_buffer_size: None,
120                 recv_buffer_size: None,
121             }),
122             resolver,
123         }
124     }
125 
126     /// Option to enforce all `Uri`s have the `http` scheme.
127     ///
128     /// Enabled by default.
129     #[inline]
enforce_http(&mut self, is_enforced: bool)130     pub fn enforce_http(&mut self, is_enforced: bool) {
131         self.config_mut().enforce_http = is_enforced;
132     }
133 
134     /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration.
135     ///
136     /// If `None`, the option will not be set.
137     ///
138     /// Default is `None`.
139     #[inline]
set_keepalive(&mut self, dur: Option<Duration>)140     pub fn set_keepalive(&mut self, dur: Option<Duration>) {
141         self.config_mut().keep_alive_timeout = dur;
142     }
143 
144     /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`.
145     ///
146     /// Default is `false`.
147     #[inline]
set_nodelay(&mut self, nodelay: bool)148     pub fn set_nodelay(&mut self, nodelay: bool) {
149         self.config_mut().nodelay = nodelay;
150     }
151 
152     /// Sets the value of the SO_SNDBUF option on the socket.
153     #[inline]
set_send_buffer_size(&mut self, size: Option<usize>)154     pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
155         self.config_mut().send_buffer_size = size;
156     }
157 
158     /// Sets the value of the SO_RCVBUF option on the socket.
159     #[inline]
set_recv_buffer_size(&mut self, size: Option<usize>)160     pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
161         self.config_mut().recv_buffer_size = size;
162     }
163 
164     /// Set that all sockets are bound to the configured address before connection.
165     ///
166     /// If `None`, the sockets will not be bound.
167     ///
168     /// Default is `None`.
169     #[inline]
set_local_address(&mut self, addr: Option<IpAddr>)170     pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
171         let (v4, v6) = match addr {
172             Some(IpAddr::V4(a)) => (Some(a), None),
173             Some(IpAddr::V6(a)) => (None, Some(a)),
174             _ => (None, None),
175         };
176 
177         let cfg = self.config_mut();
178 
179         cfg.local_address_ipv4 = v4;
180         cfg.local_address_ipv6 = v6;
181     }
182 
183     /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's
184     /// preferences) before connection.
185     #[inline]
set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr)186     pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
187         let cfg = self.config_mut();
188 
189         cfg.local_address_ipv4 = Some(addr_ipv4);
190         cfg.local_address_ipv6 = Some(addr_ipv6);
191     }
192 
193     /// Set the connect timeout.
194     ///
195     /// If a domain resolves to multiple IP addresses, the timeout will be
196     /// evenly divided across them.
197     ///
198     /// Default is `None`.
199     #[inline]
set_connect_timeout(&mut self, dur: Option<Duration>)200     pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
201         self.config_mut().connect_timeout = dur;
202     }
203 
204     /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm.
205     ///
206     /// If hostname resolves to both IPv4 and IPv6 addresses and connection
207     /// cannot be established using preferred address family before timeout
208     /// elapses, then connector will in parallel attempt connection using other
209     /// address family.
210     ///
211     /// If `None`, parallel connection attempts are disabled.
212     ///
213     /// Default is 300 milliseconds.
214     ///
215     /// [RFC 6555]: https://tools.ietf.org/html/rfc6555
216     #[inline]
set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>)217     pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
218         self.config_mut().happy_eyeballs_timeout = dur;
219     }
220 
221     /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`.
222     ///
223     /// Default is `false`.
224     #[inline]
set_reuse_address(&mut self, reuse_address: bool) -> &mut Self225     pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
226         self.config_mut().reuse_address = reuse_address;
227         self
228     }
229 
230     // private
231 
config_mut(&mut self) -> &mut Config232     fn config_mut(&mut self) -> &mut Config {
233         // If the are HttpConnector clones, this will clone the inner
234         // config. So mutating the config won't ever affect previous
235         // clones.
236         Arc::make_mut(&mut self.config)
237     }
238 }
239 
240 static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
241 static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
242 static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
243 
244 // R: Debug required for now to allow adding it to debug output later...
245 impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result246     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247         f.debug_struct("HttpConnector").finish()
248     }
249 }
250 
251 impl<R> tower_service::Service<Uri> for HttpConnector<R>
252 where
253     R: Resolve + Clone + Send + Sync + 'static,
254     R::Future: Send,
255 {
256     type Response = TcpStream;
257     type Error = ConnectError;
258     type Future = HttpConnecting<R>;
259 
poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>260     fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
261         ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
262         Poll::Ready(Ok(()))
263     }
264 
call(&mut self, dst: Uri) -> Self::Future265     fn call(&mut self, dst: Uri) -> Self::Future {
266         let mut self_ = self.clone();
267         HttpConnecting {
268             fut: Box::pin(async move { self_.call_async(dst).await }),
269             _marker: PhantomData,
270         }
271     }
272 }
273 
274 impl<R> HttpConnector<R>
275 where
276     R: Resolve,
277 {
call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError>278     async fn call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError> {
279         trace!(
280             "Http::connect; scheme={:?}, host={:?}, port={:?}",
281             dst.scheme(),
282             dst.host(),
283             dst.port(),
284         );
285 
286         if self.config.enforce_http {
287             if dst.scheme() != Some(&Scheme::HTTP) {
288                 return Err(ConnectError {
289                     msg: INVALID_NOT_HTTP.into(),
290                     cause: None,
291                 });
292             }
293         } else if dst.scheme().is_none() {
294             return Err(ConnectError {
295                 msg: INVALID_MISSING_SCHEME.into(),
296                 cause: None,
297             });
298         }
299 
300         let host = match dst.host() {
301             Some(s) => s,
302             None => {
303                 return Err(ConnectError {
304                     msg: INVALID_MISSING_HOST.into(),
305                     cause: None,
306                 })
307             }
308         };
309         let port = match dst.port() {
310             Some(port) => port.as_u16(),
311             None => {
312                 if dst.scheme() == Some(&Scheme::HTTPS) {
313                     443
314                 } else {
315                     80
316                 }
317             }
318         };
319 
320         let config = &self.config;
321 
322         // If the host is already an IP addr (v4 or v6),
323         // skip resolving the dns and start connecting right away.
324         let addrs = if let Some(addrs) = dns::IpAddrs::try_parse(host, port) {
325             addrs
326         } else {
327             let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
328                 .await
329                 .map_err(ConnectError::dns)?;
330             let addrs = addrs.map(|addr| SocketAddr::new(addr, port)).collect();
331             dns::IpAddrs::new(addrs)
332         };
333 
334         let c = ConnectingTcp::new(
335             config.local_address_ipv4,
336             config.local_address_ipv6,
337             addrs,
338             config.connect_timeout,
339             config.happy_eyeballs_timeout,
340             config.reuse_address,
341         );
342 
343         let sock = c
344             .connect()
345             .await
346             .map_err(ConnectError::m("tcp connect error"))?;
347 
348         if let Some(dur) = config.keep_alive_timeout {
349             sock.set_keepalive(Some(dur))
350                 .map_err(ConnectError::m("tcp set_keepalive error"))?;
351         }
352 
353         if let Some(size) = config.send_buffer_size {
354             sock.set_send_buffer_size(size)
355                 .map_err(ConnectError::m("tcp set_send_buffer_size error"))?;
356         }
357 
358         if let Some(size) = config.recv_buffer_size {
359             sock.set_recv_buffer_size(size)
360                 .map_err(ConnectError::m("tcp set_recv_buffer_size error"))?;
361         }
362 
363         sock.set_nodelay(config.nodelay)
364             .map_err(ConnectError::m("tcp set_nodelay error"))?;
365 
366         Ok(sock)
367     }
368 }
369 
370 impl Connection for TcpStream {
connected(&self) -> Connected371     fn connected(&self) -> Connected {
372         let connected = Connected::new();
373         if let Ok(remote_addr) = self.peer_addr() {
374             connected.extra(HttpInfo { remote_addr })
375         } else {
376             connected
377         }
378     }
379 }
380 
381 impl HttpInfo {
382     /// Get the remote address of the transport used.
remote_addr(&self) -> SocketAddr383     pub fn remote_addr(&self) -> SocketAddr {
384         self.remote_addr
385     }
386 }
387 
388 // Not publicly exported (so missing_docs doesn't trigger).
389 //
390 // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly
391 // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot
392 // (and thus we can change the type in the future).
393 #[must_use = "futures do nothing unless polled"]
394 #[pin_project]
395 #[allow(missing_debug_implementations)]
396 pub struct HttpConnecting<R> {
397     #[pin]
398     fut: BoxConnecting,
399     _marker: PhantomData<R>,
400 }
401 
402 type ConnectResult = Result<TcpStream, ConnectError>;
403 type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
404 
405 impl<R: Resolve> Future for HttpConnecting<R> {
406     type Output = ConnectResult;
407 
poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>408     fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
409         self.project().fut.poll(cx)
410     }
411 }
412 
413 // Not publicly exported (so missing_docs doesn't trigger).
414 pub struct ConnectError {
415     msg: Box<str>,
416     cause: Option<Box<dyn StdError + Send + Sync>>,
417 }
418 
419 impl ConnectError {
new<S, E>(msg: S, cause: E) -> ConnectError where S: Into<Box<str>>, E: Into<Box<dyn StdError + Send + Sync>>,420     fn new<S, E>(msg: S, cause: E) -> ConnectError
421     where
422         S: Into<Box<str>>,
423         E: Into<Box<dyn StdError + Send + Sync>>,
424     {
425         ConnectError {
426             msg: msg.into(),
427             cause: Some(cause.into()),
428         }
429     }
430 
dns<E>(cause: E) -> ConnectError where E: Into<Box<dyn StdError + Send + Sync>>,431     fn dns<E>(cause: E) -> ConnectError
432     where
433         E: Into<Box<dyn StdError + Send + Sync>>,
434     {
435         ConnectError::new("dns error", cause)
436     }
437 
m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError where S: Into<Box<str>>, E: Into<Box<dyn StdError + Send + Sync>>,438     fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError
439     where
440         S: Into<Box<str>>,
441         E: Into<Box<dyn StdError + Send + Sync>>,
442     {
443         move |cause| ConnectError::new(msg, cause)
444     }
445 }
446 
447 impl fmt::Debug for ConnectError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result448     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
449         if let Some(ref cause) = self.cause {
450             f.debug_tuple("ConnectError")
451                 .field(&self.msg)
452                 .field(cause)
453                 .finish()
454         } else {
455             self.msg.fmt(f)
456         }
457     }
458 }
459 
460 impl fmt::Display for ConnectError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result461     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
462         f.write_str(&self.msg)?;
463 
464         if let Some(ref cause) = self.cause {
465             write!(f, ": {}", cause)?;
466         }
467 
468         Ok(())
469     }
470 }
471 
472 impl StdError for ConnectError {
source(&self) -> Option<&(dyn StdError + 'static)>473     fn source(&self) -> Option<&(dyn StdError + 'static)> {
474         self.cause.as_ref().map(|e| &**e as _)
475     }
476 }
477 
478 struct ConnectingTcp {
479     local_addr_ipv4: Option<Ipv4Addr>,
480     local_addr_ipv6: Option<Ipv6Addr>,
481     preferred: ConnectingTcpRemote,
482     fallback: Option<ConnectingTcpFallback>,
483     reuse_address: bool,
484 }
485 
486 impl ConnectingTcp {
new( local_addr_ipv4: Option<Ipv4Addr>, local_addr_ipv6: Option<Ipv6Addr>, remote_addrs: dns::IpAddrs, connect_timeout: Option<Duration>, fallback_timeout: Option<Duration>, reuse_address: bool, ) -> ConnectingTcp487     fn new(
488         local_addr_ipv4: Option<Ipv4Addr>,
489         local_addr_ipv6: Option<Ipv6Addr>,
490         remote_addrs: dns::IpAddrs,
491         connect_timeout: Option<Duration>,
492         fallback_timeout: Option<Duration>,
493         reuse_address: bool,
494     ) -> ConnectingTcp {
495         if let Some(fallback_timeout) = fallback_timeout {
496             let (preferred_addrs, fallback_addrs) =
497                 remote_addrs.split_by_preference(local_addr_ipv4, local_addr_ipv6);
498             if fallback_addrs.is_empty() {
499                 return ConnectingTcp {
500                     local_addr_ipv4,
501                     local_addr_ipv6,
502                     preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout),
503                     fallback: None,
504                     reuse_address,
505                 };
506             }
507 
508             ConnectingTcp {
509                 local_addr_ipv4,
510                 local_addr_ipv6,
511                 preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout),
512                 fallback: Some(ConnectingTcpFallback {
513                     delay: tokio::time::delay_for(fallback_timeout),
514                     remote: ConnectingTcpRemote::new(fallback_addrs, connect_timeout),
515                 }),
516                 reuse_address,
517             }
518         } else {
519             ConnectingTcp {
520                 local_addr_ipv4,
521                 local_addr_ipv6,
522                 preferred: ConnectingTcpRemote::new(remote_addrs, connect_timeout),
523                 fallback: None,
524                 reuse_address,
525             }
526         }
527     }
528 }
529 
530 struct ConnectingTcpFallback {
531     delay: Delay,
532     remote: ConnectingTcpRemote,
533 }
534 
535 struct ConnectingTcpRemote {
536     addrs: dns::IpAddrs,
537     connect_timeout: Option<Duration>,
538 }
539 
540 impl ConnectingTcpRemote {
new(addrs: dns::IpAddrs, connect_timeout: Option<Duration>) -> Self541     fn new(addrs: dns::IpAddrs, connect_timeout: Option<Duration>) -> Self {
542         let connect_timeout = connect_timeout.map(|t| t / (addrs.len() as u32));
543 
544         Self {
545             addrs,
546             connect_timeout,
547         }
548     }
549 }
550 
551 impl ConnectingTcpRemote {
connect( &mut self, local_addr_ipv4: &Option<Ipv4Addr>, local_addr_ipv6: &Option<Ipv6Addr>, reuse_address: bool, ) -> io::Result<TcpStream>552     async fn connect(
553         &mut self,
554         local_addr_ipv4: &Option<Ipv4Addr>,
555         local_addr_ipv6: &Option<Ipv6Addr>,
556         reuse_address: bool,
557     ) -> io::Result<TcpStream> {
558         let mut err = None;
559         for addr in &mut self.addrs {
560             debug!("connecting to {}", addr);
561             match connect(
562                 &addr,
563                 local_addr_ipv4,
564                 local_addr_ipv6,
565                 reuse_address,
566                 self.connect_timeout,
567             )?
568             .await
569             {
570                 Ok(tcp) => {
571                     debug!("connected to {}", addr);
572                     return Ok(tcp);
573                 }
574                 Err(e) => {
575                     trace!("connect error for {}: {:?}", addr, e);
576                     err = Some(e);
577                 }
578             }
579         }
580 
581         match err {
582             Some(e) => Err(e),
583             None => Err(std::io::Error::new(
584                 std::io::ErrorKind::NotConnected,
585                 "Network unreachable",
586             )),
587         }
588     }
589 }
590 
bind_local_address( socket: &socket2::Socket, dst_addr: &SocketAddr, local_addr_ipv4: &Option<Ipv4Addr>, local_addr_ipv6: &Option<Ipv6Addr>, ) -> io::Result<()>591 fn bind_local_address(
592     socket: &socket2::Socket,
593     dst_addr: &SocketAddr,
594     local_addr_ipv4: &Option<Ipv4Addr>,
595     local_addr_ipv6: &Option<Ipv6Addr>,
596 ) -> io::Result<()> {
597     match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
598         (SocketAddr::V4(_), Some(addr), _) => {
599             socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
600         }
601         (SocketAddr::V6(_), _, Some(addr)) => {
602             socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
603         }
604         _ => {
605             if cfg!(windows) {
606                 // Windows requires a socket be bound before calling connect
607                 let any: SocketAddr = match *dst_addr {
608                     SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
609                     SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
610                 };
611                 socket.bind(&any.into())?;
612             }
613         }
614     }
615 
616     Ok(())
617 }
618 
connect( addr: &SocketAddr, local_addr_ipv4: &Option<Ipv4Addr>, local_addr_ipv6: &Option<Ipv6Addr>, reuse_address: bool, connect_timeout: Option<Duration>, ) -> io::Result<impl Future<Output = io::Result<TcpStream>>>619 fn connect(
620     addr: &SocketAddr,
621     local_addr_ipv4: &Option<Ipv4Addr>,
622     local_addr_ipv6: &Option<Ipv6Addr>,
623     reuse_address: bool,
624     connect_timeout: Option<Duration>,
625 ) -> io::Result<impl Future<Output = io::Result<TcpStream>>> {
626     use socket2::{Domain, Protocol, Socket, Type};
627     let domain = match *addr {
628         SocketAddr::V4(_) => Domain::ipv4(),
629         SocketAddr::V6(_) => Domain::ipv6(),
630     };
631     let socket = Socket::new(domain, Type::stream(), Some(Protocol::tcp()))?;
632 
633     if reuse_address {
634         socket.set_reuse_address(true)?;
635     }
636 
637     bind_local_address(&socket, addr, local_addr_ipv4, local_addr_ipv6)?;
638 
639     let addr = *addr;
640 
641     let std_tcp = socket.into_tcp_stream();
642 
643     Ok(async move {
644         let connect = TcpStream::connect_std(std_tcp, &addr);
645         match connect_timeout {
646             Some(dur) => match tokio::time::timeout(dur, connect).await {
647                 Ok(Ok(s)) => Ok(s),
648                 Ok(Err(e)) => Err(e),
649                 Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
650             },
651             None => connect.await,
652         }
653     })
654 }
655 
656 impl ConnectingTcp {
connect(mut self) -> io::Result<TcpStream>657     async fn connect(mut self) -> io::Result<TcpStream> {
658         let Self {
659             ref local_addr_ipv4,
660             ref local_addr_ipv6,
661             reuse_address,
662             ..
663         } = self;
664         match self.fallback {
665             None => {
666                 self.preferred
667                     .connect(local_addr_ipv4, local_addr_ipv6, reuse_address)
668                     .await
669             }
670             Some(mut fallback) => {
671                 let preferred_fut =
672                     self.preferred
673                         .connect(local_addr_ipv4, local_addr_ipv6, reuse_address);
674                 futures_util::pin_mut!(preferred_fut);
675 
676                 let fallback_fut =
677                     fallback
678                         .remote
679                         .connect(local_addr_ipv4, local_addr_ipv6, reuse_address);
680                 futures_util::pin_mut!(fallback_fut);
681 
682                 let (result, future) =
683                     match futures_util::future::select(preferred_fut, fallback.delay).await {
684                         Either::Left((result, _fallback_delay)) => {
685                             (result, Either::Right(fallback_fut))
686                         }
687                         Either::Right(((), preferred_fut)) => {
688                             // Delay is done, start polling both the preferred and the fallback
689                             futures_util::future::select(preferred_fut, fallback_fut)
690                                 .await
691                                 .factor_first()
692                         }
693                     };
694 
695                 if result.is_err() {
696                     // Fallback to the remaining future (could be preferred or fallback)
697                     // if we get an error
698                     future.await
699                 } else {
700                     result
701                 }
702             }
703         }
704     }
705 }
706 
707 #[cfg(test)]
708 mod tests {
709     use std::io;
710 
711     use ::http::Uri;
712 
713     use super::super::sealed::{Connect, ConnectSvc};
714     use super::HttpConnector;
715 
connect<C>( connector: C, dst: Uri, ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error> where C: Connect,716     async fn connect<C>(
717         connector: C,
718         dst: Uri,
719     ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
720     where
721         C: Connect,
722     {
723         connector.connect(super::super::sealed::Internal, dst).await
724     }
725 
726     #[tokio::test]
test_errors_enforce_http()727     async fn test_errors_enforce_http() {
728         let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
729         let connector = HttpConnector::new();
730 
731         let err = connect(connector, dst).await.unwrap_err();
732         assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
733     }
734 
735     #[cfg(any(target_os = "linux", target_os = "macos"))]
get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>)736     fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
737         use std::net::{IpAddr, TcpListener};
738 
739         let mut ip_v4 = None;
740         let mut ip_v6 = None;
741 
742         let ips = pnet::datalink::interfaces()
743             .into_iter()
744             .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
745 
746         for ip in ips {
747             match ip {
748                 IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
749                 IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
750                 _ => (),
751             }
752 
753             if ip_v4.is_some() && ip_v6.is_some() {
754                 break;
755             }
756         }
757 
758         (ip_v4, ip_v6)
759     }
760 
761     #[tokio::test]
test_errors_missing_scheme()762     async fn test_errors_missing_scheme() {
763         let dst = "example.domain".parse().unwrap();
764         let mut connector = HttpConnector::new();
765         connector.enforce_http(false);
766 
767         let err = connect(connector, dst).await.unwrap_err();
768         assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
769     }
770 
771     // NOTE: pnet crate that we use in this test doesn't compile on Windows
772     #[cfg(any(target_os = "linux", target_os = "macos"))]
773     #[tokio::test]
local_address()774     async fn local_address() {
775         use std::net::{IpAddr, TcpListener};
776 
777         let (bind_ip_v4, bind_ip_v6) = get_local_ips();
778         let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
779         let port = server4.local_addr().unwrap().port();
780         let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
781 
782         let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
783             let mut connector = HttpConnector::new();
784 
785             match (bind_ip_v4, bind_ip_v6) {
786                 (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
787                 (Some(v4), None) => connector.set_local_address(Some(v4.into())),
788                 (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
789                 _ => unreachable!(),
790             }
791 
792             connect(connector, dst.parse().unwrap()).await.unwrap();
793 
794             let (_, client_addr) = server.accept().unwrap();
795 
796             assert_eq!(client_addr.ip(), expected_ip);
797         };
798 
799         if let Some(ip) = bind_ip_v4 {
800             assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await;
801         }
802 
803         if let Some(ip) = bind_ip_v6 {
804             assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await;
805         }
806     }
807 
808     #[test]
809     #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
client_happy_eyeballs()810     fn client_happy_eyeballs() {
811         use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
812         use std::time::{Duration, Instant};
813 
814         use super::dns;
815         use super::ConnectingTcp;
816 
817         let _ = pretty_env_logger::try_init();
818         let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
819         let addr = server4.local_addr().unwrap();
820         let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap();
821         let mut rt = tokio::runtime::Builder::new()
822             .enable_io()
823             .enable_time()
824             .basic_scheduler()
825             .build()
826             .unwrap();
827 
828         let local_timeout = Duration::default();
829         let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
830         let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
831         let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
832             + Duration::from_millis(250);
833 
834         let scenarios = &[
835             // Fast primary, without fallback.
836             (&[local_ipv4_addr()][..], 4, local_timeout, false),
837             (&[local_ipv6_addr()][..], 6, local_timeout, false),
838             // Fast primary, with (unused) fallback.
839             (
840                 &[local_ipv4_addr(), local_ipv6_addr()][..],
841                 4,
842                 local_timeout,
843                 false,
844             ),
845             (
846                 &[local_ipv6_addr(), local_ipv4_addr()][..],
847                 6,
848                 local_timeout,
849                 false,
850             ),
851             // Unreachable + fast primary, without fallback.
852             (
853                 &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
854                 4,
855                 unreachable_v4_timeout,
856                 false,
857             ),
858             (
859                 &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
860                 6,
861                 unreachable_v6_timeout,
862                 false,
863             ),
864             // Unreachable + fast primary, with (unused) fallback.
865             (
866                 &[
867                     unreachable_ipv4_addr(),
868                     local_ipv4_addr(),
869                     local_ipv6_addr(),
870                 ][..],
871                 4,
872                 unreachable_v4_timeout,
873                 false,
874             ),
875             (
876                 &[
877                     unreachable_ipv6_addr(),
878                     local_ipv6_addr(),
879                     local_ipv4_addr(),
880                 ][..],
881                 6,
882                 unreachable_v6_timeout,
883                 true,
884             ),
885             // Slow primary, with (used) fallback.
886             (
887                 &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
888                 6,
889                 fallback_timeout,
890                 false,
891             ),
892             (
893                 &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
894                 4,
895                 fallback_timeout,
896                 true,
897             ),
898             // Slow primary, with (used) unreachable + fast fallback.
899             (
900                 &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
901                 6,
902                 fallback_timeout + unreachable_v6_timeout,
903                 false,
904             ),
905             (
906                 &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
907                 4,
908                 fallback_timeout + unreachable_v4_timeout,
909                 true,
910             ),
911         ];
912 
913         // Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network.
914         // Otherwise, connection to "slow" IPv6 address will error-out immediately.
915         let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
916 
917         for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
918             if needs_ipv6_access && !ipv6_accessible {
919                 continue;
920             }
921 
922             let (start, stream) = rt
923                 .block_on(async move {
924                     let addrs = hosts
925                         .iter()
926                         .map(|host| (host.clone(), addr.port()).into())
927                         .collect();
928                     let connecting_tcp = ConnectingTcp::new(
929                         None,
930                         None,
931                         dns::IpAddrs::new(addrs),
932                         None,
933                         Some(fallback_timeout),
934                         false,
935                     );
936                     let start = Instant::now();
937                     Ok::<_, io::Error>((start, connecting_tcp.connect().await?))
938                 })
939                 .unwrap();
940             let res = if stream.peer_addr().unwrap().is_ipv4() {
941                 4
942             } else {
943                 6
944             };
945             let duration = start.elapsed();
946 
947             // Allow actual duration to be +/- 150ms off.
948             let min_duration = if timeout >= Duration::from_millis(150) {
949                 timeout - Duration::from_millis(150)
950             } else {
951                 Duration::default()
952             };
953             let max_duration = timeout + Duration::from_millis(150);
954 
955             assert_eq!(res, family);
956             assert!(duration >= min_duration);
957             assert!(duration <= max_duration);
958         }
959 
960         fn local_ipv4_addr() -> IpAddr {
961             Ipv4Addr::new(127, 0, 0, 1).into()
962         }
963 
964         fn local_ipv6_addr() -> IpAddr {
965             Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
966         }
967 
968         fn unreachable_ipv4_addr() -> IpAddr {
969             Ipv4Addr::new(127, 0, 0, 2).into()
970         }
971 
972         fn unreachable_ipv6_addr() -> IpAddr {
973             Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
974         }
975 
976         fn slow_ipv4_addr() -> IpAddr {
977             // RFC 6890 reserved IPv4 address.
978             Ipv4Addr::new(198, 18, 0, 25).into()
979         }
980 
981         fn slow_ipv6_addr() -> IpAddr {
982             // RFC 6890 reserved IPv6 address.
983             Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
984         }
985 
986         fn measure_connect(addr: IpAddr) -> (bool, Duration) {
987             let start = Instant::now();
988             let result =
989                 std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
990 
991             let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
992             let duration = start.elapsed();
993             (reachable, duration)
994         }
995     }
996 }
997