1 use std::error::Error as StdError;
2 use std::fmt;
3 use std::future::Future;
4 use std::io;
5 use std::marker::PhantomData;
6 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7 use std::pin::Pin;
8 use std::sync::Arc;
9 use std::task::{self, Poll};
10 use std::time::Duration;
11 
12 use futures_util::future::Either;
13 use http::uri::{Scheme, Uri};
14 use pin_project_lite::pin_project;
15 use tokio::net::{TcpSocket, TcpStream};
16 use tokio::time::Sleep;
17 use tracing::{debug, trace, warn};
18 
19 use super::dns::{self, resolve, GaiResolver, Resolve};
20 use super::{Connected, Connection};
21 //#[cfg(feature = "runtime")] use super::dns::TokioThreadpoolGaiResolver;
22 
23 /// A connector for the `http` scheme.
24 ///
25 /// Performs DNS resolution in a thread pool, and then connects over TCP.
26 ///
27 /// # Note
28 ///
29 /// Sets the [`HttpInfo`](HttpInfo) value on responses, which includes
30 /// transport information such as the remote socket address used.
31 #[cfg_attr(docsrs, doc(cfg(feature = "tcp")))]
32 #[derive(Clone)]
33 pub struct HttpConnector<R = GaiResolver> {
34     config: Arc<Config>,
35     resolver: R,
36 }
37 
38 /// Extra information about the transport when an HttpConnector is used.
39 ///
40 /// # Example
41 ///
42 /// ```
43 /// # async fn doc() -> hyper::Result<()> {
44 /// use hyper::Uri;
45 /// use hyper::client::{Client, connect::HttpInfo};
46 ///
47 /// let client = Client::new();
48 /// let uri = Uri::from_static("http://example.com");
49 ///
50 /// let res = client.get(uri).await?;
51 /// res
52 ///     .extensions()
53 ///     .get::<HttpInfo>()
54 ///     .map(|info| {
55 ///         println!("remote addr = {}", info.remote_addr());
56 ///     });
57 /// # Ok(())
58 /// # }
59 /// ```
60 ///
61 /// # Note
62 ///
63 /// If a different connector is used besides [`HttpConnector`](HttpConnector),
64 /// this value will not exist in the extensions. Consult that specific
65 /// connector to see what "extra" information it might provide to responses.
66 #[derive(Clone, Debug)]
67 pub struct HttpInfo {
68     remote_addr: SocketAddr,
69 }
70 
71 #[derive(Clone)]
72 struct Config {
73     connect_timeout: Option<Duration>,
74     enforce_http: bool,
75     happy_eyeballs_timeout: Option<Duration>,
76     keep_alive_timeout: Option<Duration>,
77     local_address_ipv4: Option<Ipv4Addr>,
78     local_address_ipv6: Option<Ipv6Addr>,
79     nodelay: bool,
80     reuse_address: bool,
81     send_buffer_size: Option<usize>,
82     recv_buffer_size: Option<usize>,
83 }
84 
85 // ===== impl HttpConnector =====
86 
87 impl HttpConnector {
88     /// Construct a new HttpConnector.
new() -> HttpConnector89     pub fn new() -> HttpConnector {
90         HttpConnector::new_with_resolver(GaiResolver::new())
91     }
92 }
93 
94 /*
95 #[cfg(feature = "runtime")]
96 impl HttpConnector<TokioThreadpoolGaiResolver> {
97     /// Construct a new HttpConnector using the `TokioThreadpoolGaiResolver`.
98     ///
99     /// This resolver **requires** the threadpool runtime to be used.
100     pub fn new_with_tokio_threadpool_resolver() -> Self {
101         HttpConnector::new_with_resolver(TokioThreadpoolGaiResolver::new())
102     }
103 }
104 */
105 
106 impl<R> HttpConnector<R> {
107     /// Construct a new HttpConnector.
108     ///
109     /// Takes a [`Resolver`](crate::client::connect::dns#resolvers-are-services) to handle DNS lookups.
new_with_resolver(resolver: R) -> HttpConnector<R>110     pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
111         HttpConnector {
112             config: Arc::new(Config {
113                 connect_timeout: None,
114                 enforce_http: true,
115                 happy_eyeballs_timeout: Some(Duration::from_millis(300)),
116                 keep_alive_timeout: None,
117                 local_address_ipv4: None,
118                 local_address_ipv6: None,
119                 nodelay: false,
120                 reuse_address: false,
121                 send_buffer_size: None,
122                 recv_buffer_size: None,
123             }),
124             resolver,
125         }
126     }
127 
128     /// Option to enforce all `Uri`s have the `http` scheme.
129     ///
130     /// Enabled by default.
131     #[inline]
enforce_http(&mut self, is_enforced: bool)132     pub fn enforce_http(&mut self, is_enforced: bool) {
133         self.config_mut().enforce_http = is_enforced;
134     }
135 
136     /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration.
137     ///
138     /// If `None`, the option will not be set.
139     ///
140     /// Default is `None`.
141     #[inline]
set_keepalive(&mut self, dur: Option<Duration>)142     pub fn set_keepalive(&mut self, dur: Option<Duration>) {
143         self.config_mut().keep_alive_timeout = dur;
144     }
145 
146     /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`.
147     ///
148     /// Default is `false`.
149     #[inline]
set_nodelay(&mut self, nodelay: bool)150     pub fn set_nodelay(&mut self, nodelay: bool) {
151         self.config_mut().nodelay = nodelay;
152     }
153 
154     /// Sets the value of the SO_SNDBUF option on the socket.
155     #[inline]
set_send_buffer_size(&mut self, size: Option<usize>)156     pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
157         self.config_mut().send_buffer_size = size;
158     }
159 
160     /// Sets the value of the SO_RCVBUF option on the socket.
161     #[inline]
set_recv_buffer_size(&mut self, size: Option<usize>)162     pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
163         self.config_mut().recv_buffer_size = size;
164     }
165 
166     /// Set that all sockets are bound to the configured address before connection.
167     ///
168     /// If `None`, the sockets will not be bound.
169     ///
170     /// Default is `None`.
171     #[inline]
set_local_address(&mut self, addr: Option<IpAddr>)172     pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
173         let (v4, v6) = match addr {
174             Some(IpAddr::V4(a)) => (Some(a), None),
175             Some(IpAddr::V6(a)) => (None, Some(a)),
176             _ => (None, None),
177         };
178 
179         let cfg = self.config_mut();
180 
181         cfg.local_address_ipv4 = v4;
182         cfg.local_address_ipv6 = v6;
183     }
184 
185     /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's
186     /// preferences) before connection.
187     #[inline]
set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr)188     pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
189         let cfg = self.config_mut();
190 
191         cfg.local_address_ipv4 = Some(addr_ipv4);
192         cfg.local_address_ipv6 = Some(addr_ipv6);
193     }
194 
195     /// Set the connect timeout.
196     ///
197     /// If a domain resolves to multiple IP addresses, the timeout will be
198     /// evenly divided across them.
199     ///
200     /// Default is `None`.
201     #[inline]
set_connect_timeout(&mut self, dur: Option<Duration>)202     pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
203         self.config_mut().connect_timeout = dur;
204     }
205 
206     /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm.
207     ///
208     /// If hostname resolves to both IPv4 and IPv6 addresses and connection
209     /// cannot be established using preferred address family before timeout
210     /// elapses, then connector will in parallel attempt connection using other
211     /// address family.
212     ///
213     /// If `None`, parallel connection attempts are disabled.
214     ///
215     /// Default is 300 milliseconds.
216     ///
217     /// [RFC 6555]: https://tools.ietf.org/html/rfc6555
218     #[inline]
set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>)219     pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
220         self.config_mut().happy_eyeballs_timeout = dur;
221     }
222 
223     /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`.
224     ///
225     /// Default is `false`.
226     #[inline]
set_reuse_address(&mut self, reuse_address: bool) -> &mut Self227     pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
228         self.config_mut().reuse_address = reuse_address;
229         self
230     }
231 
232     // private
233 
config_mut(&mut self) -> &mut Config234     fn config_mut(&mut self) -> &mut Config {
235         // If the are HttpConnector clones, this will clone the inner
236         // config. So mutating the config won't ever affect previous
237         // clones.
238         Arc::make_mut(&mut self.config)
239     }
240 }
241 
242 static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
243 static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
244 static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
245 
246 // R: Debug required for now to allow adding it to debug output later...
247 impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result248     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249         f.debug_struct("HttpConnector").finish()
250     }
251 }
252 
253 impl<R> tower_service::Service<Uri> for HttpConnector<R>
254 where
255     R: Resolve + Clone + Send + Sync + 'static,
256     R::Future: Send,
257 {
258     type Response = TcpStream;
259     type Error = ConnectError;
260     type Future = HttpConnecting<R>;
261 
poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>262     fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
263         ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
264         Poll::Ready(Ok(()))
265     }
266 
call(&mut self, dst: Uri) -> Self::Future267     fn call(&mut self, dst: Uri) -> Self::Future {
268         let mut self_ = self.clone();
269         HttpConnecting {
270             fut: Box::pin(async move { self_.call_async(dst).await }),
271             _marker: PhantomData,
272         }
273     }
274 }
275 
get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError>276 fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> {
277     trace!(
278         "Http::connect; scheme={:?}, host={:?}, port={:?}",
279         dst.scheme(),
280         dst.host(),
281         dst.port(),
282     );
283 
284     if config.enforce_http {
285         if dst.scheme() != Some(&Scheme::HTTP) {
286             return Err(ConnectError {
287                 msg: INVALID_NOT_HTTP.into(),
288                 cause: None,
289             });
290         }
291     } else if dst.scheme().is_none() {
292         return Err(ConnectError {
293             msg: INVALID_MISSING_SCHEME.into(),
294             cause: None,
295         });
296     }
297 
298     let host = match dst.host() {
299         Some(s) => s,
300         None => {
301             return Err(ConnectError {
302                 msg: INVALID_MISSING_HOST.into(),
303                 cause: None,
304             })
305         }
306     };
307     let port = match dst.port() {
308         Some(port) => port.as_u16(),
309         None => {
310             if dst.scheme() == Some(&Scheme::HTTPS) {
311                 443
312             } else {
313                 80
314             }
315         }
316     };
317 
318     Ok((host, port))
319 }
320 
321 impl<R> HttpConnector<R>
322 where
323     R: Resolve,
324 {
call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError>325     async fn call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError> {
326         let config = &self.config;
327 
328         let (host, port) = get_host_port(config, &dst)?;
329 
330         // If the host is already an IP addr (v4 or v6),
331         // skip resolving the dns and start connecting right away.
332         let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) {
333             addrs
334         } else {
335             let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
336                 .await
337                 .map_err(ConnectError::dns)?;
338             let addrs = addrs
339                 .map(|mut addr| {
340                     addr.set_port(port);
341                     addr
342                 })
343                 .collect();
344             dns::SocketAddrs::new(addrs)
345         };
346 
347         let c = ConnectingTcp::new(addrs, config);
348 
349         let sock = c.connect().await?;
350 
351         if let Err(e) = sock.set_nodelay(config.nodelay) {
352             warn!("tcp set_nodelay error: {}", e);
353         }
354 
355         Ok(sock)
356     }
357 }
358 
359 impl Connection for TcpStream {
connected(&self) -> Connected360     fn connected(&self) -> Connected {
361         let connected = Connected::new();
362         if let Ok(remote_addr) = self.peer_addr() {
363             connected.extra(HttpInfo { remote_addr })
364         } else {
365             connected
366         }
367     }
368 }
369 
370 impl HttpInfo {
371     /// Get the remote address of the transport used.
remote_addr(&self) -> SocketAddr372     pub fn remote_addr(&self) -> SocketAddr {
373         self.remote_addr
374     }
375 }
376 
377 pin_project! {
378     // Not publicly exported (so missing_docs doesn't trigger).
379     //
380     // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly
381     // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot
382     // (and thus we can change the type in the future).
383     #[must_use = "futures do nothing unless polled"]
384     #[allow(missing_debug_implementations)]
385     pub struct HttpConnecting<R> {
386         #[pin]
387         fut: BoxConnecting,
388         _marker: PhantomData<R>,
389     }
390 }
391 
392 type ConnectResult = Result<TcpStream, ConnectError>;
393 type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
394 
395 impl<R: Resolve> Future for HttpConnecting<R> {
396     type Output = ConnectResult;
397 
poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>398     fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
399         self.project().fut.poll(cx)
400     }
401 }
402 
403 // Not publicly exported (so missing_docs doesn't trigger).
404 pub struct ConnectError {
405     msg: Box<str>,
406     cause: Option<Box<dyn StdError + Send + Sync>>,
407 }
408 
409 impl ConnectError {
new<S, E>(msg: S, cause: E) -> ConnectError where S: Into<Box<str>>, E: Into<Box<dyn StdError + Send + Sync>>,410     fn new<S, E>(msg: S, cause: E) -> ConnectError
411     where
412         S: Into<Box<str>>,
413         E: Into<Box<dyn StdError + Send + Sync>>,
414     {
415         ConnectError {
416             msg: msg.into(),
417             cause: Some(cause.into()),
418         }
419     }
420 
dns<E>(cause: E) -> ConnectError where E: Into<Box<dyn StdError + Send + Sync>>,421     fn dns<E>(cause: E) -> ConnectError
422     where
423         E: Into<Box<dyn StdError + Send + Sync>>,
424     {
425         ConnectError::new("dns error", cause)
426     }
427 
m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError where S: Into<Box<str>>, E: Into<Box<dyn StdError + Send + Sync>>,428     fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError
429     where
430         S: Into<Box<str>>,
431         E: Into<Box<dyn StdError + Send + Sync>>,
432     {
433         move |cause| ConnectError::new(msg, cause)
434     }
435 }
436 
437 impl fmt::Debug for ConnectError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result438     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
439         if let Some(ref cause) = self.cause {
440             f.debug_tuple("ConnectError")
441                 .field(&self.msg)
442                 .field(cause)
443                 .finish()
444         } else {
445             self.msg.fmt(f)
446         }
447     }
448 }
449 
450 impl fmt::Display for ConnectError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result451     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
452         f.write_str(&self.msg)?;
453 
454         if let Some(ref cause) = self.cause {
455             write!(f, ": {}", cause)?;
456         }
457 
458         Ok(())
459     }
460 }
461 
462 impl StdError for ConnectError {
source(&self) -> Option<&(dyn StdError + 'static)>463     fn source(&self) -> Option<&(dyn StdError + 'static)> {
464         self.cause.as_ref().map(|e| &**e as _)
465     }
466 }
467 
468 struct ConnectingTcp<'a> {
469     preferred: ConnectingTcpRemote,
470     fallback: Option<ConnectingTcpFallback>,
471     config: &'a Config,
472 }
473 
474 impl<'a> ConnectingTcp<'a> {
new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self475     fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self {
476         if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
477             let (preferred_addrs, fallback_addrs) = remote_addrs
478                 .split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
479             if fallback_addrs.is_empty() {
480                 return ConnectingTcp {
481                     preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
482                     fallback: None,
483                     config,
484                 };
485             }
486 
487             ConnectingTcp {
488                 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
489                 fallback: Some(ConnectingTcpFallback {
490                     delay: tokio::time::sleep(fallback_timeout),
491                     remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
492                 }),
493                 config,
494             }
495         } else {
496             ConnectingTcp {
497                 preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
498                 fallback: None,
499                 config,
500             }
501         }
502     }
503 }
504 
505 struct ConnectingTcpFallback {
506     delay: Sleep,
507     remote: ConnectingTcpRemote,
508 }
509 
510 struct ConnectingTcpRemote {
511     addrs: dns::SocketAddrs,
512     connect_timeout: Option<Duration>,
513 }
514 
515 impl ConnectingTcpRemote {
new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self516     fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self {
517         let connect_timeout = connect_timeout.map(|t| t / (addrs.len() as u32));
518 
519         Self {
520             addrs,
521             connect_timeout,
522         }
523     }
524 }
525 
526 impl ConnectingTcpRemote {
connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError>527     async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
528         let mut err = None;
529         for addr in &mut self.addrs {
530             debug!("connecting to {}", addr);
531             match connect(&addr, config, self.connect_timeout)?.await {
532                 Ok(tcp) => {
533                     debug!("connected to {}", addr);
534                     return Ok(tcp);
535                 }
536                 Err(e) => {
537                     trace!("connect error for {}: {:?}", addr, e);
538                     err = Some(e);
539                 }
540             }
541         }
542 
543         match err {
544             Some(e) => Err(e),
545             None => Err(ConnectError::new(
546                 "tcp connect error",
547                 std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
548             )),
549         }
550     }
551 }
552 
bind_local_address( socket: &socket2::Socket, dst_addr: &SocketAddr, local_addr_ipv4: &Option<Ipv4Addr>, local_addr_ipv6: &Option<Ipv6Addr>, ) -> io::Result<()>553 fn bind_local_address(
554     socket: &socket2::Socket,
555     dst_addr: &SocketAddr,
556     local_addr_ipv4: &Option<Ipv4Addr>,
557     local_addr_ipv6: &Option<Ipv6Addr>,
558 ) -> io::Result<()> {
559     match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
560         (SocketAddr::V4(_), Some(addr), _) => {
561             socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
562         }
563         (SocketAddr::V6(_), _, Some(addr)) => {
564             socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
565         }
566         _ => {
567             if cfg!(windows) {
568                 // Windows requires a socket be bound before calling connect
569                 let any: SocketAddr = match *dst_addr {
570                     SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
571                     SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
572                 };
573                 socket.bind(&any.into())?;
574             }
575         }
576     }
577 
578     Ok(())
579 }
580 
connect( addr: &SocketAddr, config: &Config, connect_timeout: Option<Duration>, ) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError>581 fn connect(
582     addr: &SocketAddr,
583     config: &Config,
584     connect_timeout: Option<Duration>,
585 ) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
586     // TODO(eliza): if Tokio's `TcpSocket` gains support for setting the
587     // keepalive timeout, it would be nice to use that instead of socket2,
588     // and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance...
589     use socket2::{Domain, Protocol, Socket, TcpKeepalive, Type};
590     use std::convert::TryInto;
591 
592     let domain = Domain::for_address(*addr);
593     let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
594         .map_err(ConnectError::m("tcp open error"))?;
595 
596     // When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is
597     // responsible for ensuring O_NONBLOCK is set.
598     socket
599         .set_nonblocking(true)
600         .map_err(ConnectError::m("tcp set_nonblocking error"))?;
601 
602     if let Some(dur) = config.keep_alive_timeout {
603         let conf = TcpKeepalive::new().with_time(dur);
604         if let Err(e) = socket.set_tcp_keepalive(&conf) {
605             warn!("tcp set_keepalive error: {}", e);
606         }
607     }
608 
609     bind_local_address(
610         &socket,
611         addr,
612         &config.local_address_ipv4,
613         &config.local_address_ipv6,
614     )
615     .map_err(ConnectError::m("tcp bind local error"))?;
616 
617     #[cfg(unix)]
618     let socket = unsafe {
619         // Safety: `from_raw_fd` is only safe to call if ownership of the raw
620         // file descriptor is transferred. Since we call `into_raw_fd` on the
621         // socket2 socket, it gives up ownership of the fd and will not close
622         // it, so this is safe.
623         use std::os::unix::io::{FromRawFd, IntoRawFd};
624         TcpSocket::from_raw_fd(socket.into_raw_fd())
625     };
626     #[cfg(windows)]
627     let socket = unsafe {
628         // Safety: `from_raw_socket` is only safe to call if ownership of the raw
629         // Windows SOCKET is transferred. Since we call `into_raw_socket` on the
630         // socket2 socket, it gives up ownership of the SOCKET and will not close
631         // it, so this is safe.
632         use std::os::windows::io::{FromRawSocket, IntoRawSocket};
633         TcpSocket::from_raw_socket(socket.into_raw_socket())
634     };
635 
636     if config.reuse_address {
637         if let Err(e) = socket.set_reuseaddr(true) {
638             warn!("tcp set_reuse_address error: {}", e);
639         }
640     }
641 
642     if let Some(size) = config.send_buffer_size {
643         if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) {
644             warn!("tcp set_buffer_size error: {}", e);
645         }
646     }
647 
648     if let Some(size) = config.recv_buffer_size {
649         if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) {
650             warn!("tcp set_recv_buffer_size error: {}", e);
651         }
652     }
653 
654     let connect = socket.connect(*addr);
655     Ok(async move {
656         match connect_timeout {
657             Some(dur) => match tokio::time::timeout(dur, connect).await {
658                 Ok(Ok(s)) => Ok(s),
659                 Ok(Err(e)) => Err(e),
660                 Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
661             },
662             None => connect.await,
663         }
664         .map_err(ConnectError::m("tcp connect error"))
665     })
666 }
667 
668 impl ConnectingTcp<'_> {
connect(mut self) -> Result<TcpStream, ConnectError>669     async fn connect(mut self) -> Result<TcpStream, ConnectError> {
670         match self.fallback {
671             None => self.preferred.connect(self.config).await,
672             Some(mut fallback) => {
673                 let preferred_fut = self.preferred.connect(self.config);
674                 futures_util::pin_mut!(preferred_fut);
675 
676                 let fallback_fut = fallback.remote.connect(self.config);
677                 futures_util::pin_mut!(fallback_fut);
678 
679                 let fallback_delay = fallback.delay;
680                 futures_util::pin_mut!(fallback_delay);
681 
682                 let (result, future) =
683                     match futures_util::future::select(preferred_fut, fallback_delay).await {
684                         Either::Left((result, _fallback_delay)) => {
685                             (result, Either::Right(fallback_fut))
686                         }
687                         Either::Right(((), preferred_fut)) => {
688                             // Delay is done, start polling both the preferred and the fallback
689                             futures_util::future::select(preferred_fut, fallback_fut)
690                                 .await
691                                 .factor_first()
692                         }
693                     };
694 
695                 if result.is_err() {
696                     // Fallback to the remaining future (could be preferred or fallback)
697                     // if we get an error
698                     future.await
699                 } else {
700                     result
701                 }
702             }
703         }
704     }
705 }
706 
707 #[cfg(test)]
708 mod tests {
709     use std::io;
710 
711     use ::http::Uri;
712 
713     use super::super::sealed::{Connect, ConnectSvc};
714     use super::{Config, ConnectError, HttpConnector};
715 
connect<C>( connector: C, dst: Uri, ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error> where C: Connect,716     async fn connect<C>(
717         connector: C,
718         dst: Uri,
719     ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
720     where
721         C: Connect,
722     {
723         connector.connect(super::super::sealed::Internal, dst).await
724     }
725 
726     #[tokio::test]
test_errors_enforce_http()727     async fn test_errors_enforce_http() {
728         let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
729         let connector = HttpConnector::new();
730 
731         let err = connect(connector, dst).await.unwrap_err();
732         assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
733     }
734 
735     #[cfg(any(target_os = "linux", target_os = "macos"))]
get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>)736     fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
737         use std::net::{IpAddr, TcpListener};
738 
739         let mut ip_v4 = None;
740         let mut ip_v6 = None;
741 
742         let ips = pnet_datalink::interfaces()
743             .into_iter()
744             .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
745 
746         for ip in ips {
747             match ip {
748                 IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
749                 IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
750                 _ => (),
751             }
752 
753             if ip_v4.is_some() && ip_v6.is_some() {
754                 break;
755             }
756         }
757 
758         (ip_v4, ip_v6)
759     }
760 
761     #[tokio::test]
test_errors_missing_scheme()762     async fn test_errors_missing_scheme() {
763         let dst = "example.domain".parse().unwrap();
764         let mut connector = HttpConnector::new();
765         connector.enforce_http(false);
766 
767         let err = connect(connector, dst).await.unwrap_err();
768         assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
769     }
770 
771     // NOTE: pnet crate that we use in this test doesn't compile on Windows
772     #[cfg(any(target_os = "linux", target_os = "macos"))]
773     #[tokio::test]
local_address()774     async fn local_address() {
775         use std::net::{IpAddr, TcpListener};
776         let _ = pretty_env_logger::try_init();
777 
778         let (bind_ip_v4, bind_ip_v6) = get_local_ips();
779         let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
780         let port = server4.local_addr().unwrap().port();
781         let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
782 
783         let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
784             let mut connector = HttpConnector::new();
785 
786             match (bind_ip_v4, bind_ip_v6) {
787                 (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
788                 (Some(v4), None) => connector.set_local_address(Some(v4.into())),
789                 (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
790                 _ => unreachable!(),
791             }
792 
793             connect(connector, dst.parse().unwrap()).await.unwrap();
794 
795             let (_, client_addr) = server.accept().unwrap();
796 
797             assert_eq!(client_addr.ip(), expected_ip);
798         };
799 
800         if let Some(ip) = bind_ip_v4 {
801             assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await;
802         }
803 
804         if let Some(ip) = bind_ip_v6 {
805             assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await;
806         }
807     }
808 
809     #[test]
810     #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
client_happy_eyeballs()811     fn client_happy_eyeballs() {
812         use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
813         use std::time::{Duration, Instant};
814 
815         use super::dns;
816         use super::ConnectingTcp;
817 
818         let _ = pretty_env_logger::try_init();
819         let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
820         let addr = server4.local_addr().unwrap();
821         let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap();
822         let rt = tokio::runtime::Builder::new_current_thread()
823             .enable_all()
824             .build()
825             .unwrap();
826 
827         let local_timeout = Duration::default();
828         let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
829         let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
830         let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
831             + Duration::from_millis(250);
832 
833         let scenarios = &[
834             // Fast primary, without fallback.
835             (&[local_ipv4_addr()][..], 4, local_timeout, false),
836             (&[local_ipv6_addr()][..], 6, local_timeout, false),
837             // Fast primary, with (unused) fallback.
838             (
839                 &[local_ipv4_addr(), local_ipv6_addr()][..],
840                 4,
841                 local_timeout,
842                 false,
843             ),
844             (
845                 &[local_ipv6_addr(), local_ipv4_addr()][..],
846                 6,
847                 local_timeout,
848                 false,
849             ),
850             // Unreachable + fast primary, without fallback.
851             (
852                 &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
853                 4,
854                 unreachable_v4_timeout,
855                 false,
856             ),
857             (
858                 &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
859                 6,
860                 unreachable_v6_timeout,
861                 false,
862             ),
863             // Unreachable + fast primary, with (unused) fallback.
864             (
865                 &[
866                     unreachable_ipv4_addr(),
867                     local_ipv4_addr(),
868                     local_ipv6_addr(),
869                 ][..],
870                 4,
871                 unreachable_v4_timeout,
872                 false,
873             ),
874             (
875                 &[
876                     unreachable_ipv6_addr(),
877                     local_ipv6_addr(),
878                     local_ipv4_addr(),
879                 ][..],
880                 6,
881                 unreachable_v6_timeout,
882                 true,
883             ),
884             // Slow primary, with (used) fallback.
885             (
886                 &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
887                 6,
888                 fallback_timeout,
889                 false,
890             ),
891             (
892                 &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
893                 4,
894                 fallback_timeout,
895                 true,
896             ),
897             // Slow primary, with (used) unreachable + fast fallback.
898             (
899                 &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
900                 6,
901                 fallback_timeout + unreachable_v6_timeout,
902                 false,
903             ),
904             (
905                 &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
906                 4,
907                 fallback_timeout + unreachable_v4_timeout,
908                 true,
909             ),
910         ];
911 
912         // Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network.
913         // Otherwise, connection to "slow" IPv6 address will error-out immediately.
914         let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
915 
916         for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
917             if needs_ipv6_access && !ipv6_accessible {
918                 continue;
919             }
920 
921             let (start, stream) = rt
922                 .block_on(async move {
923                     let addrs = hosts
924                         .iter()
925                         .map(|host| (host.clone(), addr.port()).into())
926                         .collect();
927                     let cfg = Config {
928                         local_address_ipv4: None,
929                         local_address_ipv6: None,
930                         connect_timeout: None,
931                         keep_alive_timeout: None,
932                         happy_eyeballs_timeout: Some(fallback_timeout),
933                         nodelay: false,
934                         reuse_address: false,
935                         enforce_http: false,
936                         send_buffer_size: None,
937                         recv_buffer_size: None,
938                     };
939                     let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
940                     let start = Instant::now();
941                     Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
942                 })
943                 .unwrap();
944             let res = if stream.peer_addr().unwrap().is_ipv4() {
945                 4
946             } else {
947                 6
948             };
949             let duration = start.elapsed();
950 
951             // Allow actual duration to be +/- 150ms off.
952             let min_duration = if timeout >= Duration::from_millis(150) {
953                 timeout - Duration::from_millis(150)
954             } else {
955                 Duration::default()
956             };
957             let max_duration = timeout + Duration::from_millis(150);
958 
959             assert_eq!(res, family);
960             assert!(duration >= min_duration);
961             assert!(duration <= max_duration);
962         }
963 
964         fn local_ipv4_addr() -> IpAddr {
965             Ipv4Addr::new(127, 0, 0, 1).into()
966         }
967 
968         fn local_ipv6_addr() -> IpAddr {
969             Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
970         }
971 
972         fn unreachable_ipv4_addr() -> IpAddr {
973             Ipv4Addr::new(127, 0, 0, 2).into()
974         }
975 
976         fn unreachable_ipv6_addr() -> IpAddr {
977             Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
978         }
979 
980         fn slow_ipv4_addr() -> IpAddr {
981             // RFC 6890 reserved IPv4 address.
982             Ipv4Addr::new(198, 18, 0, 25).into()
983         }
984 
985         fn slow_ipv6_addr() -> IpAddr {
986             // RFC 6890 reserved IPv6 address.
987             Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
988         }
989 
990         fn measure_connect(addr: IpAddr) -> (bool, Duration) {
991             let start = Instant::now();
992             let result =
993                 std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
994 
995             let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
996             let duration = start.elapsed();
997             (reachable, duration)
998         }
999     }
1000 }
1001