1 use std::error::Error as StdError;
2 use std::fmt;
3 use std::future::Future;
4 use std::io;
5 use std::marker::PhantomData;
6 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7 use std::pin::Pin;
8 use std::sync::Arc;
9 use std::task::{self, Poll};
10 use std::time::Duration;
11
12 use futures_util::future::Either;
13 use http::uri::{Scheme, Uri};
14 use pin_project::pin_project;
15 use tokio::net::TcpStream;
16 use tokio::time::Delay;
17
18 use super::dns::{self, resolve, GaiResolver, Resolve};
19 use super::{Connected, Connection};
20 //#[cfg(feature = "runtime")] use super::dns::TokioThreadpoolGaiResolver;
21
22 /// A connector for the `http` scheme.
23 ///
24 /// Performs DNS resolution in a thread pool, and then connects over TCP.
25 ///
26 /// # Note
27 ///
28 /// Sets the [`HttpInfo`](HttpInfo) value on responses, which includes
29 /// transport information such as the remote socket address used.
30 #[derive(Clone)]
31 pub struct HttpConnector<R = GaiResolver> {
32 config: Arc<Config>,
33 resolver: R,
34 }
35
36 /// Extra information about the transport when an HttpConnector is used.
37 ///
38 /// # Example
39 ///
40 /// ```
41 /// # async fn doc() -> hyper::Result<()> {
42 /// use hyper::Uri;
43 /// use hyper::client::{Client, connect::HttpInfo};
44 ///
45 /// let client = Client::new();
46 /// let uri = Uri::from_static("http://example.com");
47 ///
48 /// let res = client.get(uri).await?;
49 /// res
50 /// .extensions()
51 /// .get::<HttpInfo>()
52 /// .map(|info| {
53 /// println!("remote addr = {}", info.remote_addr());
54 /// });
55 /// # Ok(())
56 /// # }
57 /// ```
58 ///
59 /// # Note
60 ///
61 /// If a different connector is used besides [`HttpConnector`](HttpConnector),
62 /// this value will not exist in the extensions. Consult that specific
63 /// connector to see what "extra" information it might provide to responses.
64 #[derive(Clone, Debug)]
65 pub struct HttpInfo {
66 remote_addr: SocketAddr,
67 }
68
69 #[derive(Clone)]
70 struct Config {
71 connect_timeout: Option<Duration>,
72 enforce_http: bool,
73 happy_eyeballs_timeout: Option<Duration>,
74 keep_alive_timeout: Option<Duration>,
75 local_address_ipv4: Option<Ipv4Addr>,
76 local_address_ipv6: Option<Ipv6Addr>,
77 nodelay: bool,
78 reuse_address: bool,
79 send_buffer_size: Option<usize>,
80 recv_buffer_size: Option<usize>,
81 }
82
83 // ===== impl HttpConnector =====
84
85 impl HttpConnector {
86 /// Construct a new HttpConnector.
new() -> HttpConnector87 pub fn new() -> HttpConnector {
88 HttpConnector::new_with_resolver(GaiResolver::new())
89 }
90 }
91
92 /*
93 #[cfg(feature = "runtime")]
94 impl HttpConnector<TokioThreadpoolGaiResolver> {
95 /// Construct a new HttpConnector using the `TokioThreadpoolGaiResolver`.
96 ///
97 /// This resolver **requires** the threadpool runtime to be used.
98 pub fn new_with_tokio_threadpool_resolver() -> Self {
99 HttpConnector::new_with_resolver(TokioThreadpoolGaiResolver::new())
100 }
101 }
102 */
103
104 impl<R> HttpConnector<R> {
105 /// Construct a new HttpConnector.
106 ///
107 /// Takes a [`Resolver`](crate::client::connect::dns#resolvers-are-services) to handle DNS lookups.
new_with_resolver(resolver: R) -> HttpConnector<R>108 pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
109 HttpConnector {
110 config: Arc::new(Config {
111 connect_timeout: None,
112 enforce_http: true,
113 happy_eyeballs_timeout: Some(Duration::from_millis(300)),
114 keep_alive_timeout: None,
115 local_address_ipv4: None,
116 local_address_ipv6: None,
117 nodelay: false,
118 reuse_address: false,
119 send_buffer_size: None,
120 recv_buffer_size: None,
121 }),
122 resolver,
123 }
124 }
125
126 /// Option to enforce all `Uri`s have the `http` scheme.
127 ///
128 /// Enabled by default.
129 #[inline]
enforce_http(&mut self, is_enforced: bool)130 pub fn enforce_http(&mut self, is_enforced: bool) {
131 self.config_mut().enforce_http = is_enforced;
132 }
133
134 /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration.
135 ///
136 /// If `None`, the option will not be set.
137 ///
138 /// Default is `None`.
139 #[inline]
set_keepalive(&mut self, dur: Option<Duration>)140 pub fn set_keepalive(&mut self, dur: Option<Duration>) {
141 self.config_mut().keep_alive_timeout = dur;
142 }
143
144 /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`.
145 ///
146 /// Default is `false`.
147 #[inline]
set_nodelay(&mut self, nodelay: bool)148 pub fn set_nodelay(&mut self, nodelay: bool) {
149 self.config_mut().nodelay = nodelay;
150 }
151
152 /// Sets the value of the SO_SNDBUF option on the socket.
153 #[inline]
set_send_buffer_size(&mut self, size: Option<usize>)154 pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
155 self.config_mut().send_buffer_size = size;
156 }
157
158 /// Sets the value of the SO_RCVBUF option on the socket.
159 #[inline]
set_recv_buffer_size(&mut self, size: Option<usize>)160 pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
161 self.config_mut().recv_buffer_size = size;
162 }
163
164 /// Set that all sockets are bound to the configured address before connection.
165 ///
166 /// If `None`, the sockets will not be bound.
167 ///
168 /// Default is `None`.
169 #[inline]
set_local_address(&mut self, addr: Option<IpAddr>)170 pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
171 let (v4, v6) = match addr {
172 Some(IpAddr::V4(a)) => (Some(a), None),
173 Some(IpAddr::V6(a)) => (None, Some(a)),
174 _ => (None, None),
175 };
176
177 let cfg = self.config_mut();
178
179 cfg.local_address_ipv4 = v4;
180 cfg.local_address_ipv6 = v6;
181 }
182
183 /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's
184 /// preferences) before connection.
185 #[inline]
set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr)186 pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
187 let cfg = self.config_mut();
188
189 cfg.local_address_ipv4 = Some(addr_ipv4);
190 cfg.local_address_ipv6 = Some(addr_ipv6);
191 }
192
193 /// Set the connect timeout.
194 ///
195 /// If a domain resolves to multiple IP addresses, the timeout will be
196 /// evenly divided across them.
197 ///
198 /// Default is `None`.
199 #[inline]
set_connect_timeout(&mut self, dur: Option<Duration>)200 pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
201 self.config_mut().connect_timeout = dur;
202 }
203
204 /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm.
205 ///
206 /// If hostname resolves to both IPv4 and IPv6 addresses and connection
207 /// cannot be established using preferred address family before timeout
208 /// elapses, then connector will in parallel attempt connection using other
209 /// address family.
210 ///
211 /// If `None`, parallel connection attempts are disabled.
212 ///
213 /// Default is 300 milliseconds.
214 ///
215 /// [RFC 6555]: https://tools.ietf.org/html/rfc6555
216 #[inline]
set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>)217 pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
218 self.config_mut().happy_eyeballs_timeout = dur;
219 }
220
221 /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`.
222 ///
223 /// Default is `false`.
224 #[inline]
set_reuse_address(&mut self, reuse_address: bool) -> &mut Self225 pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
226 self.config_mut().reuse_address = reuse_address;
227 self
228 }
229
230 // private
231
config_mut(&mut self) -> &mut Config232 fn config_mut(&mut self) -> &mut Config {
233 // If the are HttpConnector clones, this will clone the inner
234 // config. So mutating the config won't ever affect previous
235 // clones.
236 Arc::make_mut(&mut self.config)
237 }
238 }
239
240 static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
241 static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
242 static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
243
244 // R: Debug required for now to allow adding it to debug output later...
245 impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result246 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247 f.debug_struct("HttpConnector").finish()
248 }
249 }
250
251 impl<R> tower_service::Service<Uri> for HttpConnector<R>
252 where
253 R: Resolve + Clone + Send + Sync + 'static,
254 R::Future: Send,
255 {
256 type Response = TcpStream;
257 type Error = ConnectError;
258 type Future = HttpConnecting<R>;
259
poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>260 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
261 ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
262 Poll::Ready(Ok(()))
263 }
264
call(&mut self, dst: Uri) -> Self::Future265 fn call(&mut self, dst: Uri) -> Self::Future {
266 let mut self_ = self.clone();
267 HttpConnecting {
268 fut: Box::pin(async move { self_.call_async(dst).await }),
269 _marker: PhantomData,
270 }
271 }
272 }
273
274 impl<R> HttpConnector<R>
275 where
276 R: Resolve,
277 {
call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError>278 async fn call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError> {
279 trace!(
280 "Http::connect; scheme={:?}, host={:?}, port={:?}",
281 dst.scheme(),
282 dst.host(),
283 dst.port(),
284 );
285
286 if self.config.enforce_http {
287 if dst.scheme() != Some(&Scheme::HTTP) {
288 return Err(ConnectError {
289 msg: INVALID_NOT_HTTP.into(),
290 cause: None,
291 });
292 }
293 } else if dst.scheme().is_none() {
294 return Err(ConnectError {
295 msg: INVALID_MISSING_SCHEME.into(),
296 cause: None,
297 });
298 }
299
300 let host = match dst.host() {
301 Some(s) => s,
302 None => {
303 return Err(ConnectError {
304 msg: INVALID_MISSING_HOST.into(),
305 cause: None,
306 })
307 }
308 };
309 let port = match dst.port() {
310 Some(port) => port.as_u16(),
311 None => {
312 if dst.scheme() == Some(&Scheme::HTTPS) {
313 443
314 } else {
315 80
316 }
317 }
318 };
319
320 let config = &self.config;
321
322 // If the host is already an IP addr (v4 or v6),
323 // skip resolving the dns and start connecting right away.
324 let addrs = if let Some(addrs) = dns::IpAddrs::try_parse(host, port) {
325 addrs
326 } else {
327 let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
328 .await
329 .map_err(ConnectError::dns)?;
330 let addrs = addrs.map(|addr| SocketAddr::new(addr, port)).collect();
331 dns::IpAddrs::new(addrs)
332 };
333
334 let c = ConnectingTcp::new(
335 config.local_address_ipv4,
336 config.local_address_ipv6,
337 addrs,
338 config.connect_timeout,
339 config.happy_eyeballs_timeout,
340 config.reuse_address,
341 );
342
343 let sock = c
344 .connect()
345 .await
346 .map_err(ConnectError::m("tcp connect error"))?;
347
348 if let Some(dur) = config.keep_alive_timeout {
349 sock.set_keepalive(Some(dur))
350 .map_err(ConnectError::m("tcp set_keepalive error"))?;
351 }
352
353 if let Some(size) = config.send_buffer_size {
354 sock.set_send_buffer_size(size)
355 .map_err(ConnectError::m("tcp set_send_buffer_size error"))?;
356 }
357
358 if let Some(size) = config.recv_buffer_size {
359 sock.set_recv_buffer_size(size)
360 .map_err(ConnectError::m("tcp set_recv_buffer_size error"))?;
361 }
362
363 sock.set_nodelay(config.nodelay)
364 .map_err(ConnectError::m("tcp set_nodelay error"))?;
365
366 Ok(sock)
367 }
368 }
369
370 impl Connection for TcpStream {
connected(&self) -> Connected371 fn connected(&self) -> Connected {
372 let connected = Connected::new();
373 if let Ok(remote_addr) = self.peer_addr() {
374 connected.extra(HttpInfo { remote_addr })
375 } else {
376 connected
377 }
378 }
379 }
380
381 impl HttpInfo {
382 /// Get the remote address of the transport used.
remote_addr(&self) -> SocketAddr383 pub fn remote_addr(&self) -> SocketAddr {
384 self.remote_addr
385 }
386 }
387
388 // Not publicly exported (so missing_docs doesn't trigger).
389 //
390 // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly
391 // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot
392 // (and thus we can change the type in the future).
393 #[must_use = "futures do nothing unless polled"]
394 #[pin_project]
395 #[allow(missing_debug_implementations)]
396 pub struct HttpConnecting<R> {
397 #[pin]
398 fut: BoxConnecting,
399 _marker: PhantomData<R>,
400 }
401
402 type ConnectResult = Result<TcpStream, ConnectError>;
403 type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
404
405 impl<R: Resolve> Future for HttpConnecting<R> {
406 type Output = ConnectResult;
407
poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>408 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
409 self.project().fut.poll(cx)
410 }
411 }
412
413 // Not publicly exported (so missing_docs doesn't trigger).
414 pub struct ConnectError {
415 msg: Box<str>,
416 cause: Option<Box<dyn StdError + Send + Sync>>,
417 }
418
419 impl ConnectError {
new<S, E>(msg: S, cause: E) -> ConnectError where S: Into<Box<str>>, E: Into<Box<dyn StdError + Send + Sync>>,420 fn new<S, E>(msg: S, cause: E) -> ConnectError
421 where
422 S: Into<Box<str>>,
423 E: Into<Box<dyn StdError + Send + Sync>>,
424 {
425 ConnectError {
426 msg: msg.into(),
427 cause: Some(cause.into()),
428 }
429 }
430
dns<E>(cause: E) -> ConnectError where E: Into<Box<dyn StdError + Send + Sync>>,431 fn dns<E>(cause: E) -> ConnectError
432 where
433 E: Into<Box<dyn StdError + Send + Sync>>,
434 {
435 ConnectError::new("dns error", cause)
436 }
437
m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError where S: Into<Box<str>>, E: Into<Box<dyn StdError + Send + Sync>>,438 fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError
439 where
440 S: Into<Box<str>>,
441 E: Into<Box<dyn StdError + Send + Sync>>,
442 {
443 move |cause| ConnectError::new(msg, cause)
444 }
445 }
446
447 impl fmt::Debug for ConnectError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result448 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
449 if let Some(ref cause) = self.cause {
450 f.debug_tuple("ConnectError")
451 .field(&self.msg)
452 .field(cause)
453 .finish()
454 } else {
455 self.msg.fmt(f)
456 }
457 }
458 }
459
460 impl fmt::Display for ConnectError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result461 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
462 f.write_str(&self.msg)?;
463
464 if let Some(ref cause) = self.cause {
465 write!(f, ": {}", cause)?;
466 }
467
468 Ok(())
469 }
470 }
471
472 impl StdError for ConnectError {
source(&self) -> Option<&(dyn StdError + 'static)>473 fn source(&self) -> Option<&(dyn StdError + 'static)> {
474 self.cause.as_ref().map(|e| &**e as _)
475 }
476 }
477
478 struct ConnectingTcp {
479 local_addr_ipv4: Option<Ipv4Addr>,
480 local_addr_ipv6: Option<Ipv6Addr>,
481 preferred: ConnectingTcpRemote,
482 fallback: Option<ConnectingTcpFallback>,
483 reuse_address: bool,
484 }
485
486 impl ConnectingTcp {
new( local_addr_ipv4: Option<Ipv4Addr>, local_addr_ipv6: Option<Ipv6Addr>, remote_addrs: dns::IpAddrs, connect_timeout: Option<Duration>, fallback_timeout: Option<Duration>, reuse_address: bool, ) -> ConnectingTcp487 fn new(
488 local_addr_ipv4: Option<Ipv4Addr>,
489 local_addr_ipv6: Option<Ipv6Addr>,
490 remote_addrs: dns::IpAddrs,
491 connect_timeout: Option<Duration>,
492 fallback_timeout: Option<Duration>,
493 reuse_address: bool,
494 ) -> ConnectingTcp {
495 if let Some(fallback_timeout) = fallback_timeout {
496 let (preferred_addrs, fallback_addrs) =
497 remote_addrs.split_by_preference(local_addr_ipv4, local_addr_ipv6);
498 if fallback_addrs.is_empty() {
499 return ConnectingTcp {
500 local_addr_ipv4,
501 local_addr_ipv6,
502 preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout),
503 fallback: None,
504 reuse_address,
505 };
506 }
507
508 ConnectingTcp {
509 local_addr_ipv4,
510 local_addr_ipv6,
511 preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout),
512 fallback: Some(ConnectingTcpFallback {
513 delay: tokio::time::delay_for(fallback_timeout),
514 remote: ConnectingTcpRemote::new(fallback_addrs, connect_timeout),
515 }),
516 reuse_address,
517 }
518 } else {
519 ConnectingTcp {
520 local_addr_ipv4,
521 local_addr_ipv6,
522 preferred: ConnectingTcpRemote::new(remote_addrs, connect_timeout),
523 fallback: None,
524 reuse_address,
525 }
526 }
527 }
528 }
529
530 struct ConnectingTcpFallback {
531 delay: Delay,
532 remote: ConnectingTcpRemote,
533 }
534
535 struct ConnectingTcpRemote {
536 addrs: dns::IpAddrs,
537 connect_timeout: Option<Duration>,
538 }
539
540 impl ConnectingTcpRemote {
new(addrs: dns::IpAddrs, connect_timeout: Option<Duration>) -> Self541 fn new(addrs: dns::IpAddrs, connect_timeout: Option<Duration>) -> Self {
542 let connect_timeout = connect_timeout.map(|t| t / (addrs.len() as u32));
543
544 Self {
545 addrs,
546 connect_timeout,
547 }
548 }
549 }
550
551 impl ConnectingTcpRemote {
connect( &mut self, local_addr_ipv4: &Option<Ipv4Addr>, local_addr_ipv6: &Option<Ipv6Addr>, reuse_address: bool, ) -> io::Result<TcpStream>552 async fn connect(
553 &mut self,
554 local_addr_ipv4: &Option<Ipv4Addr>,
555 local_addr_ipv6: &Option<Ipv6Addr>,
556 reuse_address: bool,
557 ) -> io::Result<TcpStream> {
558 let mut err = None;
559 for addr in &mut self.addrs {
560 debug!("connecting to {}", addr);
561 match connect(
562 &addr,
563 local_addr_ipv4,
564 local_addr_ipv6,
565 reuse_address,
566 self.connect_timeout,
567 )?
568 .await
569 {
570 Ok(tcp) => {
571 debug!("connected to {}", addr);
572 return Ok(tcp);
573 }
574 Err(e) => {
575 trace!("connect error for {}: {:?}", addr, e);
576 err = Some(e);
577 }
578 }
579 }
580
581 match err {
582 Some(e) => Err(e),
583 None => Err(std::io::Error::new(
584 std::io::ErrorKind::NotConnected,
585 "Network unreachable",
586 )),
587 }
588 }
589 }
590
bind_local_address( socket: &socket2::Socket, dst_addr: &SocketAddr, local_addr_ipv4: &Option<Ipv4Addr>, local_addr_ipv6: &Option<Ipv6Addr>, ) -> io::Result<()>591 fn bind_local_address(
592 socket: &socket2::Socket,
593 dst_addr: &SocketAddr,
594 local_addr_ipv4: &Option<Ipv4Addr>,
595 local_addr_ipv6: &Option<Ipv6Addr>,
596 ) -> io::Result<()> {
597 match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
598 (SocketAddr::V4(_), Some(addr), _) => {
599 socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
600 }
601 (SocketAddr::V6(_), _, Some(addr)) => {
602 socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
603 }
604 _ => {
605 if cfg!(windows) {
606 // Windows requires a socket be bound before calling connect
607 let any: SocketAddr = match *dst_addr {
608 SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
609 SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
610 };
611 socket.bind(&any.into())?;
612 }
613 }
614 }
615
616 Ok(())
617 }
618
connect( addr: &SocketAddr, local_addr_ipv4: &Option<Ipv4Addr>, local_addr_ipv6: &Option<Ipv6Addr>, reuse_address: bool, connect_timeout: Option<Duration>, ) -> io::Result<impl Future<Output = io::Result<TcpStream>>>619 fn connect(
620 addr: &SocketAddr,
621 local_addr_ipv4: &Option<Ipv4Addr>,
622 local_addr_ipv6: &Option<Ipv6Addr>,
623 reuse_address: bool,
624 connect_timeout: Option<Duration>,
625 ) -> io::Result<impl Future<Output = io::Result<TcpStream>>> {
626 use socket2::{Domain, Protocol, Socket, Type};
627 let domain = match *addr {
628 SocketAddr::V4(_) => Domain::ipv4(),
629 SocketAddr::V6(_) => Domain::ipv6(),
630 };
631 let socket = Socket::new(domain, Type::stream(), Some(Protocol::tcp()))?;
632
633 if reuse_address {
634 socket.set_reuse_address(true)?;
635 }
636
637 bind_local_address(&socket, addr, local_addr_ipv4, local_addr_ipv6)?;
638
639 let addr = *addr;
640
641 let std_tcp = socket.into_tcp_stream();
642
643 Ok(async move {
644 let connect = TcpStream::connect_std(std_tcp, &addr);
645 match connect_timeout {
646 Some(dur) => match tokio::time::timeout(dur, connect).await {
647 Ok(Ok(s)) => Ok(s),
648 Ok(Err(e)) => Err(e),
649 Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
650 },
651 None => connect.await,
652 }
653 })
654 }
655
656 impl ConnectingTcp {
connect(mut self) -> io::Result<TcpStream>657 async fn connect(mut self) -> io::Result<TcpStream> {
658 let Self {
659 ref local_addr_ipv4,
660 ref local_addr_ipv6,
661 reuse_address,
662 ..
663 } = self;
664 match self.fallback {
665 None => {
666 self.preferred
667 .connect(local_addr_ipv4, local_addr_ipv6, reuse_address)
668 .await
669 }
670 Some(mut fallback) => {
671 let preferred_fut =
672 self.preferred
673 .connect(local_addr_ipv4, local_addr_ipv6, reuse_address);
674 futures_util::pin_mut!(preferred_fut);
675
676 let fallback_fut =
677 fallback
678 .remote
679 .connect(local_addr_ipv4, local_addr_ipv6, reuse_address);
680 futures_util::pin_mut!(fallback_fut);
681
682 let (result, future) =
683 match futures_util::future::select(preferred_fut, fallback.delay).await {
684 Either::Left((result, _fallback_delay)) => {
685 (result, Either::Right(fallback_fut))
686 }
687 Either::Right(((), preferred_fut)) => {
688 // Delay is done, start polling both the preferred and the fallback
689 futures_util::future::select(preferred_fut, fallback_fut)
690 .await
691 .factor_first()
692 }
693 };
694
695 if result.is_err() {
696 // Fallback to the remaining future (could be preferred or fallback)
697 // if we get an error
698 future.await
699 } else {
700 result
701 }
702 }
703 }
704 }
705 }
706
707 #[cfg(test)]
708 mod tests {
709 use std::io;
710
711 use ::http::Uri;
712
713 use super::super::sealed::{Connect, ConnectSvc};
714 use super::HttpConnector;
715
connect<C>( connector: C, dst: Uri, ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error> where C: Connect,716 async fn connect<C>(
717 connector: C,
718 dst: Uri,
719 ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
720 where
721 C: Connect,
722 {
723 connector.connect(super::super::sealed::Internal, dst).await
724 }
725
726 #[tokio::test]
test_errors_enforce_http()727 async fn test_errors_enforce_http() {
728 let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
729 let connector = HttpConnector::new();
730
731 let err = connect(connector, dst).await.unwrap_err();
732 assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
733 }
734
735 #[cfg(any(target_os = "linux", target_os = "macos"))]
get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>)736 fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
737 use std::net::{IpAddr, TcpListener};
738
739 let mut ip_v4 = None;
740 let mut ip_v6 = None;
741
742 let ips = pnet::datalink::interfaces()
743 .into_iter()
744 .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
745
746 for ip in ips {
747 match ip {
748 IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
749 IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
750 _ => (),
751 }
752
753 if ip_v4.is_some() && ip_v6.is_some() {
754 break;
755 }
756 }
757
758 (ip_v4, ip_v6)
759 }
760
761 #[tokio::test]
test_errors_missing_scheme()762 async fn test_errors_missing_scheme() {
763 let dst = "example.domain".parse().unwrap();
764 let mut connector = HttpConnector::new();
765 connector.enforce_http(false);
766
767 let err = connect(connector, dst).await.unwrap_err();
768 assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
769 }
770
771 // NOTE: pnet crate that we use in this test doesn't compile on Windows
772 #[cfg(any(target_os = "linux", target_os = "macos"))]
773 #[tokio::test]
local_address()774 async fn local_address() {
775 use std::net::{IpAddr, TcpListener};
776
777 let (bind_ip_v4, bind_ip_v6) = get_local_ips();
778 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
779 let port = server4.local_addr().unwrap().port();
780 let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
781
782 let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
783 let mut connector = HttpConnector::new();
784
785 match (bind_ip_v4, bind_ip_v6) {
786 (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
787 (Some(v4), None) => connector.set_local_address(Some(v4.into())),
788 (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
789 _ => unreachable!(),
790 }
791
792 connect(connector, dst.parse().unwrap()).await.unwrap();
793
794 let (_, client_addr) = server.accept().unwrap();
795
796 assert_eq!(client_addr.ip(), expected_ip);
797 };
798
799 if let Some(ip) = bind_ip_v4 {
800 assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await;
801 }
802
803 if let Some(ip) = bind_ip_v6 {
804 assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await;
805 }
806 }
807
808 #[test]
809 #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
client_happy_eyeballs()810 fn client_happy_eyeballs() {
811 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
812 use std::time::{Duration, Instant};
813
814 use super::dns;
815 use super::ConnectingTcp;
816
817 let _ = pretty_env_logger::try_init();
818 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
819 let addr = server4.local_addr().unwrap();
820 let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap();
821 let mut rt = tokio::runtime::Builder::new()
822 .enable_io()
823 .enable_time()
824 .basic_scheduler()
825 .build()
826 .unwrap();
827
828 let local_timeout = Duration::default();
829 let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
830 let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
831 let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
832 + Duration::from_millis(250);
833
834 let scenarios = &[
835 // Fast primary, without fallback.
836 (&[local_ipv4_addr()][..], 4, local_timeout, false),
837 (&[local_ipv6_addr()][..], 6, local_timeout, false),
838 // Fast primary, with (unused) fallback.
839 (
840 &[local_ipv4_addr(), local_ipv6_addr()][..],
841 4,
842 local_timeout,
843 false,
844 ),
845 (
846 &[local_ipv6_addr(), local_ipv4_addr()][..],
847 6,
848 local_timeout,
849 false,
850 ),
851 // Unreachable + fast primary, without fallback.
852 (
853 &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
854 4,
855 unreachable_v4_timeout,
856 false,
857 ),
858 (
859 &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
860 6,
861 unreachable_v6_timeout,
862 false,
863 ),
864 // Unreachable + fast primary, with (unused) fallback.
865 (
866 &[
867 unreachable_ipv4_addr(),
868 local_ipv4_addr(),
869 local_ipv6_addr(),
870 ][..],
871 4,
872 unreachable_v4_timeout,
873 false,
874 ),
875 (
876 &[
877 unreachable_ipv6_addr(),
878 local_ipv6_addr(),
879 local_ipv4_addr(),
880 ][..],
881 6,
882 unreachable_v6_timeout,
883 true,
884 ),
885 // Slow primary, with (used) fallback.
886 (
887 &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
888 6,
889 fallback_timeout,
890 false,
891 ),
892 (
893 &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
894 4,
895 fallback_timeout,
896 true,
897 ),
898 // Slow primary, with (used) unreachable + fast fallback.
899 (
900 &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
901 6,
902 fallback_timeout + unreachable_v6_timeout,
903 false,
904 ),
905 (
906 &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
907 4,
908 fallback_timeout + unreachable_v4_timeout,
909 true,
910 ),
911 ];
912
913 // Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network.
914 // Otherwise, connection to "slow" IPv6 address will error-out immediately.
915 let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
916
917 for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
918 if needs_ipv6_access && !ipv6_accessible {
919 continue;
920 }
921
922 let (start, stream) = rt
923 .block_on(async move {
924 let addrs = hosts
925 .iter()
926 .map(|host| (host.clone(), addr.port()).into())
927 .collect();
928 let connecting_tcp = ConnectingTcp::new(
929 None,
930 None,
931 dns::IpAddrs::new(addrs),
932 None,
933 Some(fallback_timeout),
934 false,
935 );
936 let start = Instant::now();
937 Ok::<_, io::Error>((start, connecting_tcp.connect().await?))
938 })
939 .unwrap();
940 let res = if stream.peer_addr().unwrap().is_ipv4() {
941 4
942 } else {
943 6
944 };
945 let duration = start.elapsed();
946
947 // Allow actual duration to be +/- 150ms off.
948 let min_duration = if timeout >= Duration::from_millis(150) {
949 timeout - Duration::from_millis(150)
950 } else {
951 Duration::default()
952 };
953 let max_duration = timeout + Duration::from_millis(150);
954
955 assert_eq!(res, family);
956 assert!(duration >= min_duration);
957 assert!(duration <= max_duration);
958 }
959
960 fn local_ipv4_addr() -> IpAddr {
961 Ipv4Addr::new(127, 0, 0, 1).into()
962 }
963
964 fn local_ipv6_addr() -> IpAddr {
965 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
966 }
967
968 fn unreachable_ipv4_addr() -> IpAddr {
969 Ipv4Addr::new(127, 0, 0, 2).into()
970 }
971
972 fn unreachable_ipv6_addr() -> IpAddr {
973 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
974 }
975
976 fn slow_ipv4_addr() -> IpAddr {
977 // RFC 6890 reserved IPv4 address.
978 Ipv4Addr::new(198, 18, 0, 25).into()
979 }
980
981 fn slow_ipv6_addr() -> IpAddr {
982 // RFC 6890 reserved IPv6 address.
983 Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
984 }
985
986 fn measure_connect(addr: IpAddr) -> (bool, Duration) {
987 let start = Instant::now();
988 let result =
989 std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
990
991 let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
992 let duration = start.elapsed();
993 (reachable, duration)
994 }
995 }
996 }
997