1 use std::error::Error as StdError;
2 use std::fmt;
3 use std::future::Future;
4 use std::io;
5 use std::marker::PhantomData;
6 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7 use std::pin::Pin;
8 use std::sync::Arc;
9 use std::task::{self, Poll};
10 use std::time::Duration;
11
12 use futures_util::future::Either;
13 use http::uri::{Scheme, Uri};
14 use pin_project_lite::pin_project;
15 use tokio::net::{TcpSocket, TcpStream};
16 use tokio::time::Sleep;
17
18 use super::dns::{self, resolve, GaiResolver, Resolve};
19 use super::{Connected, Connection};
20 //#[cfg(feature = "runtime")] use super::dns::TokioThreadpoolGaiResolver;
21
22 /// A connector for the `http` scheme.
23 ///
24 /// Performs DNS resolution in a thread pool, and then connects over TCP.
25 ///
26 /// # Note
27 ///
28 /// Sets the [`HttpInfo`](HttpInfo) value on responses, which includes
29 /// transport information such as the remote socket address used.
30 #[cfg_attr(docsrs, doc(cfg(feature = "tcp")))]
31 #[derive(Clone)]
32 pub struct HttpConnector<R = GaiResolver> {
33 config: Arc<Config>,
34 resolver: R,
35 }
36
37 /// Extra information about the transport when an HttpConnector is used.
38 ///
39 /// # Example
40 ///
41 /// ```
42 /// # async fn doc() -> hyper::Result<()> {
43 /// use hyper::Uri;
44 /// use hyper::client::{Client, connect::HttpInfo};
45 ///
46 /// let client = Client::new();
47 /// let uri = Uri::from_static("http://example.com");
48 ///
49 /// let res = client.get(uri).await?;
50 /// res
51 /// .extensions()
52 /// .get::<HttpInfo>()
53 /// .map(|info| {
54 /// println!("remote addr = {}", info.remote_addr());
55 /// });
56 /// # Ok(())
57 /// # }
58 /// ```
59 ///
60 /// # Note
61 ///
62 /// If a different connector is used besides [`HttpConnector`](HttpConnector),
63 /// this value will not exist in the extensions. Consult that specific
64 /// connector to see what "extra" information it might provide to responses.
65 #[derive(Clone, Debug)]
66 pub struct HttpInfo {
67 remote_addr: SocketAddr,
68 }
69
70 #[derive(Clone)]
71 struct Config {
72 connect_timeout: Option<Duration>,
73 enforce_http: bool,
74 happy_eyeballs_timeout: Option<Duration>,
75 keep_alive_timeout: Option<Duration>,
76 local_address_ipv4: Option<Ipv4Addr>,
77 local_address_ipv6: Option<Ipv6Addr>,
78 nodelay: bool,
79 reuse_address: bool,
80 send_buffer_size: Option<usize>,
81 recv_buffer_size: Option<usize>,
82 }
83
84 // ===== impl HttpConnector =====
85
86 impl HttpConnector {
87 /// Construct a new HttpConnector.
new() -> HttpConnector88 pub fn new() -> HttpConnector {
89 HttpConnector::new_with_resolver(GaiResolver::new())
90 }
91 }
92
93 /*
94 #[cfg(feature = "runtime")]
95 impl HttpConnector<TokioThreadpoolGaiResolver> {
96 /// Construct a new HttpConnector using the `TokioThreadpoolGaiResolver`.
97 ///
98 /// This resolver **requires** the threadpool runtime to be used.
99 pub fn new_with_tokio_threadpool_resolver() -> Self {
100 HttpConnector::new_with_resolver(TokioThreadpoolGaiResolver::new())
101 }
102 }
103 */
104
105 impl<R> HttpConnector<R> {
106 /// Construct a new HttpConnector.
107 ///
108 /// Takes a [`Resolver`](crate::client::connect::dns#resolvers-are-services) to handle DNS lookups.
new_with_resolver(resolver: R) -> HttpConnector<R>109 pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
110 HttpConnector {
111 config: Arc::new(Config {
112 connect_timeout: None,
113 enforce_http: true,
114 happy_eyeballs_timeout: Some(Duration::from_millis(300)),
115 keep_alive_timeout: None,
116 local_address_ipv4: None,
117 local_address_ipv6: None,
118 nodelay: false,
119 reuse_address: false,
120 send_buffer_size: None,
121 recv_buffer_size: None,
122 }),
123 resolver,
124 }
125 }
126
127 /// Option to enforce all `Uri`s have the `http` scheme.
128 ///
129 /// Enabled by default.
130 #[inline]
enforce_http(&mut self, is_enforced: bool)131 pub fn enforce_http(&mut self, is_enforced: bool) {
132 self.config_mut().enforce_http = is_enforced;
133 }
134
135 /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration.
136 ///
137 /// If `None`, the option will not be set.
138 ///
139 /// Default is `None`.
140 #[inline]
set_keepalive(&mut self, dur: Option<Duration>)141 pub fn set_keepalive(&mut self, dur: Option<Duration>) {
142 self.config_mut().keep_alive_timeout = dur;
143 }
144
145 /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`.
146 ///
147 /// Default is `false`.
148 #[inline]
set_nodelay(&mut self, nodelay: bool)149 pub fn set_nodelay(&mut self, nodelay: bool) {
150 self.config_mut().nodelay = nodelay;
151 }
152
153 /// Sets the value of the SO_SNDBUF option on the socket.
154 #[inline]
set_send_buffer_size(&mut self, size: Option<usize>)155 pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
156 self.config_mut().send_buffer_size = size;
157 }
158
159 /// Sets the value of the SO_RCVBUF option on the socket.
160 #[inline]
set_recv_buffer_size(&mut self, size: Option<usize>)161 pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
162 self.config_mut().recv_buffer_size = size;
163 }
164
165 /// Set that all sockets are bound to the configured address before connection.
166 ///
167 /// If `None`, the sockets will not be bound.
168 ///
169 /// Default is `None`.
170 #[inline]
set_local_address(&mut self, addr: Option<IpAddr>)171 pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
172 let (v4, v6) = match addr {
173 Some(IpAddr::V4(a)) => (Some(a), None),
174 Some(IpAddr::V6(a)) => (None, Some(a)),
175 _ => (None, None),
176 };
177
178 let cfg = self.config_mut();
179
180 cfg.local_address_ipv4 = v4;
181 cfg.local_address_ipv6 = v6;
182 }
183
184 /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's
185 /// preferences) before connection.
186 #[inline]
set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr)187 pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
188 let cfg = self.config_mut();
189
190 cfg.local_address_ipv4 = Some(addr_ipv4);
191 cfg.local_address_ipv6 = Some(addr_ipv6);
192 }
193
194 /// Set the connect timeout.
195 ///
196 /// If a domain resolves to multiple IP addresses, the timeout will be
197 /// evenly divided across them.
198 ///
199 /// Default is `None`.
200 #[inline]
set_connect_timeout(&mut self, dur: Option<Duration>)201 pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
202 self.config_mut().connect_timeout = dur;
203 }
204
205 /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm.
206 ///
207 /// If hostname resolves to both IPv4 and IPv6 addresses and connection
208 /// cannot be established using preferred address family before timeout
209 /// elapses, then connector will in parallel attempt connection using other
210 /// address family.
211 ///
212 /// If `None`, parallel connection attempts are disabled.
213 ///
214 /// Default is 300 milliseconds.
215 ///
216 /// [RFC 6555]: https://tools.ietf.org/html/rfc6555
217 #[inline]
set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>)218 pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
219 self.config_mut().happy_eyeballs_timeout = dur;
220 }
221
222 /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`.
223 ///
224 /// Default is `false`.
225 #[inline]
set_reuse_address(&mut self, reuse_address: bool) -> &mut Self226 pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
227 self.config_mut().reuse_address = reuse_address;
228 self
229 }
230
231 // private
232
config_mut(&mut self) -> &mut Config233 fn config_mut(&mut self) -> &mut Config {
234 // If the are HttpConnector clones, this will clone the inner
235 // config. So mutating the config won't ever affect previous
236 // clones.
237 Arc::make_mut(&mut self.config)
238 }
239 }
240
241 static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
242 static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
243 static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
244
245 // R: Debug required for now to allow adding it to debug output later...
246 impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result247 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248 f.debug_struct("HttpConnector").finish()
249 }
250 }
251
252 impl<R> tower_service::Service<Uri> for HttpConnector<R>
253 where
254 R: Resolve + Clone + Send + Sync + 'static,
255 R::Future: Send,
256 {
257 type Response = TcpStream;
258 type Error = ConnectError;
259 type Future = HttpConnecting<R>;
260
poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>261 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
262 ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
263 Poll::Ready(Ok(()))
264 }
265
call(&mut self, dst: Uri) -> Self::Future266 fn call(&mut self, dst: Uri) -> Self::Future {
267 let mut self_ = self.clone();
268 HttpConnecting {
269 fut: Box::pin(async move { self_.call_async(dst).await }),
270 _marker: PhantomData,
271 }
272 }
273 }
274
get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError>275 fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> {
276 trace!(
277 "Http::connect; scheme={:?}, host={:?}, port={:?}",
278 dst.scheme(),
279 dst.host(),
280 dst.port(),
281 );
282
283 if config.enforce_http {
284 if dst.scheme() != Some(&Scheme::HTTP) {
285 return Err(ConnectError {
286 msg: INVALID_NOT_HTTP.into(),
287 cause: None,
288 });
289 }
290 } else if dst.scheme().is_none() {
291 return Err(ConnectError {
292 msg: INVALID_MISSING_SCHEME.into(),
293 cause: None,
294 });
295 }
296
297 let host = match dst.host() {
298 Some(s) => s,
299 None => {
300 return Err(ConnectError {
301 msg: INVALID_MISSING_HOST.into(),
302 cause: None,
303 })
304 }
305 };
306 let port = match dst.port() {
307 Some(port) => port.as_u16(),
308 None => {
309 if dst.scheme() == Some(&Scheme::HTTPS) {
310 443
311 } else {
312 80
313 }
314 }
315 };
316
317 Ok((host, port))
318 }
319
320 impl<R> HttpConnector<R>
321 where
322 R: Resolve,
323 {
call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError>324 async fn call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError> {
325 let config = &self.config;
326
327 let (host, port) = get_host_port(config, &dst)?;
328
329 // If the host is already an IP addr (v4 or v6),
330 // skip resolving the dns and start connecting right away.
331 let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) {
332 addrs
333 } else {
334 let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
335 .await
336 .map_err(ConnectError::dns)?;
337 let addrs = addrs
338 .map(|mut addr| {
339 addr.set_port(port);
340 addr
341 })
342 .collect();
343 dns::SocketAddrs::new(addrs)
344 };
345
346 let c = ConnectingTcp::new(addrs, config);
347
348 let sock = c.connect().await?;
349
350 if let Err(e) = sock.set_nodelay(config.nodelay) {
351 warn!("tcp set_nodelay error: {}", e);
352 }
353
354 Ok(sock)
355 }
356 }
357
358 impl Connection for TcpStream {
connected(&self) -> Connected359 fn connected(&self) -> Connected {
360 let connected = Connected::new();
361 if let Ok(remote_addr) = self.peer_addr() {
362 connected.extra(HttpInfo { remote_addr })
363 } else {
364 connected
365 }
366 }
367 }
368
369 impl HttpInfo {
370 /// Get the remote address of the transport used.
remote_addr(&self) -> SocketAddr371 pub fn remote_addr(&self) -> SocketAddr {
372 self.remote_addr
373 }
374 }
375
376 pin_project! {
377 // Not publicly exported (so missing_docs doesn't trigger).
378 //
379 // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly
380 // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot
381 // (and thus we can change the type in the future).
382 #[must_use = "futures do nothing unless polled"]
383 #[allow(missing_debug_implementations)]
384 pub struct HttpConnecting<R> {
385 #[pin]
386 fut: BoxConnecting,
387 _marker: PhantomData<R>,
388 }
389 }
390
391 type ConnectResult = Result<TcpStream, ConnectError>;
392 type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
393
394 impl<R: Resolve> Future for HttpConnecting<R> {
395 type Output = ConnectResult;
396
poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>397 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
398 self.project().fut.poll(cx)
399 }
400 }
401
402 // Not publicly exported (so missing_docs doesn't trigger).
403 pub struct ConnectError {
404 msg: Box<str>,
405 cause: Option<Box<dyn StdError + Send + Sync>>,
406 }
407
408 impl ConnectError {
new<S, E>(msg: S, cause: E) -> ConnectError where S: Into<Box<str>>, E: Into<Box<dyn StdError + Send + Sync>>,409 fn new<S, E>(msg: S, cause: E) -> ConnectError
410 where
411 S: Into<Box<str>>,
412 E: Into<Box<dyn StdError + Send + Sync>>,
413 {
414 ConnectError {
415 msg: msg.into(),
416 cause: Some(cause.into()),
417 }
418 }
419
dns<E>(cause: E) -> ConnectError where E: Into<Box<dyn StdError + Send + Sync>>,420 fn dns<E>(cause: E) -> ConnectError
421 where
422 E: Into<Box<dyn StdError + Send + Sync>>,
423 {
424 ConnectError::new("dns error", cause)
425 }
426
m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError where S: Into<Box<str>>, E: Into<Box<dyn StdError + Send + Sync>>,427 fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError
428 where
429 S: Into<Box<str>>,
430 E: Into<Box<dyn StdError + Send + Sync>>,
431 {
432 move |cause| ConnectError::new(msg, cause)
433 }
434 }
435
436 impl fmt::Debug for ConnectError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result437 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
438 if let Some(ref cause) = self.cause {
439 f.debug_tuple("ConnectError")
440 .field(&self.msg)
441 .field(cause)
442 .finish()
443 } else {
444 self.msg.fmt(f)
445 }
446 }
447 }
448
449 impl fmt::Display for ConnectError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result450 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
451 f.write_str(&self.msg)?;
452
453 if let Some(ref cause) = self.cause {
454 write!(f, ": {}", cause)?;
455 }
456
457 Ok(())
458 }
459 }
460
461 impl StdError for ConnectError {
source(&self) -> Option<&(dyn StdError + 'static)>462 fn source(&self) -> Option<&(dyn StdError + 'static)> {
463 self.cause.as_ref().map(|e| &**e as _)
464 }
465 }
466
467 struct ConnectingTcp<'a> {
468 preferred: ConnectingTcpRemote,
469 fallback: Option<ConnectingTcpFallback>,
470 config: &'a Config,
471 }
472
473 impl<'a> ConnectingTcp<'a> {
new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self474 fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self {
475 if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
476 let (preferred_addrs, fallback_addrs) = remote_addrs
477 .split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
478 if fallback_addrs.is_empty() {
479 return ConnectingTcp {
480 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
481 fallback: None,
482 config,
483 };
484 }
485
486 ConnectingTcp {
487 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
488 fallback: Some(ConnectingTcpFallback {
489 delay: tokio::time::sleep(fallback_timeout),
490 remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
491 }),
492 config,
493 }
494 } else {
495 ConnectingTcp {
496 preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
497 fallback: None,
498 config,
499 }
500 }
501 }
502 }
503
504 struct ConnectingTcpFallback {
505 delay: Sleep,
506 remote: ConnectingTcpRemote,
507 }
508
509 struct ConnectingTcpRemote {
510 addrs: dns::SocketAddrs,
511 connect_timeout: Option<Duration>,
512 }
513
514 impl ConnectingTcpRemote {
new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self515 fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self {
516 let connect_timeout = connect_timeout.map(|t| t / (addrs.len() as u32));
517
518 Self {
519 addrs,
520 connect_timeout,
521 }
522 }
523 }
524
525 impl ConnectingTcpRemote {
connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError>526 async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
527 let mut err = None;
528 for addr in &mut self.addrs {
529 debug!("connecting to {}", addr);
530 match connect(&addr, config, self.connect_timeout)?.await {
531 Ok(tcp) => {
532 debug!("connected to {}", addr);
533 return Ok(tcp);
534 }
535 Err(e) => {
536 trace!("connect error for {}: {:?}", addr, e);
537 err = Some(e);
538 }
539 }
540 }
541
542 match err {
543 Some(e) => Err(e),
544 None => Err(ConnectError::new(
545 "tcp connect error",
546 std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
547 )),
548 }
549 }
550 }
551
bind_local_address( socket: &socket2::Socket, dst_addr: &SocketAddr, local_addr_ipv4: &Option<Ipv4Addr>, local_addr_ipv6: &Option<Ipv6Addr>, ) -> io::Result<()>552 fn bind_local_address(
553 socket: &socket2::Socket,
554 dst_addr: &SocketAddr,
555 local_addr_ipv4: &Option<Ipv4Addr>,
556 local_addr_ipv6: &Option<Ipv6Addr>,
557 ) -> io::Result<()> {
558 match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
559 (SocketAddr::V4(_), Some(addr), _) => {
560 socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
561 }
562 (SocketAddr::V6(_), _, Some(addr)) => {
563 socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
564 }
565 _ => {
566 if cfg!(windows) {
567 // Windows requires a socket be bound before calling connect
568 let any: SocketAddr = match *dst_addr {
569 SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
570 SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
571 };
572 socket.bind(&any.into())?;
573 }
574 }
575 }
576
577 Ok(())
578 }
579
connect( addr: &SocketAddr, config: &Config, connect_timeout: Option<Duration>, ) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError>580 fn connect(
581 addr: &SocketAddr,
582 config: &Config,
583 connect_timeout: Option<Duration>,
584 ) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
585 // TODO(eliza): if Tokio's `TcpSocket` gains support for setting the
586 // keepalive timeout, it would be nice to use that instead of socket2,
587 // and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance...
588 use socket2::{Domain, Protocol, Socket, TcpKeepalive, Type};
589 use std::convert::TryInto;
590
591 let domain = Domain::for_address(*addr);
592 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
593 .map_err(ConnectError::m("tcp open error"))?;
594
595 // When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is
596 // responsible for ensuring O_NONBLOCK is set.
597 socket
598 .set_nonblocking(true)
599 .map_err(ConnectError::m("tcp set_nonblocking error"))?;
600
601 if let Some(dur) = config.keep_alive_timeout {
602 let conf = TcpKeepalive::new().with_time(dur);
603 if let Err(e) = socket.set_tcp_keepalive(&conf) {
604 warn!("tcp set_keepalive error: {}", e);
605 }
606 }
607
608 bind_local_address(
609 &socket,
610 addr,
611 &config.local_address_ipv4,
612 &config.local_address_ipv6,
613 )
614 .map_err(ConnectError::m("tcp bind local error"))?;
615
616 #[cfg(unix)]
617 let socket = unsafe {
618 // Safety: `from_raw_fd` is only safe to call if ownership of the raw
619 // file descriptor is transferred. Since we call `into_raw_fd` on the
620 // socket2 socket, it gives up ownership of the fd and will not close
621 // it, so this is safe.
622 use std::os::unix::io::{FromRawFd, IntoRawFd};
623 TcpSocket::from_raw_fd(socket.into_raw_fd())
624 };
625 #[cfg(windows)]
626 let socket = unsafe {
627 // Safety: `from_raw_socket` is only safe to call if ownership of the raw
628 // Windows SOCKET is transferred. Since we call `into_raw_socket` on the
629 // socket2 socket, it gives up ownership of the SOCKET and will not close
630 // it, so this is safe.
631 use std::os::windows::io::{FromRawSocket, IntoRawSocket};
632 TcpSocket::from_raw_socket(socket.into_raw_socket())
633 };
634
635 if config.reuse_address {
636 if let Err(e) = socket.set_reuseaddr(true) {
637 warn!("tcp set_reuse_address error: {}", e);
638 }
639 }
640
641 if let Some(size) = config.send_buffer_size {
642 if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) {
643 warn!("tcp set_buffer_size error: {}", e);
644 }
645 }
646
647 if let Some(size) = config.recv_buffer_size {
648 if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) {
649 warn!("tcp set_recv_buffer_size error: {}", e);
650 }
651 }
652
653 let connect = socket.connect(*addr);
654 Ok(async move {
655 match connect_timeout {
656 Some(dur) => match tokio::time::timeout(dur, connect).await {
657 Ok(Ok(s)) => Ok(s),
658 Ok(Err(e)) => Err(e),
659 Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
660 },
661 None => connect.await,
662 }
663 .map_err(ConnectError::m("tcp connect error"))
664 })
665 }
666
667 impl ConnectingTcp<'_> {
connect(mut self) -> Result<TcpStream, ConnectError>668 async fn connect(mut self) -> Result<TcpStream, ConnectError> {
669 match self.fallback {
670 None => self.preferred.connect(self.config).await,
671 Some(mut fallback) => {
672 let preferred_fut = self.preferred.connect(self.config);
673 futures_util::pin_mut!(preferred_fut);
674
675 let fallback_fut = fallback.remote.connect(self.config);
676 futures_util::pin_mut!(fallback_fut);
677
678 let fallback_delay = fallback.delay;
679 futures_util::pin_mut!(fallback_delay);
680
681 let (result, future) =
682 match futures_util::future::select(preferred_fut, fallback_delay).await {
683 Either::Left((result, _fallback_delay)) => {
684 (result, Either::Right(fallback_fut))
685 }
686 Either::Right(((), preferred_fut)) => {
687 // Delay is done, start polling both the preferred and the fallback
688 futures_util::future::select(preferred_fut, fallback_fut)
689 .await
690 .factor_first()
691 }
692 };
693
694 if result.is_err() {
695 // Fallback to the remaining future (could be preferred or fallback)
696 // if we get an error
697 future.await
698 } else {
699 result
700 }
701 }
702 }
703 }
704 }
705
706 #[cfg(test)]
707 mod tests {
708 use std::io;
709
710 use ::http::Uri;
711
712 use super::super::sealed::{Connect, ConnectSvc};
713 use super::{Config, ConnectError, HttpConnector};
714
connect<C>( connector: C, dst: Uri, ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error> where C: Connect,715 async fn connect<C>(
716 connector: C,
717 dst: Uri,
718 ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
719 where
720 C: Connect,
721 {
722 connector.connect(super::super::sealed::Internal, dst).await
723 }
724
725 #[tokio::test]
test_errors_enforce_http()726 async fn test_errors_enforce_http() {
727 let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
728 let connector = HttpConnector::new();
729
730 let err = connect(connector, dst).await.unwrap_err();
731 assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
732 }
733
734 #[cfg(any(target_os = "linux", target_os = "macos"))]
get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>)735 fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
736 use std::net::{IpAddr, TcpListener};
737
738 let mut ip_v4 = None;
739 let mut ip_v6 = None;
740
741 let ips = pnet_datalink::interfaces()
742 .into_iter()
743 .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
744
745 for ip in ips {
746 match ip {
747 IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
748 IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
749 _ => (),
750 }
751
752 if ip_v4.is_some() && ip_v6.is_some() {
753 break;
754 }
755 }
756
757 (ip_v4, ip_v6)
758 }
759
760 #[tokio::test]
test_errors_missing_scheme()761 async fn test_errors_missing_scheme() {
762 let dst = "example.domain".parse().unwrap();
763 let mut connector = HttpConnector::new();
764 connector.enforce_http(false);
765
766 let err = connect(connector, dst).await.unwrap_err();
767 assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
768 }
769
770 // NOTE: pnet crate that we use in this test doesn't compile on Windows
771 #[cfg(any(target_os = "linux", target_os = "macos"))]
772 #[tokio::test]
local_address()773 async fn local_address() {
774 use std::net::{IpAddr, TcpListener};
775 let _ = pretty_env_logger::try_init();
776
777 let (bind_ip_v4, bind_ip_v6) = get_local_ips();
778 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
779 let port = server4.local_addr().unwrap().port();
780 let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
781
782 let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
783 let mut connector = HttpConnector::new();
784
785 match (bind_ip_v4, bind_ip_v6) {
786 (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
787 (Some(v4), None) => connector.set_local_address(Some(v4.into())),
788 (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
789 _ => unreachable!(),
790 }
791
792 connect(connector, dst.parse().unwrap()).await.unwrap();
793
794 let (_, client_addr) = server.accept().unwrap();
795
796 assert_eq!(client_addr.ip(), expected_ip);
797 };
798
799 if let Some(ip) = bind_ip_v4 {
800 assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await;
801 }
802
803 if let Some(ip) = bind_ip_v6 {
804 assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await;
805 }
806 }
807
808 #[test]
809 #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
client_happy_eyeballs()810 fn client_happy_eyeballs() {
811 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
812 use std::time::{Duration, Instant};
813
814 use super::dns;
815 use super::ConnectingTcp;
816
817 let _ = pretty_env_logger::try_init();
818 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
819 let addr = server4.local_addr().unwrap();
820 let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap();
821 let rt = tokio::runtime::Builder::new_current_thread()
822 .enable_all()
823 .build()
824 .unwrap();
825
826 let local_timeout = Duration::default();
827 let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
828 let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
829 let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
830 + Duration::from_millis(250);
831
832 let scenarios = &[
833 // Fast primary, without fallback.
834 (&[local_ipv4_addr()][..], 4, local_timeout, false),
835 (&[local_ipv6_addr()][..], 6, local_timeout, false),
836 // Fast primary, with (unused) fallback.
837 (
838 &[local_ipv4_addr(), local_ipv6_addr()][..],
839 4,
840 local_timeout,
841 false,
842 ),
843 (
844 &[local_ipv6_addr(), local_ipv4_addr()][..],
845 6,
846 local_timeout,
847 false,
848 ),
849 // Unreachable + fast primary, without fallback.
850 (
851 &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
852 4,
853 unreachable_v4_timeout,
854 false,
855 ),
856 (
857 &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
858 6,
859 unreachable_v6_timeout,
860 false,
861 ),
862 // Unreachable + fast primary, with (unused) fallback.
863 (
864 &[
865 unreachable_ipv4_addr(),
866 local_ipv4_addr(),
867 local_ipv6_addr(),
868 ][..],
869 4,
870 unreachable_v4_timeout,
871 false,
872 ),
873 (
874 &[
875 unreachable_ipv6_addr(),
876 local_ipv6_addr(),
877 local_ipv4_addr(),
878 ][..],
879 6,
880 unreachable_v6_timeout,
881 true,
882 ),
883 // Slow primary, with (used) fallback.
884 (
885 &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
886 6,
887 fallback_timeout,
888 false,
889 ),
890 (
891 &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
892 4,
893 fallback_timeout,
894 true,
895 ),
896 // Slow primary, with (used) unreachable + fast fallback.
897 (
898 &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
899 6,
900 fallback_timeout + unreachable_v6_timeout,
901 false,
902 ),
903 (
904 &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
905 4,
906 fallback_timeout + unreachable_v4_timeout,
907 true,
908 ),
909 ];
910
911 // Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network.
912 // Otherwise, connection to "slow" IPv6 address will error-out immediately.
913 let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
914
915 for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
916 if needs_ipv6_access && !ipv6_accessible {
917 continue;
918 }
919
920 let (start, stream) = rt
921 .block_on(async move {
922 let addrs = hosts
923 .iter()
924 .map(|host| (host.clone(), addr.port()).into())
925 .collect();
926 let cfg = Config {
927 local_address_ipv4: None,
928 local_address_ipv6: None,
929 connect_timeout: None,
930 keep_alive_timeout: None,
931 happy_eyeballs_timeout: Some(fallback_timeout),
932 nodelay: false,
933 reuse_address: false,
934 enforce_http: false,
935 send_buffer_size: None,
936 recv_buffer_size: None,
937 };
938 let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
939 let start = Instant::now();
940 Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
941 })
942 .unwrap();
943 let res = if stream.peer_addr().unwrap().is_ipv4() {
944 4
945 } else {
946 6
947 };
948 let duration = start.elapsed();
949
950 // Allow actual duration to be +/- 150ms off.
951 let min_duration = if timeout >= Duration::from_millis(150) {
952 timeout - Duration::from_millis(150)
953 } else {
954 Duration::default()
955 };
956 let max_duration = timeout + Duration::from_millis(150);
957
958 assert_eq!(res, family);
959 assert!(duration >= min_duration);
960 assert!(duration <= max_duration);
961 }
962
963 fn local_ipv4_addr() -> IpAddr {
964 Ipv4Addr::new(127, 0, 0, 1).into()
965 }
966
967 fn local_ipv6_addr() -> IpAddr {
968 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
969 }
970
971 fn unreachable_ipv4_addr() -> IpAddr {
972 Ipv4Addr::new(127, 0, 0, 2).into()
973 }
974
975 fn unreachable_ipv6_addr() -> IpAddr {
976 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
977 }
978
979 fn slow_ipv4_addr() -> IpAddr {
980 // RFC 6890 reserved IPv4 address.
981 Ipv4Addr::new(198, 18, 0, 25).into()
982 }
983
984 fn slow_ipv6_addr() -> IpAddr {
985 // RFC 6890 reserved IPv6 address.
986 Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
987 }
988
989 fn measure_connect(addr: IpAddr) -> (bool, Duration) {
990 let start = Instant::now();
991 let result =
992 std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
993
994 let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
995 let duration = start.elapsed();
996 (reachable, duration)
997 }
998 }
999 }
1000