1 use std::error::Error as StdError;
2 use std::fmt;
3 use std::future::Future;
4 use std::io;
5 use std::marker::PhantomData;
6 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7 use std::pin::Pin;
8 use std::sync::Arc;
9 use std::task::{self, Poll};
10 use std::time::Duration;
11
12 use futures_util::future::Either;
13 use http::uri::{Scheme, Uri};
14 use pin_project_lite::pin_project;
15 use tokio::net::{TcpSocket, TcpStream};
16 use tokio::time::Sleep;
17 use tracing::{debug, trace, warn};
18
19 use super::dns::{self, resolve, GaiResolver, Resolve};
20 use super::{Connected, Connection};
21 //#[cfg(feature = "runtime")] use super::dns::TokioThreadpoolGaiResolver;
22
23 /// A connector for the `http` scheme.
24 ///
25 /// Performs DNS resolution in a thread pool, and then connects over TCP.
26 ///
27 /// # Note
28 ///
29 /// Sets the [`HttpInfo`](HttpInfo) value on responses, which includes
30 /// transport information such as the remote socket address used.
31 #[cfg_attr(docsrs, doc(cfg(feature = "tcp")))]
32 #[derive(Clone)]
33 pub struct HttpConnector<R = GaiResolver> {
34 config: Arc<Config>,
35 resolver: R,
36 }
37
38 /// Extra information about the transport when an HttpConnector is used.
39 ///
40 /// # Example
41 ///
42 /// ```
43 /// # async fn doc() -> hyper::Result<()> {
44 /// use hyper::Uri;
45 /// use hyper::client::{Client, connect::HttpInfo};
46 ///
47 /// let client = Client::new();
48 /// let uri = Uri::from_static("http://example.com");
49 ///
50 /// let res = client.get(uri).await?;
51 /// res
52 /// .extensions()
53 /// .get::<HttpInfo>()
54 /// .map(|info| {
55 /// println!("remote addr = {}", info.remote_addr());
56 /// });
57 /// # Ok(())
58 /// # }
59 /// ```
60 ///
61 /// # Note
62 ///
63 /// If a different connector is used besides [`HttpConnector`](HttpConnector),
64 /// this value will not exist in the extensions. Consult that specific
65 /// connector to see what "extra" information it might provide to responses.
66 #[derive(Clone, Debug)]
67 pub struct HttpInfo {
68 remote_addr: SocketAddr,
69 }
70
71 #[derive(Clone)]
72 struct Config {
73 connect_timeout: Option<Duration>,
74 enforce_http: bool,
75 happy_eyeballs_timeout: Option<Duration>,
76 keep_alive_timeout: Option<Duration>,
77 local_address_ipv4: Option<Ipv4Addr>,
78 local_address_ipv6: Option<Ipv6Addr>,
79 nodelay: bool,
80 reuse_address: bool,
81 send_buffer_size: Option<usize>,
82 recv_buffer_size: Option<usize>,
83 }
84
85 // ===== impl HttpConnector =====
86
87 impl HttpConnector {
88 /// Construct a new HttpConnector.
new() -> HttpConnector89 pub fn new() -> HttpConnector {
90 HttpConnector::new_with_resolver(GaiResolver::new())
91 }
92 }
93
94 /*
95 #[cfg(feature = "runtime")]
96 impl HttpConnector<TokioThreadpoolGaiResolver> {
97 /// Construct a new HttpConnector using the `TokioThreadpoolGaiResolver`.
98 ///
99 /// This resolver **requires** the threadpool runtime to be used.
100 pub fn new_with_tokio_threadpool_resolver() -> Self {
101 HttpConnector::new_with_resolver(TokioThreadpoolGaiResolver::new())
102 }
103 }
104 */
105
106 impl<R> HttpConnector<R> {
107 /// Construct a new HttpConnector.
108 ///
109 /// Takes a [`Resolver`](crate::client::connect::dns#resolvers-are-services) to handle DNS lookups.
new_with_resolver(resolver: R) -> HttpConnector<R>110 pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
111 HttpConnector {
112 config: Arc::new(Config {
113 connect_timeout: None,
114 enforce_http: true,
115 happy_eyeballs_timeout: Some(Duration::from_millis(300)),
116 keep_alive_timeout: None,
117 local_address_ipv4: None,
118 local_address_ipv6: None,
119 nodelay: false,
120 reuse_address: false,
121 send_buffer_size: None,
122 recv_buffer_size: None,
123 }),
124 resolver,
125 }
126 }
127
128 /// Option to enforce all `Uri`s have the `http` scheme.
129 ///
130 /// Enabled by default.
131 #[inline]
enforce_http(&mut self, is_enforced: bool)132 pub fn enforce_http(&mut self, is_enforced: bool) {
133 self.config_mut().enforce_http = is_enforced;
134 }
135
136 /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration.
137 ///
138 /// If `None`, the option will not be set.
139 ///
140 /// Default is `None`.
141 #[inline]
set_keepalive(&mut self, dur: Option<Duration>)142 pub fn set_keepalive(&mut self, dur: Option<Duration>) {
143 self.config_mut().keep_alive_timeout = dur;
144 }
145
146 /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`.
147 ///
148 /// Default is `false`.
149 #[inline]
set_nodelay(&mut self, nodelay: bool)150 pub fn set_nodelay(&mut self, nodelay: bool) {
151 self.config_mut().nodelay = nodelay;
152 }
153
154 /// Sets the value of the SO_SNDBUF option on the socket.
155 #[inline]
set_send_buffer_size(&mut self, size: Option<usize>)156 pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
157 self.config_mut().send_buffer_size = size;
158 }
159
160 /// Sets the value of the SO_RCVBUF option on the socket.
161 #[inline]
set_recv_buffer_size(&mut self, size: Option<usize>)162 pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
163 self.config_mut().recv_buffer_size = size;
164 }
165
166 /// Set that all sockets are bound to the configured address before connection.
167 ///
168 /// If `None`, the sockets will not be bound.
169 ///
170 /// Default is `None`.
171 #[inline]
set_local_address(&mut self, addr: Option<IpAddr>)172 pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
173 let (v4, v6) = match addr {
174 Some(IpAddr::V4(a)) => (Some(a), None),
175 Some(IpAddr::V6(a)) => (None, Some(a)),
176 _ => (None, None),
177 };
178
179 let cfg = self.config_mut();
180
181 cfg.local_address_ipv4 = v4;
182 cfg.local_address_ipv6 = v6;
183 }
184
185 /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's
186 /// preferences) before connection.
187 #[inline]
set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr)188 pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
189 let cfg = self.config_mut();
190
191 cfg.local_address_ipv4 = Some(addr_ipv4);
192 cfg.local_address_ipv6 = Some(addr_ipv6);
193 }
194
195 /// Set the connect timeout.
196 ///
197 /// If a domain resolves to multiple IP addresses, the timeout will be
198 /// evenly divided across them.
199 ///
200 /// Default is `None`.
201 #[inline]
set_connect_timeout(&mut self, dur: Option<Duration>)202 pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
203 self.config_mut().connect_timeout = dur;
204 }
205
206 /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm.
207 ///
208 /// If hostname resolves to both IPv4 and IPv6 addresses and connection
209 /// cannot be established using preferred address family before timeout
210 /// elapses, then connector will in parallel attempt connection using other
211 /// address family.
212 ///
213 /// If `None`, parallel connection attempts are disabled.
214 ///
215 /// Default is 300 milliseconds.
216 ///
217 /// [RFC 6555]: https://tools.ietf.org/html/rfc6555
218 #[inline]
set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>)219 pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
220 self.config_mut().happy_eyeballs_timeout = dur;
221 }
222
223 /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`.
224 ///
225 /// Default is `false`.
226 #[inline]
set_reuse_address(&mut self, reuse_address: bool) -> &mut Self227 pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
228 self.config_mut().reuse_address = reuse_address;
229 self
230 }
231
232 // private
233
config_mut(&mut self) -> &mut Config234 fn config_mut(&mut self) -> &mut Config {
235 // If the are HttpConnector clones, this will clone the inner
236 // config. So mutating the config won't ever affect previous
237 // clones.
238 Arc::make_mut(&mut self.config)
239 }
240 }
241
242 static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
243 static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
244 static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
245
246 // R: Debug required for now to allow adding it to debug output later...
247 impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249 f.debug_struct("HttpConnector").finish()
250 }
251 }
252
253 impl<R> tower_service::Service<Uri> for HttpConnector<R>
254 where
255 R: Resolve + Clone + Send + Sync + 'static,
256 R::Future: Send,
257 {
258 type Response = TcpStream;
259 type Error = ConnectError;
260 type Future = HttpConnecting<R>;
261
poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>262 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
263 ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
264 Poll::Ready(Ok(()))
265 }
266
call(&mut self, dst: Uri) -> Self::Future267 fn call(&mut self, dst: Uri) -> Self::Future {
268 let mut self_ = self.clone();
269 HttpConnecting {
270 fut: Box::pin(async move { self_.call_async(dst).await }),
271 _marker: PhantomData,
272 }
273 }
274 }
275
get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError>276 fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> {
277 trace!(
278 "Http::connect; scheme={:?}, host={:?}, port={:?}",
279 dst.scheme(),
280 dst.host(),
281 dst.port(),
282 );
283
284 if config.enforce_http {
285 if dst.scheme() != Some(&Scheme::HTTP) {
286 return Err(ConnectError {
287 msg: INVALID_NOT_HTTP.into(),
288 cause: None,
289 });
290 }
291 } else if dst.scheme().is_none() {
292 return Err(ConnectError {
293 msg: INVALID_MISSING_SCHEME.into(),
294 cause: None,
295 });
296 }
297
298 let host = match dst.host() {
299 Some(s) => s,
300 None => {
301 return Err(ConnectError {
302 msg: INVALID_MISSING_HOST.into(),
303 cause: None,
304 })
305 }
306 };
307 let port = match dst.port() {
308 Some(port) => port.as_u16(),
309 None => {
310 if dst.scheme() == Some(&Scheme::HTTPS) {
311 443
312 } else {
313 80
314 }
315 }
316 };
317
318 Ok((host, port))
319 }
320
321 impl<R> HttpConnector<R>
322 where
323 R: Resolve,
324 {
call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError>325 async fn call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError> {
326 let config = &self.config;
327
328 let (host, port) = get_host_port(config, &dst)?;
329
330 // If the host is already an IP addr (v4 or v6),
331 // skip resolving the dns and start connecting right away.
332 let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) {
333 addrs
334 } else {
335 let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
336 .await
337 .map_err(ConnectError::dns)?;
338 let addrs = addrs
339 .map(|mut addr| {
340 addr.set_port(port);
341 addr
342 })
343 .collect();
344 dns::SocketAddrs::new(addrs)
345 };
346
347 let c = ConnectingTcp::new(addrs, config);
348
349 let sock = c.connect().await?;
350
351 if let Err(e) = sock.set_nodelay(config.nodelay) {
352 warn!("tcp set_nodelay error: {}", e);
353 }
354
355 Ok(sock)
356 }
357 }
358
359 impl Connection for TcpStream {
connected(&self) -> Connected360 fn connected(&self) -> Connected {
361 let connected = Connected::new();
362 if let Ok(remote_addr) = self.peer_addr() {
363 connected.extra(HttpInfo { remote_addr })
364 } else {
365 connected
366 }
367 }
368 }
369
370 impl HttpInfo {
371 /// Get the remote address of the transport used.
remote_addr(&self) -> SocketAddr372 pub fn remote_addr(&self) -> SocketAddr {
373 self.remote_addr
374 }
375 }
376
377 pin_project! {
378 // Not publicly exported (so missing_docs doesn't trigger).
379 //
380 // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly
381 // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot
382 // (and thus we can change the type in the future).
383 #[must_use = "futures do nothing unless polled"]
384 #[allow(missing_debug_implementations)]
385 pub struct HttpConnecting<R> {
386 #[pin]
387 fut: BoxConnecting,
388 _marker: PhantomData<R>,
389 }
390 }
391
392 type ConnectResult = Result<TcpStream, ConnectError>;
393 type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
394
395 impl<R: Resolve> Future for HttpConnecting<R> {
396 type Output = ConnectResult;
397
poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>398 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
399 self.project().fut.poll(cx)
400 }
401 }
402
403 // Not publicly exported (so missing_docs doesn't trigger).
404 pub struct ConnectError {
405 msg: Box<str>,
406 cause: Option<Box<dyn StdError + Send + Sync>>,
407 }
408
409 impl ConnectError {
new<S, E>(msg: S, cause: E) -> ConnectError where S: Into<Box<str>>, E: Into<Box<dyn StdError + Send + Sync>>,410 fn new<S, E>(msg: S, cause: E) -> ConnectError
411 where
412 S: Into<Box<str>>,
413 E: Into<Box<dyn StdError + Send + Sync>>,
414 {
415 ConnectError {
416 msg: msg.into(),
417 cause: Some(cause.into()),
418 }
419 }
420
dns<E>(cause: E) -> ConnectError where E: Into<Box<dyn StdError + Send + Sync>>,421 fn dns<E>(cause: E) -> ConnectError
422 where
423 E: Into<Box<dyn StdError + Send + Sync>>,
424 {
425 ConnectError::new("dns error", cause)
426 }
427
m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError where S: Into<Box<str>>, E: Into<Box<dyn StdError + Send + Sync>>,428 fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError
429 where
430 S: Into<Box<str>>,
431 E: Into<Box<dyn StdError + Send + Sync>>,
432 {
433 move |cause| ConnectError::new(msg, cause)
434 }
435 }
436
437 impl fmt::Debug for ConnectError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result438 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
439 if let Some(ref cause) = self.cause {
440 f.debug_tuple("ConnectError")
441 .field(&self.msg)
442 .field(cause)
443 .finish()
444 } else {
445 self.msg.fmt(f)
446 }
447 }
448 }
449
450 impl fmt::Display for ConnectError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result451 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
452 f.write_str(&self.msg)?;
453
454 if let Some(ref cause) = self.cause {
455 write!(f, ": {}", cause)?;
456 }
457
458 Ok(())
459 }
460 }
461
462 impl StdError for ConnectError {
source(&self) -> Option<&(dyn StdError + 'static)>463 fn source(&self) -> Option<&(dyn StdError + 'static)> {
464 self.cause.as_ref().map(|e| &**e as _)
465 }
466 }
467
468 struct ConnectingTcp<'a> {
469 preferred: ConnectingTcpRemote,
470 fallback: Option<ConnectingTcpFallback>,
471 config: &'a Config,
472 }
473
474 impl<'a> ConnectingTcp<'a> {
new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self475 fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self {
476 if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
477 let (preferred_addrs, fallback_addrs) = remote_addrs
478 .split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
479 if fallback_addrs.is_empty() {
480 return ConnectingTcp {
481 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
482 fallback: None,
483 config,
484 };
485 }
486
487 ConnectingTcp {
488 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
489 fallback: Some(ConnectingTcpFallback {
490 delay: tokio::time::sleep(fallback_timeout),
491 remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
492 }),
493 config,
494 }
495 } else {
496 ConnectingTcp {
497 preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
498 fallback: None,
499 config,
500 }
501 }
502 }
503 }
504
505 struct ConnectingTcpFallback {
506 delay: Sleep,
507 remote: ConnectingTcpRemote,
508 }
509
510 struct ConnectingTcpRemote {
511 addrs: dns::SocketAddrs,
512 connect_timeout: Option<Duration>,
513 }
514
515 impl ConnectingTcpRemote {
new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self516 fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self {
517 let connect_timeout = connect_timeout.map(|t| t / (addrs.len() as u32));
518
519 Self {
520 addrs,
521 connect_timeout,
522 }
523 }
524 }
525
526 impl ConnectingTcpRemote {
connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError>527 async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
528 let mut err = None;
529 for addr in &mut self.addrs {
530 debug!("connecting to {}", addr);
531 match connect(&addr, config, self.connect_timeout)?.await {
532 Ok(tcp) => {
533 debug!("connected to {}", addr);
534 return Ok(tcp);
535 }
536 Err(e) => {
537 trace!("connect error for {}: {:?}", addr, e);
538 err = Some(e);
539 }
540 }
541 }
542
543 match err {
544 Some(e) => Err(e),
545 None => Err(ConnectError::new(
546 "tcp connect error",
547 std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
548 )),
549 }
550 }
551 }
552
bind_local_address( socket: &socket2::Socket, dst_addr: &SocketAddr, local_addr_ipv4: &Option<Ipv4Addr>, local_addr_ipv6: &Option<Ipv6Addr>, ) -> io::Result<()>553 fn bind_local_address(
554 socket: &socket2::Socket,
555 dst_addr: &SocketAddr,
556 local_addr_ipv4: &Option<Ipv4Addr>,
557 local_addr_ipv6: &Option<Ipv6Addr>,
558 ) -> io::Result<()> {
559 match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
560 (SocketAddr::V4(_), Some(addr), _) => {
561 socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
562 }
563 (SocketAddr::V6(_), _, Some(addr)) => {
564 socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
565 }
566 _ => {
567 if cfg!(windows) {
568 // Windows requires a socket be bound before calling connect
569 let any: SocketAddr = match *dst_addr {
570 SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
571 SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
572 };
573 socket.bind(&any.into())?;
574 }
575 }
576 }
577
578 Ok(())
579 }
580
connect( addr: &SocketAddr, config: &Config, connect_timeout: Option<Duration>, ) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError>581 fn connect(
582 addr: &SocketAddr,
583 config: &Config,
584 connect_timeout: Option<Duration>,
585 ) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
586 // TODO(eliza): if Tokio's `TcpSocket` gains support for setting the
587 // keepalive timeout, it would be nice to use that instead of socket2,
588 // and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance...
589 use socket2::{Domain, Protocol, Socket, TcpKeepalive, Type};
590 use std::convert::TryInto;
591
592 let domain = Domain::for_address(*addr);
593 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
594 .map_err(ConnectError::m("tcp open error"))?;
595
596 // When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is
597 // responsible for ensuring O_NONBLOCK is set.
598 socket
599 .set_nonblocking(true)
600 .map_err(ConnectError::m("tcp set_nonblocking error"))?;
601
602 if let Some(dur) = config.keep_alive_timeout {
603 let conf = TcpKeepalive::new().with_time(dur);
604 if let Err(e) = socket.set_tcp_keepalive(&conf) {
605 warn!("tcp set_keepalive error: {}", e);
606 }
607 }
608
609 bind_local_address(
610 &socket,
611 addr,
612 &config.local_address_ipv4,
613 &config.local_address_ipv6,
614 )
615 .map_err(ConnectError::m("tcp bind local error"))?;
616
617 #[cfg(unix)]
618 let socket = unsafe {
619 // Safety: `from_raw_fd` is only safe to call if ownership of the raw
620 // file descriptor is transferred. Since we call `into_raw_fd` on the
621 // socket2 socket, it gives up ownership of the fd and will not close
622 // it, so this is safe.
623 use std::os::unix::io::{FromRawFd, IntoRawFd};
624 TcpSocket::from_raw_fd(socket.into_raw_fd())
625 };
626 #[cfg(windows)]
627 let socket = unsafe {
628 // Safety: `from_raw_socket` is only safe to call if ownership of the raw
629 // Windows SOCKET is transferred. Since we call `into_raw_socket` on the
630 // socket2 socket, it gives up ownership of the SOCKET and will not close
631 // it, so this is safe.
632 use std::os::windows::io::{FromRawSocket, IntoRawSocket};
633 TcpSocket::from_raw_socket(socket.into_raw_socket())
634 };
635
636 if config.reuse_address {
637 if let Err(e) = socket.set_reuseaddr(true) {
638 warn!("tcp set_reuse_address error: {}", e);
639 }
640 }
641
642 if let Some(size) = config.send_buffer_size {
643 if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) {
644 warn!("tcp set_buffer_size error: {}", e);
645 }
646 }
647
648 if let Some(size) = config.recv_buffer_size {
649 if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) {
650 warn!("tcp set_recv_buffer_size error: {}", e);
651 }
652 }
653
654 let connect = socket.connect(*addr);
655 Ok(async move {
656 match connect_timeout {
657 Some(dur) => match tokio::time::timeout(dur, connect).await {
658 Ok(Ok(s)) => Ok(s),
659 Ok(Err(e)) => Err(e),
660 Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
661 },
662 None => connect.await,
663 }
664 .map_err(ConnectError::m("tcp connect error"))
665 })
666 }
667
668 impl ConnectingTcp<'_> {
connect(mut self) -> Result<TcpStream, ConnectError>669 async fn connect(mut self) -> Result<TcpStream, ConnectError> {
670 match self.fallback {
671 None => self.preferred.connect(self.config).await,
672 Some(mut fallback) => {
673 let preferred_fut = self.preferred.connect(self.config);
674 futures_util::pin_mut!(preferred_fut);
675
676 let fallback_fut = fallback.remote.connect(self.config);
677 futures_util::pin_mut!(fallback_fut);
678
679 let fallback_delay = fallback.delay;
680 futures_util::pin_mut!(fallback_delay);
681
682 let (result, future) =
683 match futures_util::future::select(preferred_fut, fallback_delay).await {
684 Either::Left((result, _fallback_delay)) => {
685 (result, Either::Right(fallback_fut))
686 }
687 Either::Right(((), preferred_fut)) => {
688 // Delay is done, start polling both the preferred and the fallback
689 futures_util::future::select(preferred_fut, fallback_fut)
690 .await
691 .factor_first()
692 }
693 };
694
695 if result.is_err() {
696 // Fallback to the remaining future (could be preferred or fallback)
697 // if we get an error
698 future.await
699 } else {
700 result
701 }
702 }
703 }
704 }
705 }
706
707 #[cfg(test)]
708 mod tests {
709 use std::io;
710
711 use ::http::Uri;
712
713 use super::super::sealed::{Connect, ConnectSvc};
714 use super::{Config, ConnectError, HttpConnector};
715
connect<C>( connector: C, dst: Uri, ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error> where C: Connect,716 async fn connect<C>(
717 connector: C,
718 dst: Uri,
719 ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
720 where
721 C: Connect,
722 {
723 connector.connect(super::super::sealed::Internal, dst).await
724 }
725
726 #[tokio::test]
test_errors_enforce_http()727 async fn test_errors_enforce_http() {
728 let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
729 let connector = HttpConnector::new();
730
731 let err = connect(connector, dst).await.unwrap_err();
732 assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
733 }
734
735 #[cfg(any(target_os = "linux", target_os = "macos"))]
get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>)736 fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
737 use std::net::{IpAddr, TcpListener};
738
739 let mut ip_v4 = None;
740 let mut ip_v6 = None;
741
742 let ips = pnet_datalink::interfaces()
743 .into_iter()
744 .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
745
746 for ip in ips {
747 match ip {
748 IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
749 IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
750 _ => (),
751 }
752
753 if ip_v4.is_some() && ip_v6.is_some() {
754 break;
755 }
756 }
757
758 (ip_v4, ip_v6)
759 }
760
761 #[tokio::test]
test_errors_missing_scheme()762 async fn test_errors_missing_scheme() {
763 let dst = "example.domain".parse().unwrap();
764 let mut connector = HttpConnector::new();
765 connector.enforce_http(false);
766
767 let err = connect(connector, dst).await.unwrap_err();
768 assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
769 }
770
771 // NOTE: pnet crate that we use in this test doesn't compile on Windows
772 #[cfg(any(target_os = "linux", target_os = "macos"))]
773 #[tokio::test]
local_address()774 async fn local_address() {
775 use std::net::{IpAddr, TcpListener};
776 let _ = pretty_env_logger::try_init();
777
778 let (bind_ip_v4, bind_ip_v6) = get_local_ips();
779 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
780 let port = server4.local_addr().unwrap().port();
781 let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
782
783 let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
784 let mut connector = HttpConnector::new();
785
786 match (bind_ip_v4, bind_ip_v6) {
787 (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
788 (Some(v4), None) => connector.set_local_address(Some(v4.into())),
789 (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
790 _ => unreachable!(),
791 }
792
793 connect(connector, dst.parse().unwrap()).await.unwrap();
794
795 let (_, client_addr) = server.accept().unwrap();
796
797 assert_eq!(client_addr.ip(), expected_ip);
798 };
799
800 if let Some(ip) = bind_ip_v4 {
801 assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await;
802 }
803
804 if let Some(ip) = bind_ip_v6 {
805 assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await;
806 }
807 }
808
809 #[test]
810 #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
client_happy_eyeballs()811 fn client_happy_eyeballs() {
812 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
813 use std::time::{Duration, Instant};
814
815 use super::dns;
816 use super::ConnectingTcp;
817
818 let _ = pretty_env_logger::try_init();
819 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
820 let addr = server4.local_addr().unwrap();
821 let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap();
822 let rt = tokio::runtime::Builder::new_current_thread()
823 .enable_all()
824 .build()
825 .unwrap();
826
827 let local_timeout = Duration::default();
828 let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
829 let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
830 let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
831 + Duration::from_millis(250);
832
833 let scenarios = &[
834 // Fast primary, without fallback.
835 (&[local_ipv4_addr()][..], 4, local_timeout, false),
836 (&[local_ipv6_addr()][..], 6, local_timeout, false),
837 // Fast primary, with (unused) fallback.
838 (
839 &[local_ipv4_addr(), local_ipv6_addr()][..],
840 4,
841 local_timeout,
842 false,
843 ),
844 (
845 &[local_ipv6_addr(), local_ipv4_addr()][..],
846 6,
847 local_timeout,
848 false,
849 ),
850 // Unreachable + fast primary, without fallback.
851 (
852 &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
853 4,
854 unreachable_v4_timeout,
855 false,
856 ),
857 (
858 &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
859 6,
860 unreachable_v6_timeout,
861 false,
862 ),
863 // Unreachable + fast primary, with (unused) fallback.
864 (
865 &[
866 unreachable_ipv4_addr(),
867 local_ipv4_addr(),
868 local_ipv6_addr(),
869 ][..],
870 4,
871 unreachable_v4_timeout,
872 false,
873 ),
874 (
875 &[
876 unreachable_ipv6_addr(),
877 local_ipv6_addr(),
878 local_ipv4_addr(),
879 ][..],
880 6,
881 unreachable_v6_timeout,
882 true,
883 ),
884 // Slow primary, with (used) fallback.
885 (
886 &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
887 6,
888 fallback_timeout,
889 false,
890 ),
891 (
892 &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
893 4,
894 fallback_timeout,
895 true,
896 ),
897 // Slow primary, with (used) unreachable + fast fallback.
898 (
899 &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
900 6,
901 fallback_timeout + unreachable_v6_timeout,
902 false,
903 ),
904 (
905 &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
906 4,
907 fallback_timeout + unreachable_v4_timeout,
908 true,
909 ),
910 ];
911
912 // Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network.
913 // Otherwise, connection to "slow" IPv6 address will error-out immediately.
914 let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
915
916 for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
917 if needs_ipv6_access && !ipv6_accessible {
918 continue;
919 }
920
921 let (start, stream) = rt
922 .block_on(async move {
923 let addrs = hosts
924 .iter()
925 .map(|host| (host.clone(), addr.port()).into())
926 .collect();
927 let cfg = Config {
928 local_address_ipv4: None,
929 local_address_ipv6: None,
930 connect_timeout: None,
931 keep_alive_timeout: None,
932 happy_eyeballs_timeout: Some(fallback_timeout),
933 nodelay: false,
934 reuse_address: false,
935 enforce_http: false,
936 send_buffer_size: None,
937 recv_buffer_size: None,
938 };
939 let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
940 let start = Instant::now();
941 Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
942 })
943 .unwrap();
944 let res = if stream.peer_addr().unwrap().is_ipv4() {
945 4
946 } else {
947 6
948 };
949 let duration = start.elapsed();
950
951 // Allow actual duration to be +/- 150ms off.
952 let min_duration = if timeout >= Duration::from_millis(150) {
953 timeout - Duration::from_millis(150)
954 } else {
955 Duration::default()
956 };
957 let max_duration = timeout + Duration::from_millis(150);
958
959 assert_eq!(res, family);
960 assert!(duration >= min_duration);
961 assert!(duration <= max_duration);
962 }
963
964 fn local_ipv4_addr() -> IpAddr {
965 Ipv4Addr::new(127, 0, 0, 1).into()
966 }
967
968 fn local_ipv6_addr() -> IpAddr {
969 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
970 }
971
972 fn unreachable_ipv4_addr() -> IpAddr {
973 Ipv4Addr::new(127, 0, 0, 2).into()
974 }
975
976 fn unreachable_ipv6_addr() -> IpAddr {
977 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
978 }
979
980 fn slow_ipv4_addr() -> IpAddr {
981 // RFC 6890 reserved IPv4 address.
982 Ipv4Addr::new(198, 18, 0, 25).into()
983 }
984
985 fn slow_ipv6_addr() -> IpAddr {
986 // RFC 6890 reserved IPv6 address.
987 Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
988 }
989
990 fn measure_connect(addr: IpAddr) -> (bool, Duration) {
991 let start = Instant::now();
992 let result =
993 std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
994
995 let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
996 let duration = start.elapsed();
997 (reachable, duration)
998 }
999 }
1000 }
1001