1 use std::error::Error as StdError;
2 use std::fmt;
3 use std::future::Future;
4 use std::io;
5 use std::marker::PhantomData;
6 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7 use std::pin::Pin;
8 use std::sync::Arc;
9 use std::task::{self, Poll};
10 use std::time::Duration;
11
12 use futures_util::future::Either;
13 use http::uri::{Scheme, Uri};
14 use pin_project_lite::pin_project;
15 use tokio::net::{TcpSocket, TcpStream};
16 use tokio::time::Sleep;
17 use tracing::{debug, trace, warn};
18
19 use super::dns::{self, resolve, GaiResolver, Resolve};
20 use super::{Connected, Connection};
21 //#[cfg(feature = "runtime")] use super::dns::TokioThreadpoolGaiResolver;
22
23 /// A connector for the `http` scheme.
24 ///
25 /// Performs DNS resolution in a thread pool, and then connects over TCP.
26 ///
27 /// # Note
28 ///
29 /// Sets the [`HttpInfo`](HttpInfo) value on responses, which includes
30 /// transport information such as the remote socket address used.
31 #[cfg_attr(docsrs, doc(cfg(feature = "tcp")))]
32 #[derive(Clone)]
33 pub struct HttpConnector<R = GaiResolver> {
34 config: Arc<Config>,
35 resolver: R,
36 }
37
38 /// Extra information about the transport when an HttpConnector is used.
39 ///
40 /// # Example
41 ///
42 /// ```
43 /// # async fn doc() -> hyper::Result<()> {
44 /// use hyper::Uri;
45 /// use hyper::client::{Client, connect::HttpInfo};
46 ///
47 /// let client = Client::new();
48 /// let uri = Uri::from_static("http://example.com");
49 ///
50 /// let res = client.get(uri).await?;
51 /// res
52 /// .extensions()
53 /// .get::<HttpInfo>()
54 /// .map(|info| {
55 /// println!("remote addr = {}", info.remote_addr());
56 /// });
57 /// # Ok(())
58 /// # }
59 /// ```
60 ///
61 /// # Note
62 ///
63 /// If a different connector is used besides [`HttpConnector`](HttpConnector),
64 /// this value will not exist in the extensions. Consult that specific
65 /// connector to see what "extra" information it might provide to responses.
66 #[derive(Clone, Debug)]
67 pub struct HttpInfo {
68 remote_addr: SocketAddr,
69 }
70
71 #[derive(Clone)]
72 struct Config {
73 connect_timeout: Option<Duration>,
74 enforce_http: bool,
75 happy_eyeballs_timeout: Option<Duration>,
76 keep_alive_timeout: Option<Duration>,
77 local_address_ipv4: Option<Ipv4Addr>,
78 local_address_ipv6: Option<Ipv6Addr>,
79 nodelay: bool,
80 reuse_address: bool,
81 send_buffer_size: Option<usize>,
82 recv_buffer_size: Option<usize>,
83 }
84
85 // ===== impl HttpConnector =====
86
87 impl HttpConnector {
88 /// Construct a new HttpConnector.
new() -> HttpConnector89 pub fn new() -> HttpConnector {
90 HttpConnector::new_with_resolver(GaiResolver::new())
91 }
92 }
93
94 /*
95 #[cfg(feature = "runtime")]
96 impl HttpConnector<TokioThreadpoolGaiResolver> {
97 /// Construct a new HttpConnector using the `TokioThreadpoolGaiResolver`.
98 ///
99 /// This resolver **requires** the threadpool runtime to be used.
100 pub fn new_with_tokio_threadpool_resolver() -> Self {
101 HttpConnector::new_with_resolver(TokioThreadpoolGaiResolver::new())
102 }
103 }
104 */
105
106 impl<R> HttpConnector<R> {
107 /// Construct a new HttpConnector.
108 ///
109 /// Takes a [`Resolver`](crate::client::connect::dns#resolvers-are-services) to handle DNS lookups.
new_with_resolver(resolver: R) -> HttpConnector<R>110 pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
111 HttpConnector {
112 config: Arc::new(Config {
113 connect_timeout: None,
114 enforce_http: true,
115 happy_eyeballs_timeout: Some(Duration::from_millis(300)),
116 keep_alive_timeout: None,
117 local_address_ipv4: None,
118 local_address_ipv6: None,
119 nodelay: false,
120 reuse_address: false,
121 send_buffer_size: None,
122 recv_buffer_size: None,
123 }),
124 resolver,
125 }
126 }
127
128 /// Option to enforce all `Uri`s have the `http` scheme.
129 ///
130 /// Enabled by default.
131 #[inline]
enforce_http(&mut self, is_enforced: bool)132 pub fn enforce_http(&mut self, is_enforced: bool) {
133 self.config_mut().enforce_http = is_enforced;
134 }
135
136 /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration.
137 ///
138 /// If `None`, the option will not be set.
139 ///
140 /// Default is `None`.
141 #[inline]
set_keepalive(&mut self, dur: Option<Duration>)142 pub fn set_keepalive(&mut self, dur: Option<Duration>) {
143 self.config_mut().keep_alive_timeout = dur;
144 }
145
146 /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`.
147 ///
148 /// Default is `false`.
149 #[inline]
set_nodelay(&mut self, nodelay: bool)150 pub fn set_nodelay(&mut self, nodelay: bool) {
151 self.config_mut().nodelay = nodelay;
152 }
153
154 /// Sets the value of the SO_SNDBUF option on the socket.
155 #[inline]
set_send_buffer_size(&mut self, size: Option<usize>)156 pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
157 self.config_mut().send_buffer_size = size;
158 }
159
160 /// Sets the value of the SO_RCVBUF option on the socket.
161 #[inline]
set_recv_buffer_size(&mut self, size: Option<usize>)162 pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
163 self.config_mut().recv_buffer_size = size;
164 }
165
166 /// Set that all sockets are bound to the configured address before connection.
167 ///
168 /// If `None`, the sockets will not be bound.
169 ///
170 /// Default is `None`.
171 #[inline]
set_local_address(&mut self, addr: Option<IpAddr>)172 pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
173 let (v4, v6) = match addr {
174 Some(IpAddr::V4(a)) => (Some(a), None),
175 Some(IpAddr::V6(a)) => (None, Some(a)),
176 _ => (None, None),
177 };
178
179 let cfg = self.config_mut();
180
181 cfg.local_address_ipv4 = v4;
182 cfg.local_address_ipv6 = v6;
183 }
184
185 /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's
186 /// preferences) before connection.
187 #[inline]
set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr)188 pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
189 let cfg = self.config_mut();
190
191 cfg.local_address_ipv4 = Some(addr_ipv4);
192 cfg.local_address_ipv6 = Some(addr_ipv6);
193 }
194
195 /// Set the connect timeout.
196 ///
197 /// If a domain resolves to multiple IP addresses, the timeout will be
198 /// evenly divided across them.
199 ///
200 /// Default is `None`.
201 #[inline]
set_connect_timeout(&mut self, dur: Option<Duration>)202 pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
203 self.config_mut().connect_timeout = dur;
204 }
205
206 /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm.
207 ///
208 /// If hostname resolves to both IPv4 and IPv6 addresses and connection
209 /// cannot be established using preferred address family before timeout
210 /// elapses, then connector will in parallel attempt connection using other
211 /// address family.
212 ///
213 /// If `None`, parallel connection attempts are disabled.
214 ///
215 /// Default is 300 milliseconds.
216 ///
217 /// [RFC 6555]: https://tools.ietf.org/html/rfc6555
218 #[inline]
set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>)219 pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
220 self.config_mut().happy_eyeballs_timeout = dur;
221 }
222
223 /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`.
224 ///
225 /// Default is `false`.
226 #[inline]
set_reuse_address(&mut self, reuse_address: bool) -> &mut Self227 pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
228 self.config_mut().reuse_address = reuse_address;
229 self
230 }
231
232 // private
233
config_mut(&mut self) -> &mut Config234 fn config_mut(&mut self) -> &mut Config {
235 // If the are HttpConnector clones, this will clone the inner
236 // config. So mutating the config won't ever affect previous
237 // clones.
238 Arc::make_mut(&mut self.config)
239 }
240 }
241
242 static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
243 static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
244 static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
245
246 // R: Debug required for now to allow adding it to debug output later...
247 impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249 f.debug_struct("HttpConnector").finish()
250 }
251 }
252
253 impl<R> tower_service::Service<Uri> for HttpConnector<R>
254 where
255 R: Resolve + Clone + Send + Sync + 'static,
256 R::Future: Send,
257 {
258 type Response = TcpStream;
259 type Error = ConnectError;
260 type Future = HttpConnecting<R>;
261
poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>262 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
263 ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
264 Poll::Ready(Ok(()))
265 }
266
call(&mut self, dst: Uri) -> Self::Future267 fn call(&mut self, dst: Uri) -> Self::Future {
268 let mut self_ = self.clone();
269 HttpConnecting {
270 fut: Box::pin(async move { self_.call_async(dst).await }),
271 _marker: PhantomData,
272 }
273 }
274 }
275
get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError>276 fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> {
277 trace!(
278 "Http::connect; scheme={:?}, host={:?}, port={:?}",
279 dst.scheme(),
280 dst.host(),
281 dst.port(),
282 );
283
284 if config.enforce_http {
285 if dst.scheme() != Some(&Scheme::HTTP) {
286 return Err(ConnectError {
287 msg: INVALID_NOT_HTTP.into(),
288 cause: None,
289 });
290 }
291 } else if dst.scheme().is_none() {
292 return Err(ConnectError {
293 msg: INVALID_MISSING_SCHEME.into(),
294 cause: None,
295 });
296 }
297
298 let host = match dst.host() {
299 Some(s) => s,
300 None => {
301 return Err(ConnectError {
302 msg: INVALID_MISSING_HOST.into(),
303 cause: None,
304 })
305 }
306 };
307 let port = match dst.port() {
308 Some(port) => port.as_u16(),
309 None => {
310 if dst.scheme() == Some(&Scheme::HTTPS) {
311 443
312 } else {
313 80
314 }
315 }
316 };
317
318 Ok((host, port))
319 }
320
321 impl<R> HttpConnector<R>
322 where
323 R: Resolve,
324 {
call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError>325 async fn call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError> {
326 let config = &self.config;
327
328 let (host, port) = get_host_port(config, &dst)?;
329 let host = host.trim_start_matches('[').trim_end_matches(']');
330
331 // If the host is already an IP addr (v4 or v6),
332 // skip resolving the dns and start connecting right away.
333 let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) {
334 addrs
335 } else {
336 let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
337 .await
338 .map_err(ConnectError::dns)?;
339 let addrs = addrs
340 .map(|mut addr| {
341 addr.set_port(port);
342 addr
343 })
344 .collect();
345 dns::SocketAddrs::new(addrs)
346 };
347
348 let c = ConnectingTcp::new(addrs, config);
349
350 let sock = c.connect().await?;
351
352 if let Err(e) = sock.set_nodelay(config.nodelay) {
353 warn!("tcp set_nodelay error: {}", e);
354 }
355
356 Ok(sock)
357 }
358 }
359
360 impl Connection for TcpStream {
connected(&self) -> Connected361 fn connected(&self) -> Connected {
362 let connected = Connected::new();
363 if let Ok(remote_addr) = self.peer_addr() {
364 connected.extra(HttpInfo { remote_addr })
365 } else {
366 connected
367 }
368 }
369 }
370
371 impl HttpInfo {
372 /// Get the remote address of the transport used.
remote_addr(&self) -> SocketAddr373 pub fn remote_addr(&self) -> SocketAddr {
374 self.remote_addr
375 }
376 }
377
378 pin_project! {
379 // Not publicly exported (so missing_docs doesn't trigger).
380 //
381 // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly
382 // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot
383 // (and thus we can change the type in the future).
384 #[must_use = "futures do nothing unless polled"]
385 #[allow(missing_debug_implementations)]
386 pub struct HttpConnecting<R> {
387 #[pin]
388 fut: BoxConnecting,
389 _marker: PhantomData<R>,
390 }
391 }
392
393 type ConnectResult = Result<TcpStream, ConnectError>;
394 type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
395
396 impl<R: Resolve> Future for HttpConnecting<R> {
397 type Output = ConnectResult;
398
poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>399 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
400 self.project().fut.poll(cx)
401 }
402 }
403
404 // Not publicly exported (so missing_docs doesn't trigger).
405 pub struct ConnectError {
406 msg: Box<str>,
407 cause: Option<Box<dyn StdError + Send + Sync>>,
408 }
409
410 impl ConnectError {
new<S, E>(msg: S, cause: E) -> ConnectError where S: Into<Box<str>>, E: Into<Box<dyn StdError + Send + Sync>>,411 fn new<S, E>(msg: S, cause: E) -> ConnectError
412 where
413 S: Into<Box<str>>,
414 E: Into<Box<dyn StdError + Send + Sync>>,
415 {
416 ConnectError {
417 msg: msg.into(),
418 cause: Some(cause.into()),
419 }
420 }
421
dns<E>(cause: E) -> ConnectError where E: Into<Box<dyn StdError + Send + Sync>>,422 fn dns<E>(cause: E) -> ConnectError
423 where
424 E: Into<Box<dyn StdError + Send + Sync>>,
425 {
426 ConnectError::new("dns error", cause)
427 }
428
m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError where S: Into<Box<str>>, E: Into<Box<dyn StdError + Send + Sync>>,429 fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError
430 where
431 S: Into<Box<str>>,
432 E: Into<Box<dyn StdError + Send + Sync>>,
433 {
434 move |cause| ConnectError::new(msg, cause)
435 }
436 }
437
438 impl fmt::Debug for ConnectError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result439 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
440 if let Some(ref cause) = self.cause {
441 f.debug_tuple("ConnectError")
442 .field(&self.msg)
443 .field(cause)
444 .finish()
445 } else {
446 self.msg.fmt(f)
447 }
448 }
449 }
450
451 impl fmt::Display for ConnectError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result452 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
453 f.write_str(&self.msg)?;
454
455 if let Some(ref cause) = self.cause {
456 write!(f, ": {}", cause)?;
457 }
458
459 Ok(())
460 }
461 }
462
463 impl StdError for ConnectError {
source(&self) -> Option<&(dyn StdError + 'static)>464 fn source(&self) -> Option<&(dyn StdError + 'static)> {
465 self.cause.as_ref().map(|e| &**e as _)
466 }
467 }
468
469 struct ConnectingTcp<'a> {
470 preferred: ConnectingTcpRemote,
471 fallback: Option<ConnectingTcpFallback>,
472 config: &'a Config,
473 }
474
475 impl<'a> ConnectingTcp<'a> {
new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self476 fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self {
477 if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
478 let (preferred_addrs, fallback_addrs) = remote_addrs
479 .split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
480 if fallback_addrs.is_empty() {
481 return ConnectingTcp {
482 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
483 fallback: None,
484 config,
485 };
486 }
487
488 ConnectingTcp {
489 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
490 fallback: Some(ConnectingTcpFallback {
491 delay: tokio::time::sleep(fallback_timeout),
492 remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
493 }),
494 config,
495 }
496 } else {
497 ConnectingTcp {
498 preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
499 fallback: None,
500 config,
501 }
502 }
503 }
504 }
505
506 struct ConnectingTcpFallback {
507 delay: Sleep,
508 remote: ConnectingTcpRemote,
509 }
510
511 struct ConnectingTcpRemote {
512 addrs: dns::SocketAddrs,
513 connect_timeout: Option<Duration>,
514 }
515
516 impl ConnectingTcpRemote {
new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self517 fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self {
518 let connect_timeout = connect_timeout.map(|t| t / (addrs.len() as u32));
519
520 Self {
521 addrs,
522 connect_timeout,
523 }
524 }
525 }
526
527 impl ConnectingTcpRemote {
connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError>528 async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
529 let mut err = None;
530 for addr in &mut self.addrs {
531 debug!("connecting to {}", addr);
532 match connect(&addr, config, self.connect_timeout)?.await {
533 Ok(tcp) => {
534 debug!("connected to {}", addr);
535 return Ok(tcp);
536 }
537 Err(e) => {
538 trace!("connect error for {}: {:?}", addr, e);
539 err = Some(e);
540 }
541 }
542 }
543
544 match err {
545 Some(e) => Err(e),
546 None => Err(ConnectError::new(
547 "tcp connect error",
548 std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
549 )),
550 }
551 }
552 }
553
bind_local_address( socket: &socket2::Socket, dst_addr: &SocketAddr, local_addr_ipv4: &Option<Ipv4Addr>, local_addr_ipv6: &Option<Ipv6Addr>, ) -> io::Result<()>554 fn bind_local_address(
555 socket: &socket2::Socket,
556 dst_addr: &SocketAddr,
557 local_addr_ipv4: &Option<Ipv4Addr>,
558 local_addr_ipv6: &Option<Ipv6Addr>,
559 ) -> io::Result<()> {
560 match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
561 (SocketAddr::V4(_), Some(addr), _) => {
562 socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
563 }
564 (SocketAddr::V6(_), _, Some(addr)) => {
565 socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
566 }
567 _ => {
568 if cfg!(windows) {
569 // Windows requires a socket be bound before calling connect
570 let any: SocketAddr = match *dst_addr {
571 SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
572 SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
573 };
574 socket.bind(&any.into())?;
575 }
576 }
577 }
578
579 Ok(())
580 }
581
connect( addr: &SocketAddr, config: &Config, connect_timeout: Option<Duration>, ) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError>582 fn connect(
583 addr: &SocketAddr,
584 config: &Config,
585 connect_timeout: Option<Duration>,
586 ) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
587 // TODO(eliza): if Tokio's `TcpSocket` gains support for setting the
588 // keepalive timeout, it would be nice to use that instead of socket2,
589 // and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance...
590 use socket2::{Domain, Protocol, Socket, TcpKeepalive, Type};
591 use std::convert::TryInto;
592
593 let domain = Domain::for_address(*addr);
594 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
595 .map_err(ConnectError::m("tcp open error"))?;
596
597 // When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is
598 // responsible for ensuring O_NONBLOCK is set.
599 socket
600 .set_nonblocking(true)
601 .map_err(ConnectError::m("tcp set_nonblocking error"))?;
602
603 if let Some(dur) = config.keep_alive_timeout {
604 let conf = TcpKeepalive::new().with_time(dur);
605 if let Err(e) = socket.set_tcp_keepalive(&conf) {
606 warn!("tcp set_keepalive error: {}", e);
607 }
608 }
609
610 bind_local_address(
611 &socket,
612 addr,
613 &config.local_address_ipv4,
614 &config.local_address_ipv6,
615 )
616 .map_err(ConnectError::m("tcp bind local error"))?;
617
618 #[cfg(unix)]
619 let socket = unsafe {
620 // Safety: `from_raw_fd` is only safe to call if ownership of the raw
621 // file descriptor is transferred. Since we call `into_raw_fd` on the
622 // socket2 socket, it gives up ownership of the fd and will not close
623 // it, so this is safe.
624 use std::os::unix::io::{FromRawFd, IntoRawFd};
625 TcpSocket::from_raw_fd(socket.into_raw_fd())
626 };
627 #[cfg(windows)]
628 let socket = unsafe {
629 // Safety: `from_raw_socket` is only safe to call if ownership of the raw
630 // Windows SOCKET is transferred. Since we call `into_raw_socket` on the
631 // socket2 socket, it gives up ownership of the SOCKET and will not close
632 // it, so this is safe.
633 use std::os::windows::io::{FromRawSocket, IntoRawSocket};
634 TcpSocket::from_raw_socket(socket.into_raw_socket())
635 };
636
637 if config.reuse_address {
638 if let Err(e) = socket.set_reuseaddr(true) {
639 warn!("tcp set_reuse_address error: {}", e);
640 }
641 }
642
643 if let Some(size) = config.send_buffer_size {
644 if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) {
645 warn!("tcp set_buffer_size error: {}", e);
646 }
647 }
648
649 if let Some(size) = config.recv_buffer_size {
650 if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) {
651 warn!("tcp set_recv_buffer_size error: {}", e);
652 }
653 }
654
655 let connect = socket.connect(*addr);
656 Ok(async move {
657 match connect_timeout {
658 Some(dur) => match tokio::time::timeout(dur, connect).await {
659 Ok(Ok(s)) => Ok(s),
660 Ok(Err(e)) => Err(e),
661 Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
662 },
663 None => connect.await,
664 }
665 .map_err(ConnectError::m("tcp connect error"))
666 })
667 }
668
669 impl ConnectingTcp<'_> {
connect(mut self) -> Result<TcpStream, ConnectError>670 async fn connect(mut self) -> Result<TcpStream, ConnectError> {
671 match self.fallback {
672 None => self.preferred.connect(self.config).await,
673 Some(mut fallback) => {
674 let preferred_fut = self.preferred.connect(self.config);
675 futures_util::pin_mut!(preferred_fut);
676
677 let fallback_fut = fallback.remote.connect(self.config);
678 futures_util::pin_mut!(fallback_fut);
679
680 let fallback_delay = fallback.delay;
681 futures_util::pin_mut!(fallback_delay);
682
683 let (result, future) =
684 match futures_util::future::select(preferred_fut, fallback_delay).await {
685 Either::Left((result, _fallback_delay)) => {
686 (result, Either::Right(fallback_fut))
687 }
688 Either::Right(((), preferred_fut)) => {
689 // Delay is done, start polling both the preferred and the fallback
690 futures_util::future::select(preferred_fut, fallback_fut)
691 .await
692 .factor_first()
693 }
694 };
695
696 if result.is_err() {
697 // Fallback to the remaining future (could be preferred or fallback)
698 // if we get an error
699 future.await
700 } else {
701 result
702 }
703 }
704 }
705 }
706 }
707
708 #[cfg(test)]
709 mod tests {
710 use std::io;
711
712 use ::http::Uri;
713
714 use super::super::sealed::{Connect, ConnectSvc};
715 use super::{Config, ConnectError, HttpConnector};
716
connect<C>( connector: C, dst: Uri, ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error> where C: Connect,717 async fn connect<C>(
718 connector: C,
719 dst: Uri,
720 ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
721 where
722 C: Connect,
723 {
724 connector.connect(super::super::sealed::Internal, dst).await
725 }
726
727 #[tokio::test]
test_errors_enforce_http()728 async fn test_errors_enforce_http() {
729 let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
730 let connector = HttpConnector::new();
731
732 let err = connect(connector, dst).await.unwrap_err();
733 assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
734 }
735
736 #[cfg(any(target_os = "linux", target_os = "macos"))]
get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>)737 fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
738 use std::net::{IpAddr, TcpListener};
739
740 let mut ip_v4 = None;
741 let mut ip_v6 = None;
742
743 let ips = pnet_datalink::interfaces()
744 .into_iter()
745 .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
746
747 for ip in ips {
748 match ip {
749 IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
750 IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
751 _ => (),
752 }
753
754 if ip_v4.is_some() && ip_v6.is_some() {
755 break;
756 }
757 }
758
759 (ip_v4, ip_v6)
760 }
761
762 #[tokio::test]
test_errors_missing_scheme()763 async fn test_errors_missing_scheme() {
764 let dst = "example.domain".parse().unwrap();
765 let mut connector = HttpConnector::new();
766 connector.enforce_http(false);
767
768 let err = connect(connector, dst).await.unwrap_err();
769 assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
770 }
771
772 // NOTE: pnet crate that we use in this test doesn't compile on Windows
773 #[cfg(any(target_os = "linux", target_os = "macos"))]
774 #[tokio::test]
local_address()775 async fn local_address() {
776 use std::net::{IpAddr, TcpListener};
777 let _ = pretty_env_logger::try_init();
778
779 let (bind_ip_v4, bind_ip_v6) = get_local_ips();
780 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
781 let port = server4.local_addr().unwrap().port();
782 let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
783
784 let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
785 let mut connector = HttpConnector::new();
786
787 match (bind_ip_v4, bind_ip_v6) {
788 (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
789 (Some(v4), None) => connector.set_local_address(Some(v4.into())),
790 (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
791 _ => unreachable!(),
792 }
793
794 connect(connector, dst.parse().unwrap()).await.unwrap();
795
796 let (_, client_addr) = server.accept().unwrap();
797
798 assert_eq!(client_addr.ip(), expected_ip);
799 };
800
801 if let Some(ip) = bind_ip_v4 {
802 assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await;
803 }
804
805 if let Some(ip) = bind_ip_v6 {
806 assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await;
807 }
808 }
809
810 #[test]
811 #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
client_happy_eyeballs()812 fn client_happy_eyeballs() {
813 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
814 use std::time::{Duration, Instant};
815
816 use super::dns;
817 use super::ConnectingTcp;
818
819 let _ = pretty_env_logger::try_init();
820 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
821 let addr = server4.local_addr().unwrap();
822 let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap();
823 let rt = tokio::runtime::Builder::new_current_thread()
824 .enable_all()
825 .build()
826 .unwrap();
827
828 let local_timeout = Duration::default();
829 let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
830 let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
831 let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
832 + Duration::from_millis(250);
833
834 let scenarios = &[
835 // Fast primary, without fallback.
836 (&[local_ipv4_addr()][..], 4, local_timeout, false),
837 (&[local_ipv6_addr()][..], 6, local_timeout, false),
838 // Fast primary, with (unused) fallback.
839 (
840 &[local_ipv4_addr(), local_ipv6_addr()][..],
841 4,
842 local_timeout,
843 false,
844 ),
845 (
846 &[local_ipv6_addr(), local_ipv4_addr()][..],
847 6,
848 local_timeout,
849 false,
850 ),
851 // Unreachable + fast primary, without fallback.
852 (
853 &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
854 4,
855 unreachable_v4_timeout,
856 false,
857 ),
858 (
859 &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
860 6,
861 unreachable_v6_timeout,
862 false,
863 ),
864 // Unreachable + fast primary, with (unused) fallback.
865 (
866 &[
867 unreachable_ipv4_addr(),
868 local_ipv4_addr(),
869 local_ipv6_addr(),
870 ][..],
871 4,
872 unreachable_v4_timeout,
873 false,
874 ),
875 (
876 &[
877 unreachable_ipv6_addr(),
878 local_ipv6_addr(),
879 local_ipv4_addr(),
880 ][..],
881 6,
882 unreachable_v6_timeout,
883 true,
884 ),
885 // Slow primary, with (used) fallback.
886 (
887 &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
888 6,
889 fallback_timeout,
890 false,
891 ),
892 (
893 &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
894 4,
895 fallback_timeout,
896 true,
897 ),
898 // Slow primary, with (used) unreachable + fast fallback.
899 (
900 &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
901 6,
902 fallback_timeout + unreachable_v6_timeout,
903 false,
904 ),
905 (
906 &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
907 4,
908 fallback_timeout + unreachable_v4_timeout,
909 true,
910 ),
911 ];
912
913 // Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network.
914 // Otherwise, connection to "slow" IPv6 address will error-out immediately.
915 let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
916
917 for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
918 if needs_ipv6_access && !ipv6_accessible {
919 continue;
920 }
921
922 let (start, stream) = rt
923 .block_on(async move {
924 let addrs = hosts
925 .iter()
926 .map(|host| (host.clone(), addr.port()).into())
927 .collect();
928 let cfg = Config {
929 local_address_ipv4: None,
930 local_address_ipv6: None,
931 connect_timeout: None,
932 keep_alive_timeout: None,
933 happy_eyeballs_timeout: Some(fallback_timeout),
934 nodelay: false,
935 reuse_address: false,
936 enforce_http: false,
937 send_buffer_size: None,
938 recv_buffer_size: None,
939 };
940 let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
941 let start = Instant::now();
942 Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
943 })
944 .unwrap();
945 let res = if stream.peer_addr().unwrap().is_ipv4() {
946 4
947 } else {
948 6
949 };
950 let duration = start.elapsed();
951
952 // Allow actual duration to be +/- 150ms off.
953 let min_duration = if timeout >= Duration::from_millis(150) {
954 timeout - Duration::from_millis(150)
955 } else {
956 Duration::default()
957 };
958 let max_duration = timeout + Duration::from_millis(150);
959
960 assert_eq!(res, family);
961 assert!(duration >= min_duration);
962 assert!(duration <= max_duration);
963 }
964
965 fn local_ipv4_addr() -> IpAddr {
966 Ipv4Addr::new(127, 0, 0, 1).into()
967 }
968
969 fn local_ipv6_addr() -> IpAddr {
970 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
971 }
972
973 fn unreachable_ipv4_addr() -> IpAddr {
974 Ipv4Addr::new(127, 0, 0, 2).into()
975 }
976
977 fn unreachable_ipv6_addr() -> IpAddr {
978 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
979 }
980
981 fn slow_ipv4_addr() -> IpAddr {
982 // RFC 6890 reserved IPv4 address.
983 Ipv4Addr::new(198, 18, 0, 25).into()
984 }
985
986 fn slow_ipv6_addr() -> IpAddr {
987 // RFC 6890 reserved IPv6 address.
988 Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
989 }
990
991 fn measure_connect(addr: IpAddr) -> (bool, Duration) {
992 let start = Instant::now();
993 let result =
994 std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
995
996 let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
997 let duration = start.elapsed();
998 (reachable, duration)
999 }
1000 }
1001 }
1002