1 use std::error::Error as StdError;
2 use std::fmt;
3 use std::future::Future;
4 use std::io;
5 use std::marker::PhantomData;
6 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7 use std::pin::Pin;
8 use std::sync::Arc;
9 use std::task::{self, Poll};
10 use std::time::Duration;
11 
12 use futures_util::future::Either;
13 use http::uri::{Scheme, Uri};
14 use pin_project_lite::pin_project;
15 use tokio::net::{TcpSocket, TcpStream};
16 use tokio::time::Sleep;
17 
18 use super::dns::{self, resolve, GaiResolver, Resolve};
19 use super::{Connected, Connection};
20 //#[cfg(feature = "runtime")] use super::dns::TokioThreadpoolGaiResolver;
21 
22 /// A connector for the `http` scheme.
23 ///
24 /// Performs DNS resolution in a thread pool, and then connects over TCP.
25 ///
26 /// # Note
27 ///
28 /// Sets the [`HttpInfo`](HttpInfo) value on responses, which includes
29 /// transport information such as the remote socket address used.
30 #[cfg_attr(docsrs, doc(cfg(feature = "tcp")))]
31 #[derive(Clone)]
32 pub struct HttpConnector<R = GaiResolver> {
33     config: Arc<Config>,
34     resolver: R,
35 }
36 
37 /// Extra information about the transport when an HttpConnector is used.
38 ///
39 /// # Example
40 ///
41 /// ```
42 /// # async fn doc() -> hyper::Result<()> {
43 /// use hyper::Uri;
44 /// use hyper::client::{Client, connect::HttpInfo};
45 ///
46 /// let client = Client::new();
47 /// let uri = Uri::from_static("http://example.com");
48 ///
49 /// let res = client.get(uri).await?;
50 /// res
51 ///     .extensions()
52 ///     .get::<HttpInfo>()
53 ///     .map(|info| {
54 ///         println!("remote addr = {}", info.remote_addr());
55 ///     });
56 /// # Ok(())
57 /// # }
58 /// ```
59 ///
60 /// # Note
61 ///
62 /// If a different connector is used besides [`HttpConnector`](HttpConnector),
63 /// this value will not exist in the extensions. Consult that specific
64 /// connector to see what "extra" information it might provide to responses.
65 #[derive(Clone, Debug)]
66 pub struct HttpInfo {
67     remote_addr: SocketAddr,
68 }
69 
70 #[derive(Clone)]
71 struct Config {
72     connect_timeout: Option<Duration>,
73     enforce_http: bool,
74     happy_eyeballs_timeout: Option<Duration>,
75     keep_alive_timeout: Option<Duration>,
76     local_address_ipv4: Option<Ipv4Addr>,
77     local_address_ipv6: Option<Ipv6Addr>,
78     nodelay: bool,
79     reuse_address: bool,
80     send_buffer_size: Option<usize>,
81     recv_buffer_size: Option<usize>,
82 }
83 
84 // ===== impl HttpConnector =====
85 
86 impl HttpConnector {
87     /// Construct a new HttpConnector.
new() -> HttpConnector88     pub fn new() -> HttpConnector {
89         HttpConnector::new_with_resolver(GaiResolver::new())
90     }
91 }
92 
93 /*
94 #[cfg(feature = "runtime")]
95 impl HttpConnector<TokioThreadpoolGaiResolver> {
96     /// Construct a new HttpConnector using the `TokioThreadpoolGaiResolver`.
97     ///
98     /// This resolver **requires** the threadpool runtime to be used.
99     pub fn new_with_tokio_threadpool_resolver() -> Self {
100         HttpConnector::new_with_resolver(TokioThreadpoolGaiResolver::new())
101     }
102 }
103 */
104 
105 impl<R> HttpConnector<R> {
106     /// Construct a new HttpConnector.
107     ///
108     /// Takes a [`Resolver`](crate::client::connect::dns#resolvers-are-services) to handle DNS lookups.
new_with_resolver(resolver: R) -> HttpConnector<R>109     pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
110         HttpConnector {
111             config: Arc::new(Config {
112                 connect_timeout: None,
113                 enforce_http: true,
114                 happy_eyeballs_timeout: Some(Duration::from_millis(300)),
115                 keep_alive_timeout: None,
116                 local_address_ipv4: None,
117                 local_address_ipv6: None,
118                 nodelay: false,
119                 reuse_address: false,
120                 send_buffer_size: None,
121                 recv_buffer_size: None,
122             }),
123             resolver,
124         }
125     }
126 
127     /// Option to enforce all `Uri`s have the `http` scheme.
128     ///
129     /// Enabled by default.
130     #[inline]
enforce_http(&mut self, is_enforced: bool)131     pub fn enforce_http(&mut self, is_enforced: bool) {
132         self.config_mut().enforce_http = is_enforced;
133     }
134 
135     /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration.
136     ///
137     /// If `None`, the option will not be set.
138     ///
139     /// Default is `None`.
140     #[inline]
set_keepalive(&mut self, dur: Option<Duration>)141     pub fn set_keepalive(&mut self, dur: Option<Duration>) {
142         self.config_mut().keep_alive_timeout = dur;
143     }
144 
145     /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`.
146     ///
147     /// Default is `false`.
148     #[inline]
set_nodelay(&mut self, nodelay: bool)149     pub fn set_nodelay(&mut self, nodelay: bool) {
150         self.config_mut().nodelay = nodelay;
151     }
152 
153     /// Sets the value of the SO_SNDBUF option on the socket.
154     #[inline]
set_send_buffer_size(&mut self, size: Option<usize>)155     pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
156         self.config_mut().send_buffer_size = size;
157     }
158 
159     /// Sets the value of the SO_RCVBUF option on the socket.
160     #[inline]
set_recv_buffer_size(&mut self, size: Option<usize>)161     pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
162         self.config_mut().recv_buffer_size = size;
163     }
164 
165     /// Set that all sockets are bound to the configured address before connection.
166     ///
167     /// If `None`, the sockets will not be bound.
168     ///
169     /// Default is `None`.
170     #[inline]
set_local_address(&mut self, addr: Option<IpAddr>)171     pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
172         let (v4, v6) = match addr {
173             Some(IpAddr::V4(a)) => (Some(a), None),
174             Some(IpAddr::V6(a)) => (None, Some(a)),
175             _ => (None, None),
176         };
177 
178         let cfg = self.config_mut();
179 
180         cfg.local_address_ipv4 = v4;
181         cfg.local_address_ipv6 = v6;
182     }
183 
184     /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's
185     /// preferences) before connection.
186     #[inline]
set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr)187     pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
188         let cfg = self.config_mut();
189 
190         cfg.local_address_ipv4 = Some(addr_ipv4);
191         cfg.local_address_ipv6 = Some(addr_ipv6);
192     }
193 
194     /// Set the connect timeout.
195     ///
196     /// If a domain resolves to multiple IP addresses, the timeout will be
197     /// evenly divided across them.
198     ///
199     /// Default is `None`.
200     #[inline]
set_connect_timeout(&mut self, dur: Option<Duration>)201     pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
202         self.config_mut().connect_timeout = dur;
203     }
204 
205     /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm.
206     ///
207     /// If hostname resolves to both IPv4 and IPv6 addresses and connection
208     /// cannot be established using preferred address family before timeout
209     /// elapses, then connector will in parallel attempt connection using other
210     /// address family.
211     ///
212     /// If `None`, parallel connection attempts are disabled.
213     ///
214     /// Default is 300 milliseconds.
215     ///
216     /// [RFC 6555]: https://tools.ietf.org/html/rfc6555
217     #[inline]
set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>)218     pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
219         self.config_mut().happy_eyeballs_timeout = dur;
220     }
221 
222     /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`.
223     ///
224     /// Default is `false`.
225     #[inline]
set_reuse_address(&mut self, reuse_address: bool) -> &mut Self226     pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
227         self.config_mut().reuse_address = reuse_address;
228         self
229     }
230 
231     // private
232 
config_mut(&mut self) -> &mut Config233     fn config_mut(&mut self) -> &mut Config {
234         // If the are HttpConnector clones, this will clone the inner
235         // config. So mutating the config won't ever affect previous
236         // clones.
237         Arc::make_mut(&mut self.config)
238     }
239 }
240 
241 static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
242 static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
243 static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
244 
245 // R: Debug required for now to allow adding it to debug output later...
246 impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result247     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248         f.debug_struct("HttpConnector").finish()
249     }
250 }
251 
252 impl<R> tower_service::Service<Uri> for HttpConnector<R>
253 where
254     R: Resolve + Clone + Send + Sync + 'static,
255     R::Future: Send,
256 {
257     type Response = TcpStream;
258     type Error = ConnectError;
259     type Future = HttpConnecting<R>;
260 
poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>261     fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
262         ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
263         Poll::Ready(Ok(()))
264     }
265 
call(&mut self, dst: Uri) -> Self::Future266     fn call(&mut self, dst: Uri) -> Self::Future {
267         let mut self_ = self.clone();
268         HttpConnecting {
269             fut: Box::pin(async move { self_.call_async(dst).await }),
270             _marker: PhantomData,
271         }
272     }
273 }
274 
get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError>275 fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> {
276     trace!(
277         "Http::connect; scheme={:?}, host={:?}, port={:?}",
278         dst.scheme(),
279         dst.host(),
280         dst.port(),
281     );
282 
283     if config.enforce_http {
284         if dst.scheme() != Some(&Scheme::HTTP) {
285             return Err(ConnectError {
286                 msg: INVALID_NOT_HTTP.into(),
287                 cause: None,
288             });
289         }
290     } else if dst.scheme().is_none() {
291         return Err(ConnectError {
292             msg: INVALID_MISSING_SCHEME.into(),
293             cause: None,
294         });
295     }
296 
297     let host = match dst.host() {
298         Some(s) => s,
299         None => {
300             return Err(ConnectError {
301                 msg: INVALID_MISSING_HOST.into(),
302                 cause: None,
303             })
304         }
305     };
306     let port = match dst.port() {
307         Some(port) => port.as_u16(),
308         None => {
309             if dst.scheme() == Some(&Scheme::HTTPS) {
310                 443
311             } else {
312                 80
313             }
314         }
315     };
316 
317     Ok((host, port))
318 }
319 
320 impl<R> HttpConnector<R>
321 where
322     R: Resolve,
323 {
call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError>324     async fn call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError> {
325         let config = &self.config;
326 
327         let (host, port) = get_host_port(config, &dst)?;
328 
329         // If the host is already an IP addr (v4 or v6),
330         // skip resolving the dns and start connecting right away.
331         let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) {
332             addrs
333         } else {
334             let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
335                 .await
336                 .map_err(ConnectError::dns)?;
337             let addrs = addrs
338                 .map(|mut addr| {
339                     addr.set_port(port);
340                     addr
341                 })
342                 .collect();
343             dns::SocketAddrs::new(addrs)
344         };
345 
346         let c = ConnectingTcp::new(addrs, config);
347 
348         let sock = c.connect().await?;
349 
350         if let Err(e) = sock.set_nodelay(config.nodelay) {
351             warn!("tcp set_nodelay error: {}", e);
352         }
353 
354         Ok(sock)
355     }
356 }
357 
358 impl Connection for TcpStream {
connected(&self) -> Connected359     fn connected(&self) -> Connected {
360         let connected = Connected::new();
361         if let Ok(remote_addr) = self.peer_addr() {
362             connected.extra(HttpInfo { remote_addr })
363         } else {
364             connected
365         }
366     }
367 }
368 
369 impl HttpInfo {
370     /// Get the remote address of the transport used.
remote_addr(&self) -> SocketAddr371     pub fn remote_addr(&self) -> SocketAddr {
372         self.remote_addr
373     }
374 }
375 
376 pin_project! {
377     // Not publicly exported (so missing_docs doesn't trigger).
378     //
379     // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly
380     // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot
381     // (and thus we can change the type in the future).
382     #[must_use = "futures do nothing unless polled"]
383     #[allow(missing_debug_implementations)]
384     pub struct HttpConnecting<R> {
385         #[pin]
386         fut: BoxConnecting,
387         _marker: PhantomData<R>,
388     }
389 }
390 
391 type ConnectResult = Result<TcpStream, ConnectError>;
392 type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
393 
394 impl<R: Resolve> Future for HttpConnecting<R> {
395     type Output = ConnectResult;
396 
poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>397     fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
398         self.project().fut.poll(cx)
399     }
400 }
401 
402 // Not publicly exported (so missing_docs doesn't trigger).
403 pub struct ConnectError {
404     msg: Box<str>,
405     cause: Option<Box<dyn StdError + Send + Sync>>,
406 }
407 
408 impl ConnectError {
new<S, E>(msg: S, cause: E) -> ConnectError where S: Into<Box<str>>, E: Into<Box<dyn StdError + Send + Sync>>,409     fn new<S, E>(msg: S, cause: E) -> ConnectError
410     where
411         S: Into<Box<str>>,
412         E: Into<Box<dyn StdError + Send + Sync>>,
413     {
414         ConnectError {
415             msg: msg.into(),
416             cause: Some(cause.into()),
417         }
418     }
419 
dns<E>(cause: E) -> ConnectError where E: Into<Box<dyn StdError + Send + Sync>>,420     fn dns<E>(cause: E) -> ConnectError
421     where
422         E: Into<Box<dyn StdError + Send + Sync>>,
423     {
424         ConnectError::new("dns error", cause)
425     }
426 
m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError where S: Into<Box<str>>, E: Into<Box<dyn StdError + Send + Sync>>,427     fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError
428     where
429         S: Into<Box<str>>,
430         E: Into<Box<dyn StdError + Send + Sync>>,
431     {
432         move |cause| ConnectError::new(msg, cause)
433     }
434 }
435 
436 impl fmt::Debug for ConnectError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result437     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
438         if let Some(ref cause) = self.cause {
439             f.debug_tuple("ConnectError")
440                 .field(&self.msg)
441                 .field(cause)
442                 .finish()
443         } else {
444             self.msg.fmt(f)
445         }
446     }
447 }
448 
449 impl fmt::Display for ConnectError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result450     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
451         f.write_str(&self.msg)?;
452 
453         if let Some(ref cause) = self.cause {
454             write!(f, ": {}", cause)?;
455         }
456 
457         Ok(())
458     }
459 }
460 
461 impl StdError for ConnectError {
source(&self) -> Option<&(dyn StdError + 'static)>462     fn source(&self) -> Option<&(dyn StdError + 'static)> {
463         self.cause.as_ref().map(|e| &**e as _)
464     }
465 }
466 
467 struct ConnectingTcp<'a> {
468     preferred: ConnectingTcpRemote,
469     fallback: Option<ConnectingTcpFallback>,
470     config: &'a Config,
471 }
472 
473 impl<'a> ConnectingTcp<'a> {
new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self474     fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self {
475         if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
476             let (preferred_addrs, fallback_addrs) = remote_addrs
477                 .split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
478             if fallback_addrs.is_empty() {
479                 return ConnectingTcp {
480                     preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
481                     fallback: None,
482                     config,
483                 };
484             }
485 
486             ConnectingTcp {
487                 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
488                 fallback: Some(ConnectingTcpFallback {
489                     delay: tokio::time::sleep(fallback_timeout),
490                     remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
491                 }),
492                 config,
493             }
494         } else {
495             ConnectingTcp {
496                 preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
497                 fallback: None,
498                 config,
499             }
500         }
501     }
502 }
503 
504 struct ConnectingTcpFallback {
505     delay: Sleep,
506     remote: ConnectingTcpRemote,
507 }
508 
509 struct ConnectingTcpRemote {
510     addrs: dns::SocketAddrs,
511     connect_timeout: Option<Duration>,
512 }
513 
514 impl ConnectingTcpRemote {
new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self515     fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self {
516         let connect_timeout = connect_timeout.map(|t| t / (addrs.len() as u32));
517 
518         Self {
519             addrs,
520             connect_timeout,
521         }
522     }
523 }
524 
525 impl ConnectingTcpRemote {
connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError>526     async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
527         let mut err = None;
528         for addr in &mut self.addrs {
529             debug!("connecting to {}", addr);
530             match connect(&addr, config, self.connect_timeout)?.await {
531                 Ok(tcp) => {
532                     debug!("connected to {}", addr);
533                     return Ok(tcp);
534                 }
535                 Err(e) => {
536                     trace!("connect error for {}: {:?}", addr, e);
537                     err = Some(e);
538                 }
539             }
540         }
541 
542         match err {
543             Some(e) => Err(e),
544             None => Err(ConnectError::new(
545                 "tcp connect error",
546                 std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
547             )),
548         }
549     }
550 }
551 
bind_local_address( socket: &socket2::Socket, dst_addr: &SocketAddr, local_addr_ipv4: &Option<Ipv4Addr>, local_addr_ipv6: &Option<Ipv6Addr>, ) -> io::Result<()>552 fn bind_local_address(
553     socket: &socket2::Socket,
554     dst_addr: &SocketAddr,
555     local_addr_ipv4: &Option<Ipv4Addr>,
556     local_addr_ipv6: &Option<Ipv6Addr>,
557 ) -> io::Result<()> {
558     match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
559         (SocketAddr::V4(_), Some(addr), _) => {
560             socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
561         }
562         (SocketAddr::V6(_), _, Some(addr)) => {
563             socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
564         }
565         _ => {
566             if cfg!(windows) {
567                 // Windows requires a socket be bound before calling connect
568                 let any: SocketAddr = match *dst_addr {
569                     SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
570                     SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
571                 };
572                 socket.bind(&any.into())?;
573             }
574         }
575     }
576 
577     Ok(())
578 }
579 
connect( addr: &SocketAddr, config: &Config, connect_timeout: Option<Duration>, ) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError>580 fn connect(
581     addr: &SocketAddr,
582     config: &Config,
583     connect_timeout: Option<Duration>,
584 ) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
585     // TODO(eliza): if Tokio's `TcpSocket` gains support for setting the
586     // keepalive timeout, it would be nice to use that instead of socket2,
587     // and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance...
588     use socket2::{Domain, Protocol, Socket, TcpKeepalive, Type};
589     use std::convert::TryInto;
590 
591     let domain = Domain::for_address(*addr);
592     let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
593         .map_err(ConnectError::m("tcp open error"))?;
594 
595     // When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is
596     // responsible for ensuring O_NONBLOCK is set.
597     socket
598         .set_nonblocking(true)
599         .map_err(ConnectError::m("tcp set_nonblocking error"))?;
600 
601     if let Some(dur) = config.keep_alive_timeout {
602         let conf = TcpKeepalive::new().with_time(dur);
603         if let Err(e) = socket.set_tcp_keepalive(&conf) {
604             warn!("tcp set_keepalive error: {}", e);
605         }
606     }
607 
608     bind_local_address(
609         &socket,
610         addr,
611         &config.local_address_ipv4,
612         &config.local_address_ipv6,
613     )
614     .map_err(ConnectError::m("tcp bind local error"))?;
615 
616     #[cfg(unix)]
617     let socket = unsafe {
618         // Safety: `from_raw_fd` is only safe to call if ownership of the raw
619         // file descriptor is transferred. Since we call `into_raw_fd` on the
620         // socket2 socket, it gives up ownership of the fd and will not close
621         // it, so this is safe.
622         use std::os::unix::io::{FromRawFd, IntoRawFd};
623         TcpSocket::from_raw_fd(socket.into_raw_fd())
624     };
625     #[cfg(windows)]
626     let socket = unsafe {
627         // Safety: `from_raw_socket` is only safe to call if ownership of the raw
628         // Windows SOCKET is transferred. Since we call `into_raw_socket` on the
629         // socket2 socket, it gives up ownership of the SOCKET and will not close
630         // it, so this is safe.
631         use std::os::windows::io::{FromRawSocket, IntoRawSocket};
632         TcpSocket::from_raw_socket(socket.into_raw_socket())
633     };
634 
635     if config.reuse_address {
636         if let Err(e) = socket.set_reuseaddr(true) {
637             warn!("tcp set_reuse_address error: {}", e);
638         }
639     }
640 
641     if let Some(size) = config.send_buffer_size {
642         if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) {
643             warn!("tcp set_buffer_size error: {}", e);
644         }
645     }
646 
647     if let Some(size) = config.recv_buffer_size {
648         if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) {
649             warn!("tcp set_recv_buffer_size error: {}", e);
650         }
651     }
652 
653     let connect = socket.connect(*addr);
654     Ok(async move {
655         match connect_timeout {
656             Some(dur) => match tokio::time::timeout(dur, connect).await {
657                 Ok(Ok(s)) => Ok(s),
658                 Ok(Err(e)) => Err(e),
659                 Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
660             },
661             None => connect.await,
662         }
663         .map_err(ConnectError::m("tcp connect error"))
664     })
665 }
666 
667 impl ConnectingTcp<'_> {
connect(mut self) -> Result<TcpStream, ConnectError>668     async fn connect(mut self) -> Result<TcpStream, ConnectError> {
669         match self.fallback {
670             None => self.preferred.connect(self.config).await,
671             Some(mut fallback) => {
672                 let preferred_fut = self.preferred.connect(self.config);
673                 futures_util::pin_mut!(preferred_fut);
674 
675                 let fallback_fut = fallback.remote.connect(self.config);
676                 futures_util::pin_mut!(fallback_fut);
677 
678                 let fallback_delay = fallback.delay;
679                 futures_util::pin_mut!(fallback_delay);
680 
681                 let (result, future) =
682                     match futures_util::future::select(preferred_fut, fallback_delay).await {
683                         Either::Left((result, _fallback_delay)) => {
684                             (result, Either::Right(fallback_fut))
685                         }
686                         Either::Right(((), preferred_fut)) => {
687                             // Delay is done, start polling both the preferred and the fallback
688                             futures_util::future::select(preferred_fut, fallback_fut)
689                                 .await
690                                 .factor_first()
691                         }
692                     };
693 
694                 if result.is_err() {
695                     // Fallback to the remaining future (could be preferred or fallback)
696                     // if we get an error
697                     future.await
698                 } else {
699                     result
700                 }
701             }
702         }
703     }
704 }
705 
706 #[cfg(test)]
707 mod tests {
708     use std::io;
709 
710     use ::http::Uri;
711 
712     use super::super::sealed::{Connect, ConnectSvc};
713     use super::{Config, ConnectError, HttpConnector};
714 
connect<C>( connector: C, dst: Uri, ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error> where C: Connect,715     async fn connect<C>(
716         connector: C,
717         dst: Uri,
718     ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
719     where
720         C: Connect,
721     {
722         connector.connect(super::super::sealed::Internal, dst).await
723     }
724 
725     #[tokio::test]
test_errors_enforce_http()726     async fn test_errors_enforce_http() {
727         let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
728         let connector = HttpConnector::new();
729 
730         let err = connect(connector, dst).await.unwrap_err();
731         assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
732     }
733 
734     #[cfg(any(target_os = "linux", target_os = "macos"))]
get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>)735     fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
736         use std::net::{IpAddr, TcpListener};
737 
738         let mut ip_v4 = None;
739         let mut ip_v6 = None;
740 
741         let ips = pnet_datalink::interfaces()
742             .into_iter()
743             .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
744 
745         for ip in ips {
746             match ip {
747                 IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
748                 IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
749                 _ => (),
750             }
751 
752             if ip_v4.is_some() && ip_v6.is_some() {
753                 break;
754             }
755         }
756 
757         (ip_v4, ip_v6)
758     }
759 
760     #[tokio::test]
test_errors_missing_scheme()761     async fn test_errors_missing_scheme() {
762         let dst = "example.domain".parse().unwrap();
763         let mut connector = HttpConnector::new();
764         connector.enforce_http(false);
765 
766         let err = connect(connector, dst).await.unwrap_err();
767         assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
768     }
769 
770     // NOTE: pnet crate that we use in this test doesn't compile on Windows
771     #[cfg(any(target_os = "linux", target_os = "macos"))]
772     #[tokio::test]
local_address()773     async fn local_address() {
774         use std::net::{IpAddr, TcpListener};
775         let _ = pretty_env_logger::try_init();
776 
777         let (bind_ip_v4, bind_ip_v6) = get_local_ips();
778         let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
779         let port = server4.local_addr().unwrap().port();
780         let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
781 
782         let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
783             let mut connector = HttpConnector::new();
784 
785             match (bind_ip_v4, bind_ip_v6) {
786                 (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
787                 (Some(v4), None) => connector.set_local_address(Some(v4.into())),
788                 (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
789                 _ => unreachable!(),
790             }
791 
792             connect(connector, dst.parse().unwrap()).await.unwrap();
793 
794             let (_, client_addr) = server.accept().unwrap();
795 
796             assert_eq!(client_addr.ip(), expected_ip);
797         };
798 
799         if let Some(ip) = bind_ip_v4 {
800             assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await;
801         }
802 
803         if let Some(ip) = bind_ip_v6 {
804             assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await;
805         }
806     }
807 
808     #[test]
809     #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
client_happy_eyeballs()810     fn client_happy_eyeballs() {
811         use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
812         use std::time::{Duration, Instant};
813 
814         use super::dns;
815         use super::ConnectingTcp;
816 
817         let _ = pretty_env_logger::try_init();
818         let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
819         let addr = server4.local_addr().unwrap();
820         let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap();
821         let rt = tokio::runtime::Builder::new_current_thread()
822             .enable_all()
823             .build()
824             .unwrap();
825 
826         let local_timeout = Duration::default();
827         let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
828         let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
829         let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
830             + Duration::from_millis(250);
831 
832         let scenarios = &[
833             // Fast primary, without fallback.
834             (&[local_ipv4_addr()][..], 4, local_timeout, false),
835             (&[local_ipv6_addr()][..], 6, local_timeout, false),
836             // Fast primary, with (unused) fallback.
837             (
838                 &[local_ipv4_addr(), local_ipv6_addr()][..],
839                 4,
840                 local_timeout,
841                 false,
842             ),
843             (
844                 &[local_ipv6_addr(), local_ipv4_addr()][..],
845                 6,
846                 local_timeout,
847                 false,
848             ),
849             // Unreachable + fast primary, without fallback.
850             (
851                 &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
852                 4,
853                 unreachable_v4_timeout,
854                 false,
855             ),
856             (
857                 &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
858                 6,
859                 unreachable_v6_timeout,
860                 false,
861             ),
862             // Unreachable + fast primary, with (unused) fallback.
863             (
864                 &[
865                     unreachable_ipv4_addr(),
866                     local_ipv4_addr(),
867                     local_ipv6_addr(),
868                 ][..],
869                 4,
870                 unreachable_v4_timeout,
871                 false,
872             ),
873             (
874                 &[
875                     unreachable_ipv6_addr(),
876                     local_ipv6_addr(),
877                     local_ipv4_addr(),
878                 ][..],
879                 6,
880                 unreachable_v6_timeout,
881                 true,
882             ),
883             // Slow primary, with (used) fallback.
884             (
885                 &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
886                 6,
887                 fallback_timeout,
888                 false,
889             ),
890             (
891                 &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
892                 4,
893                 fallback_timeout,
894                 true,
895             ),
896             // Slow primary, with (used) unreachable + fast fallback.
897             (
898                 &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
899                 6,
900                 fallback_timeout + unreachable_v6_timeout,
901                 false,
902             ),
903             (
904                 &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
905                 4,
906                 fallback_timeout + unreachable_v4_timeout,
907                 true,
908             ),
909         ];
910 
911         // Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network.
912         // Otherwise, connection to "slow" IPv6 address will error-out immediately.
913         let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
914 
915         for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
916             if needs_ipv6_access && !ipv6_accessible {
917                 continue;
918             }
919 
920             let (start, stream) = rt
921                 .block_on(async move {
922                     let addrs = hosts
923                         .iter()
924                         .map(|host| (host.clone(), addr.port()).into())
925                         .collect();
926                     let cfg = Config {
927                         local_address_ipv4: None,
928                         local_address_ipv6: None,
929                         connect_timeout: None,
930                         keep_alive_timeout: None,
931                         happy_eyeballs_timeout: Some(fallback_timeout),
932                         nodelay: false,
933                         reuse_address: false,
934                         enforce_http: false,
935                         send_buffer_size: None,
936                         recv_buffer_size: None,
937                     };
938                     let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
939                     let start = Instant::now();
940                     Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
941                 })
942                 .unwrap();
943             let res = if stream.peer_addr().unwrap().is_ipv4() {
944                 4
945             } else {
946                 6
947             };
948             let duration = start.elapsed();
949 
950             // Allow actual duration to be +/- 150ms off.
951             let min_duration = if timeout >= Duration::from_millis(150) {
952                 timeout - Duration::from_millis(150)
953             } else {
954                 Duration::default()
955             };
956             let max_duration = timeout + Duration::from_millis(150);
957 
958             assert_eq!(res, family);
959             assert!(duration >= min_duration);
960             assert!(duration <= max_duration);
961         }
962 
963         fn local_ipv4_addr() -> IpAddr {
964             Ipv4Addr::new(127, 0, 0, 1).into()
965         }
966 
967         fn local_ipv6_addr() -> IpAddr {
968             Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
969         }
970 
971         fn unreachable_ipv4_addr() -> IpAddr {
972             Ipv4Addr::new(127, 0, 0, 2).into()
973         }
974 
975         fn unreachable_ipv6_addr() -> IpAddr {
976             Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
977         }
978 
979         fn slow_ipv4_addr() -> IpAddr {
980             // RFC 6890 reserved IPv4 address.
981             Ipv4Addr::new(198, 18, 0, 25).into()
982         }
983 
984         fn slow_ipv6_addr() -> IpAddr {
985             // RFC 6890 reserved IPv6 address.
986             Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
987         }
988 
989         fn measure_connect(addr: IpAddr) -> (bool, Duration) {
990             let start = Instant::now();
991             let result =
992                 std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
993 
994             let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
995             let duration = start.elapsed();
996             (reachable, duration)
997         }
998     }
999 }
1000