1 use std::error::Error as StdError;
2 use std::fmt;
3 use std::future::Future;
4 use std::io;
5 use std::marker::PhantomData;
6 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7 use std::pin::Pin;
8 use std::sync::Arc;
9 use std::task::{self, Poll};
10 use std::time::Duration;
11 
12 use futures_util::future::Either;
13 use http::uri::{Scheme, Uri};
14 use pin_project_lite::pin_project;
15 use tokio::net::{TcpSocket, TcpStream};
16 use tokio::time::Sleep;
17 use tracing::{debug, trace, warn};
18 
19 use super::dns::{self, resolve, GaiResolver, Resolve};
20 use super::{Connected, Connection};
21 //#[cfg(feature = "runtime")] use super::dns::TokioThreadpoolGaiResolver;
22 
23 /// A connector for the `http` scheme.
24 ///
25 /// Performs DNS resolution in a thread pool, and then connects over TCP.
26 ///
27 /// # Note
28 ///
29 /// Sets the [`HttpInfo`](HttpInfo) value on responses, which includes
30 /// transport information such as the remote socket address used.
31 #[cfg_attr(docsrs, doc(cfg(feature = "tcp")))]
32 #[derive(Clone)]
33 pub struct HttpConnector<R = GaiResolver> {
34     config: Arc<Config>,
35     resolver: R,
36 }
37 
38 /// Extra information about the transport when an HttpConnector is used.
39 ///
40 /// # Example
41 ///
42 /// ```
43 /// # async fn doc() -> hyper::Result<()> {
44 /// use hyper::Uri;
45 /// use hyper::client::{Client, connect::HttpInfo};
46 ///
47 /// let client = Client::new();
48 /// let uri = Uri::from_static("http://example.com");
49 ///
50 /// let res = client.get(uri).await?;
51 /// res
52 ///     .extensions()
53 ///     .get::<HttpInfo>()
54 ///     .map(|info| {
55 ///         println!("remote addr = {}", info.remote_addr());
56 ///     });
57 /// # Ok(())
58 /// # }
59 /// ```
60 ///
61 /// # Note
62 ///
63 /// If a different connector is used besides [`HttpConnector`](HttpConnector),
64 /// this value will not exist in the extensions. Consult that specific
65 /// connector to see what "extra" information it might provide to responses.
66 #[derive(Clone, Debug)]
67 pub struct HttpInfo {
68     remote_addr: SocketAddr,
69 }
70 
71 #[derive(Clone)]
72 struct Config {
73     connect_timeout: Option<Duration>,
74     enforce_http: bool,
75     happy_eyeballs_timeout: Option<Duration>,
76     keep_alive_timeout: Option<Duration>,
77     local_address_ipv4: Option<Ipv4Addr>,
78     local_address_ipv6: Option<Ipv6Addr>,
79     nodelay: bool,
80     reuse_address: bool,
81     send_buffer_size: Option<usize>,
82     recv_buffer_size: Option<usize>,
83 }
84 
85 // ===== impl HttpConnector =====
86 
87 impl HttpConnector {
88     /// Construct a new HttpConnector.
new() -> HttpConnector89     pub fn new() -> HttpConnector {
90         HttpConnector::new_with_resolver(GaiResolver::new())
91     }
92 }
93 
94 /*
95 #[cfg(feature = "runtime")]
96 impl HttpConnector<TokioThreadpoolGaiResolver> {
97     /// Construct a new HttpConnector using the `TokioThreadpoolGaiResolver`.
98     ///
99     /// This resolver **requires** the threadpool runtime to be used.
100     pub fn new_with_tokio_threadpool_resolver() -> Self {
101         HttpConnector::new_with_resolver(TokioThreadpoolGaiResolver::new())
102     }
103 }
104 */
105 
106 impl<R> HttpConnector<R> {
107     /// Construct a new HttpConnector.
108     ///
109     /// Takes a [`Resolver`](crate::client::connect::dns#resolvers-are-services) to handle DNS lookups.
new_with_resolver(resolver: R) -> HttpConnector<R>110     pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
111         HttpConnector {
112             config: Arc::new(Config {
113                 connect_timeout: None,
114                 enforce_http: true,
115                 happy_eyeballs_timeout: Some(Duration::from_millis(300)),
116                 keep_alive_timeout: None,
117                 local_address_ipv4: None,
118                 local_address_ipv6: None,
119                 nodelay: false,
120                 reuse_address: false,
121                 send_buffer_size: None,
122                 recv_buffer_size: None,
123             }),
124             resolver,
125         }
126     }
127 
128     /// Option to enforce all `Uri`s have the `http` scheme.
129     ///
130     /// Enabled by default.
131     #[inline]
enforce_http(&mut self, is_enforced: bool)132     pub fn enforce_http(&mut self, is_enforced: bool) {
133         self.config_mut().enforce_http = is_enforced;
134     }
135 
136     /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration.
137     ///
138     /// If `None`, the option will not be set.
139     ///
140     /// Default is `None`.
141     #[inline]
set_keepalive(&mut self, dur: Option<Duration>)142     pub fn set_keepalive(&mut self, dur: Option<Duration>) {
143         self.config_mut().keep_alive_timeout = dur;
144     }
145 
146     /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`.
147     ///
148     /// Default is `false`.
149     #[inline]
set_nodelay(&mut self, nodelay: bool)150     pub fn set_nodelay(&mut self, nodelay: bool) {
151         self.config_mut().nodelay = nodelay;
152     }
153 
154     /// Sets the value of the SO_SNDBUF option on the socket.
155     #[inline]
set_send_buffer_size(&mut self, size: Option<usize>)156     pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
157         self.config_mut().send_buffer_size = size;
158     }
159 
160     /// Sets the value of the SO_RCVBUF option on the socket.
161     #[inline]
set_recv_buffer_size(&mut self, size: Option<usize>)162     pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
163         self.config_mut().recv_buffer_size = size;
164     }
165 
166     /// Set that all sockets are bound to the configured address before connection.
167     ///
168     /// If `None`, the sockets will not be bound.
169     ///
170     /// Default is `None`.
171     #[inline]
set_local_address(&mut self, addr: Option<IpAddr>)172     pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
173         let (v4, v6) = match addr {
174             Some(IpAddr::V4(a)) => (Some(a), None),
175             Some(IpAddr::V6(a)) => (None, Some(a)),
176             _ => (None, None),
177         };
178 
179         let cfg = self.config_mut();
180 
181         cfg.local_address_ipv4 = v4;
182         cfg.local_address_ipv6 = v6;
183     }
184 
185     /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's
186     /// preferences) before connection.
187     #[inline]
set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr)188     pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
189         let cfg = self.config_mut();
190 
191         cfg.local_address_ipv4 = Some(addr_ipv4);
192         cfg.local_address_ipv6 = Some(addr_ipv6);
193     }
194 
195     /// Set the connect timeout.
196     ///
197     /// If a domain resolves to multiple IP addresses, the timeout will be
198     /// evenly divided across them.
199     ///
200     /// Default is `None`.
201     #[inline]
set_connect_timeout(&mut self, dur: Option<Duration>)202     pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
203         self.config_mut().connect_timeout = dur;
204     }
205 
206     /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm.
207     ///
208     /// If hostname resolves to both IPv4 and IPv6 addresses and connection
209     /// cannot be established using preferred address family before timeout
210     /// elapses, then connector will in parallel attempt connection using other
211     /// address family.
212     ///
213     /// If `None`, parallel connection attempts are disabled.
214     ///
215     /// Default is 300 milliseconds.
216     ///
217     /// [RFC 6555]: https://tools.ietf.org/html/rfc6555
218     #[inline]
set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>)219     pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
220         self.config_mut().happy_eyeballs_timeout = dur;
221     }
222 
223     /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`.
224     ///
225     /// Default is `false`.
226     #[inline]
set_reuse_address(&mut self, reuse_address: bool) -> &mut Self227     pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
228         self.config_mut().reuse_address = reuse_address;
229         self
230     }
231 
232     // private
233 
config_mut(&mut self) -> &mut Config234     fn config_mut(&mut self) -> &mut Config {
235         // If the are HttpConnector clones, this will clone the inner
236         // config. So mutating the config won't ever affect previous
237         // clones.
238         Arc::make_mut(&mut self.config)
239     }
240 }
241 
242 static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
243 static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
244 static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
245 
246 // R: Debug required for now to allow adding it to debug output later...
247 impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result248     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249         f.debug_struct("HttpConnector").finish()
250     }
251 }
252 
253 impl<R> tower_service::Service<Uri> for HttpConnector<R>
254 where
255     R: Resolve + Clone + Send + Sync + 'static,
256     R::Future: Send,
257 {
258     type Response = TcpStream;
259     type Error = ConnectError;
260     type Future = HttpConnecting<R>;
261 
poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>262     fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
263         ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
264         Poll::Ready(Ok(()))
265     }
266 
call(&mut self, dst: Uri) -> Self::Future267     fn call(&mut self, dst: Uri) -> Self::Future {
268         let mut self_ = self.clone();
269         HttpConnecting {
270             fut: Box::pin(async move { self_.call_async(dst).await }),
271             _marker: PhantomData,
272         }
273     }
274 }
275 
get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError>276 fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> {
277     trace!(
278         "Http::connect; scheme={:?}, host={:?}, port={:?}",
279         dst.scheme(),
280         dst.host(),
281         dst.port(),
282     );
283 
284     if config.enforce_http {
285         if dst.scheme() != Some(&Scheme::HTTP) {
286             return Err(ConnectError {
287                 msg: INVALID_NOT_HTTP.into(),
288                 cause: None,
289             });
290         }
291     } else if dst.scheme().is_none() {
292         return Err(ConnectError {
293             msg: INVALID_MISSING_SCHEME.into(),
294             cause: None,
295         });
296     }
297 
298     let host = match dst.host() {
299         Some(s) => s,
300         None => {
301             return Err(ConnectError {
302                 msg: INVALID_MISSING_HOST.into(),
303                 cause: None,
304             })
305         }
306     };
307     let port = match dst.port() {
308         Some(port) => port.as_u16(),
309         None => {
310             if dst.scheme() == Some(&Scheme::HTTPS) {
311                 443
312             } else {
313                 80
314             }
315         }
316     };
317 
318     Ok((host, port))
319 }
320 
321 impl<R> HttpConnector<R>
322 where
323     R: Resolve,
324 {
call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError>325     async fn call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError> {
326         let config = &self.config;
327 
328         let (host, port) = get_host_port(config, &dst)?;
329         let host = host.trim_start_matches('[').trim_end_matches(']');
330 
331         // If the host is already an IP addr (v4 or v6),
332         // skip resolving the dns and start connecting right away.
333         let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) {
334             addrs
335         } else {
336             let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
337                 .await
338                 .map_err(ConnectError::dns)?;
339             let addrs = addrs
340                 .map(|mut addr| {
341                     addr.set_port(port);
342                     addr
343                 })
344                 .collect();
345             dns::SocketAddrs::new(addrs)
346         };
347 
348         let c = ConnectingTcp::new(addrs, config);
349 
350         let sock = c.connect().await?;
351 
352         if let Err(e) = sock.set_nodelay(config.nodelay) {
353             warn!("tcp set_nodelay error: {}", e);
354         }
355 
356         Ok(sock)
357     }
358 }
359 
360 impl Connection for TcpStream {
connected(&self) -> Connected361     fn connected(&self) -> Connected {
362         let connected = Connected::new();
363         if let Ok(remote_addr) = self.peer_addr() {
364             connected.extra(HttpInfo { remote_addr })
365         } else {
366             connected
367         }
368     }
369 }
370 
371 impl HttpInfo {
372     /// Get the remote address of the transport used.
remote_addr(&self) -> SocketAddr373     pub fn remote_addr(&self) -> SocketAddr {
374         self.remote_addr
375     }
376 }
377 
378 pin_project! {
379     // Not publicly exported (so missing_docs doesn't trigger).
380     //
381     // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly
382     // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot
383     // (and thus we can change the type in the future).
384     #[must_use = "futures do nothing unless polled"]
385     #[allow(missing_debug_implementations)]
386     pub struct HttpConnecting<R> {
387         #[pin]
388         fut: BoxConnecting,
389         _marker: PhantomData<R>,
390     }
391 }
392 
393 type ConnectResult = Result<TcpStream, ConnectError>;
394 type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
395 
396 impl<R: Resolve> Future for HttpConnecting<R> {
397     type Output = ConnectResult;
398 
poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>399     fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
400         self.project().fut.poll(cx)
401     }
402 }
403 
404 // Not publicly exported (so missing_docs doesn't trigger).
405 pub struct ConnectError {
406     msg: Box<str>,
407     cause: Option<Box<dyn StdError + Send + Sync>>,
408 }
409 
410 impl ConnectError {
new<S, E>(msg: S, cause: E) -> ConnectError where S: Into<Box<str>>, E: Into<Box<dyn StdError + Send + Sync>>,411     fn new<S, E>(msg: S, cause: E) -> ConnectError
412     where
413         S: Into<Box<str>>,
414         E: Into<Box<dyn StdError + Send + Sync>>,
415     {
416         ConnectError {
417             msg: msg.into(),
418             cause: Some(cause.into()),
419         }
420     }
421 
dns<E>(cause: E) -> ConnectError where E: Into<Box<dyn StdError + Send + Sync>>,422     fn dns<E>(cause: E) -> ConnectError
423     where
424         E: Into<Box<dyn StdError + Send + Sync>>,
425     {
426         ConnectError::new("dns error", cause)
427     }
428 
m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError where S: Into<Box<str>>, E: Into<Box<dyn StdError + Send + Sync>>,429     fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError
430     where
431         S: Into<Box<str>>,
432         E: Into<Box<dyn StdError + Send + Sync>>,
433     {
434         move |cause| ConnectError::new(msg, cause)
435     }
436 }
437 
438 impl fmt::Debug for ConnectError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result439     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
440         if let Some(ref cause) = self.cause {
441             f.debug_tuple("ConnectError")
442                 .field(&self.msg)
443                 .field(cause)
444                 .finish()
445         } else {
446             self.msg.fmt(f)
447         }
448     }
449 }
450 
451 impl fmt::Display for ConnectError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result452     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
453         f.write_str(&self.msg)?;
454 
455         if let Some(ref cause) = self.cause {
456             write!(f, ": {}", cause)?;
457         }
458 
459         Ok(())
460     }
461 }
462 
463 impl StdError for ConnectError {
source(&self) -> Option<&(dyn StdError + 'static)>464     fn source(&self) -> Option<&(dyn StdError + 'static)> {
465         self.cause.as_ref().map(|e| &**e as _)
466     }
467 }
468 
469 struct ConnectingTcp<'a> {
470     preferred: ConnectingTcpRemote,
471     fallback: Option<ConnectingTcpFallback>,
472     config: &'a Config,
473 }
474 
475 impl<'a> ConnectingTcp<'a> {
new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self476     fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self {
477         if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
478             let (preferred_addrs, fallback_addrs) = remote_addrs
479                 .split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
480             if fallback_addrs.is_empty() {
481                 return ConnectingTcp {
482                     preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
483                     fallback: None,
484                     config,
485                 };
486             }
487 
488             ConnectingTcp {
489                 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
490                 fallback: Some(ConnectingTcpFallback {
491                     delay: tokio::time::sleep(fallback_timeout),
492                     remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
493                 }),
494                 config,
495             }
496         } else {
497             ConnectingTcp {
498                 preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
499                 fallback: None,
500                 config,
501             }
502         }
503     }
504 }
505 
506 struct ConnectingTcpFallback {
507     delay: Sleep,
508     remote: ConnectingTcpRemote,
509 }
510 
511 struct ConnectingTcpRemote {
512     addrs: dns::SocketAddrs,
513     connect_timeout: Option<Duration>,
514 }
515 
516 impl ConnectingTcpRemote {
new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self517     fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self {
518         let connect_timeout = connect_timeout.map(|t| t / (addrs.len() as u32));
519 
520         Self {
521             addrs,
522             connect_timeout,
523         }
524     }
525 }
526 
527 impl ConnectingTcpRemote {
connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError>528     async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
529         let mut err = None;
530         for addr in &mut self.addrs {
531             debug!("connecting to {}", addr);
532             match connect(&addr, config, self.connect_timeout)?.await {
533                 Ok(tcp) => {
534                     debug!("connected to {}", addr);
535                     return Ok(tcp);
536                 }
537                 Err(e) => {
538                     trace!("connect error for {}: {:?}", addr, e);
539                     err = Some(e);
540                 }
541             }
542         }
543 
544         match err {
545             Some(e) => Err(e),
546             None => Err(ConnectError::new(
547                 "tcp connect error",
548                 std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
549             )),
550         }
551     }
552 }
553 
bind_local_address( socket: &socket2::Socket, dst_addr: &SocketAddr, local_addr_ipv4: &Option<Ipv4Addr>, local_addr_ipv6: &Option<Ipv6Addr>, ) -> io::Result<()>554 fn bind_local_address(
555     socket: &socket2::Socket,
556     dst_addr: &SocketAddr,
557     local_addr_ipv4: &Option<Ipv4Addr>,
558     local_addr_ipv6: &Option<Ipv6Addr>,
559 ) -> io::Result<()> {
560     match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
561         (SocketAddr::V4(_), Some(addr), _) => {
562             socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
563         }
564         (SocketAddr::V6(_), _, Some(addr)) => {
565             socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
566         }
567         _ => {
568             if cfg!(windows) {
569                 // Windows requires a socket be bound before calling connect
570                 let any: SocketAddr = match *dst_addr {
571                     SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
572                     SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
573                 };
574                 socket.bind(&any.into())?;
575             }
576         }
577     }
578 
579     Ok(())
580 }
581 
connect( addr: &SocketAddr, config: &Config, connect_timeout: Option<Duration>, ) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError>582 fn connect(
583     addr: &SocketAddr,
584     config: &Config,
585     connect_timeout: Option<Duration>,
586 ) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
587     // TODO(eliza): if Tokio's `TcpSocket` gains support for setting the
588     // keepalive timeout, it would be nice to use that instead of socket2,
589     // and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance...
590     use socket2::{Domain, Protocol, Socket, TcpKeepalive, Type};
591     use std::convert::TryInto;
592 
593     let domain = Domain::for_address(*addr);
594     let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
595         .map_err(ConnectError::m("tcp open error"))?;
596 
597     // When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is
598     // responsible for ensuring O_NONBLOCK is set.
599     socket
600         .set_nonblocking(true)
601         .map_err(ConnectError::m("tcp set_nonblocking error"))?;
602 
603     if let Some(dur) = config.keep_alive_timeout {
604         let conf = TcpKeepalive::new().with_time(dur);
605         if let Err(e) = socket.set_tcp_keepalive(&conf) {
606             warn!("tcp set_keepalive error: {}", e);
607         }
608     }
609 
610     bind_local_address(
611         &socket,
612         addr,
613         &config.local_address_ipv4,
614         &config.local_address_ipv6,
615     )
616     .map_err(ConnectError::m("tcp bind local error"))?;
617 
618     #[cfg(unix)]
619     let socket = unsafe {
620         // Safety: `from_raw_fd` is only safe to call if ownership of the raw
621         // file descriptor is transferred. Since we call `into_raw_fd` on the
622         // socket2 socket, it gives up ownership of the fd and will not close
623         // it, so this is safe.
624         use std::os::unix::io::{FromRawFd, IntoRawFd};
625         TcpSocket::from_raw_fd(socket.into_raw_fd())
626     };
627     #[cfg(windows)]
628     let socket = unsafe {
629         // Safety: `from_raw_socket` is only safe to call if ownership of the raw
630         // Windows SOCKET is transferred. Since we call `into_raw_socket` on the
631         // socket2 socket, it gives up ownership of the SOCKET and will not close
632         // it, so this is safe.
633         use std::os::windows::io::{FromRawSocket, IntoRawSocket};
634         TcpSocket::from_raw_socket(socket.into_raw_socket())
635     };
636 
637     if config.reuse_address {
638         if let Err(e) = socket.set_reuseaddr(true) {
639             warn!("tcp set_reuse_address error: {}", e);
640         }
641     }
642 
643     if let Some(size) = config.send_buffer_size {
644         if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) {
645             warn!("tcp set_buffer_size error: {}", e);
646         }
647     }
648 
649     if let Some(size) = config.recv_buffer_size {
650         if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) {
651             warn!("tcp set_recv_buffer_size error: {}", e);
652         }
653     }
654 
655     let connect = socket.connect(*addr);
656     Ok(async move {
657         match connect_timeout {
658             Some(dur) => match tokio::time::timeout(dur, connect).await {
659                 Ok(Ok(s)) => Ok(s),
660                 Ok(Err(e)) => Err(e),
661                 Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
662             },
663             None => connect.await,
664         }
665         .map_err(ConnectError::m("tcp connect error"))
666     })
667 }
668 
669 impl ConnectingTcp<'_> {
connect(mut self) -> Result<TcpStream, ConnectError>670     async fn connect(mut self) -> Result<TcpStream, ConnectError> {
671         match self.fallback {
672             None => self.preferred.connect(self.config).await,
673             Some(mut fallback) => {
674                 let preferred_fut = self.preferred.connect(self.config);
675                 futures_util::pin_mut!(preferred_fut);
676 
677                 let fallback_fut = fallback.remote.connect(self.config);
678                 futures_util::pin_mut!(fallback_fut);
679 
680                 let fallback_delay = fallback.delay;
681                 futures_util::pin_mut!(fallback_delay);
682 
683                 let (result, future) =
684                     match futures_util::future::select(preferred_fut, fallback_delay).await {
685                         Either::Left((result, _fallback_delay)) => {
686                             (result, Either::Right(fallback_fut))
687                         }
688                         Either::Right(((), preferred_fut)) => {
689                             // Delay is done, start polling both the preferred and the fallback
690                             futures_util::future::select(preferred_fut, fallback_fut)
691                                 .await
692                                 .factor_first()
693                         }
694                     };
695 
696                 if result.is_err() {
697                     // Fallback to the remaining future (could be preferred or fallback)
698                     // if we get an error
699                     future.await
700                 } else {
701                     result
702                 }
703             }
704         }
705     }
706 }
707 
708 #[cfg(test)]
709 mod tests {
710     use std::io;
711 
712     use ::http::Uri;
713 
714     use super::super::sealed::{Connect, ConnectSvc};
715     use super::{Config, ConnectError, HttpConnector};
716 
connect<C>( connector: C, dst: Uri, ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error> where C: Connect,717     async fn connect<C>(
718         connector: C,
719         dst: Uri,
720     ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
721     where
722         C: Connect,
723     {
724         connector.connect(super::super::sealed::Internal, dst).await
725     }
726 
727     #[tokio::test]
test_errors_enforce_http()728     async fn test_errors_enforce_http() {
729         let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
730         let connector = HttpConnector::new();
731 
732         let err = connect(connector, dst).await.unwrap_err();
733         assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
734     }
735 
736     #[cfg(any(target_os = "linux", target_os = "macos"))]
get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>)737     fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
738         use std::net::{IpAddr, TcpListener};
739 
740         let mut ip_v4 = None;
741         let mut ip_v6 = None;
742 
743         let ips = pnet_datalink::interfaces()
744             .into_iter()
745             .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
746 
747         for ip in ips {
748             match ip {
749                 IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
750                 IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
751                 _ => (),
752             }
753 
754             if ip_v4.is_some() && ip_v6.is_some() {
755                 break;
756             }
757         }
758 
759         (ip_v4, ip_v6)
760     }
761 
762     #[tokio::test]
test_errors_missing_scheme()763     async fn test_errors_missing_scheme() {
764         let dst = "example.domain".parse().unwrap();
765         let mut connector = HttpConnector::new();
766         connector.enforce_http(false);
767 
768         let err = connect(connector, dst).await.unwrap_err();
769         assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
770     }
771 
772     // NOTE: pnet crate that we use in this test doesn't compile on Windows
773     #[cfg(any(target_os = "linux", target_os = "macos"))]
774     #[tokio::test]
local_address()775     async fn local_address() {
776         use std::net::{IpAddr, TcpListener};
777         let _ = pretty_env_logger::try_init();
778 
779         let (bind_ip_v4, bind_ip_v6) = get_local_ips();
780         let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
781         let port = server4.local_addr().unwrap().port();
782         let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
783 
784         let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
785             let mut connector = HttpConnector::new();
786 
787             match (bind_ip_v4, bind_ip_v6) {
788                 (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
789                 (Some(v4), None) => connector.set_local_address(Some(v4.into())),
790                 (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
791                 _ => unreachable!(),
792             }
793 
794             connect(connector, dst.parse().unwrap()).await.unwrap();
795 
796             let (_, client_addr) = server.accept().unwrap();
797 
798             assert_eq!(client_addr.ip(), expected_ip);
799         };
800 
801         if let Some(ip) = bind_ip_v4 {
802             assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await;
803         }
804 
805         if let Some(ip) = bind_ip_v6 {
806             assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await;
807         }
808     }
809 
810     #[test]
811     #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
client_happy_eyeballs()812     fn client_happy_eyeballs() {
813         use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
814         use std::time::{Duration, Instant};
815 
816         use super::dns;
817         use super::ConnectingTcp;
818 
819         let _ = pretty_env_logger::try_init();
820         let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
821         let addr = server4.local_addr().unwrap();
822         let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap();
823         let rt = tokio::runtime::Builder::new_current_thread()
824             .enable_all()
825             .build()
826             .unwrap();
827 
828         let local_timeout = Duration::default();
829         let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
830         let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
831         let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
832             + Duration::from_millis(250);
833 
834         let scenarios = &[
835             // Fast primary, without fallback.
836             (&[local_ipv4_addr()][..], 4, local_timeout, false),
837             (&[local_ipv6_addr()][..], 6, local_timeout, false),
838             // Fast primary, with (unused) fallback.
839             (
840                 &[local_ipv4_addr(), local_ipv6_addr()][..],
841                 4,
842                 local_timeout,
843                 false,
844             ),
845             (
846                 &[local_ipv6_addr(), local_ipv4_addr()][..],
847                 6,
848                 local_timeout,
849                 false,
850             ),
851             // Unreachable + fast primary, without fallback.
852             (
853                 &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
854                 4,
855                 unreachable_v4_timeout,
856                 false,
857             ),
858             (
859                 &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
860                 6,
861                 unreachable_v6_timeout,
862                 false,
863             ),
864             // Unreachable + fast primary, with (unused) fallback.
865             (
866                 &[
867                     unreachable_ipv4_addr(),
868                     local_ipv4_addr(),
869                     local_ipv6_addr(),
870                 ][..],
871                 4,
872                 unreachable_v4_timeout,
873                 false,
874             ),
875             (
876                 &[
877                     unreachable_ipv6_addr(),
878                     local_ipv6_addr(),
879                     local_ipv4_addr(),
880                 ][..],
881                 6,
882                 unreachable_v6_timeout,
883                 true,
884             ),
885             // Slow primary, with (used) fallback.
886             (
887                 &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
888                 6,
889                 fallback_timeout,
890                 false,
891             ),
892             (
893                 &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
894                 4,
895                 fallback_timeout,
896                 true,
897             ),
898             // Slow primary, with (used) unreachable + fast fallback.
899             (
900                 &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
901                 6,
902                 fallback_timeout + unreachable_v6_timeout,
903                 false,
904             ),
905             (
906                 &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
907                 4,
908                 fallback_timeout + unreachable_v4_timeout,
909                 true,
910             ),
911         ];
912 
913         // Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network.
914         // Otherwise, connection to "slow" IPv6 address will error-out immediately.
915         let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
916 
917         for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
918             if needs_ipv6_access && !ipv6_accessible {
919                 continue;
920             }
921 
922             let (start, stream) = rt
923                 .block_on(async move {
924                     let addrs = hosts
925                         .iter()
926                         .map(|host| (host.clone(), addr.port()).into())
927                         .collect();
928                     let cfg = Config {
929                         local_address_ipv4: None,
930                         local_address_ipv6: None,
931                         connect_timeout: None,
932                         keep_alive_timeout: None,
933                         happy_eyeballs_timeout: Some(fallback_timeout),
934                         nodelay: false,
935                         reuse_address: false,
936                         enforce_http: false,
937                         send_buffer_size: None,
938                         recv_buffer_size: None,
939                     };
940                     let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
941                     let start = Instant::now();
942                     Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
943                 })
944                 .unwrap();
945             let res = if stream.peer_addr().unwrap().is_ipv4() {
946                 4
947             } else {
948                 6
949             };
950             let duration = start.elapsed();
951 
952             // Allow actual duration to be +/- 150ms off.
953             let min_duration = if timeout >= Duration::from_millis(150) {
954                 timeout - Duration::from_millis(150)
955             } else {
956                 Duration::default()
957             };
958             let max_duration = timeout + Duration::from_millis(150);
959 
960             assert_eq!(res, family);
961             assert!(duration >= min_duration);
962             assert!(duration <= max_duration);
963         }
964 
965         fn local_ipv4_addr() -> IpAddr {
966             Ipv4Addr::new(127, 0, 0, 1).into()
967         }
968 
969         fn local_ipv6_addr() -> IpAddr {
970             Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
971         }
972 
973         fn unreachable_ipv4_addr() -> IpAddr {
974             Ipv4Addr::new(127, 0, 0, 2).into()
975         }
976 
977         fn unreachable_ipv6_addr() -> IpAddr {
978             Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
979         }
980 
981         fn slow_ipv4_addr() -> IpAddr {
982             // RFC 6890 reserved IPv4 address.
983             Ipv4Addr::new(198, 18, 0, 25).into()
984         }
985 
986         fn slow_ipv6_addr() -> IpAddr {
987             // RFC 6890 reserved IPv6 address.
988             Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
989         }
990 
991         fn measure_connect(addr: IpAddr) -> (bool, Duration) {
992             let start = Instant::now();
993             let result =
994                 std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
995 
996             let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
997             let duration = start.elapsed();
998             (reachable, duration)
999         }
1000     }
1001 }
1002