1 use futures_util::future::Either;
2 #[cfg(feature = "__tls")]
3 use http::header::HeaderValue;
4 use http::uri::{Authority, Scheme};
5 use http::Uri;
6 use hyper::client::connect::{
7     dns::{GaiResolver, Name},
8     Connected, Connection,
9 };
10 use hyper::service::Service;
11 #[cfg(feature = "native-tls-crate")]
12 use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
13 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14 
15 use pin_project_lite::pin_project;
16 use std::io::IoSlice;
17 use std::net::IpAddr;
18 use std::pin::Pin;
19 use std::sync::Arc;
20 use std::task::{Context, Poll};
21 use std::time::Duration;
22 use std::{collections::HashMap, io};
23 use std::{future::Future, net::SocketAddr};
24 
25 #[cfg(feature = "default-tls")]
26 use self::native_tls_conn::NativeTlsConn;
27 #[cfg(feature = "__rustls")]
28 use self::rustls_tls_conn::RustlsTlsConn;
29 #[cfg(feature = "trust-dns")]
30 use crate::dns::TrustDnsResolver;
31 use crate::error::BoxError;
32 use crate::proxy::{Proxy, ProxyScheme};
33 
34 #[derive(Clone)]
35 pub(crate) enum HttpConnector {
36     Gai(hyper::client::HttpConnector),
37     GaiWithDnsOverrides(hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>>),
38     #[cfg(feature = "trust-dns")]
39     TrustDns(hyper::client::HttpConnector<TrustDnsResolver>),
40     #[cfg(feature = "trust-dns")]
41     TrustDnsWithOverrides(hyper::client::HttpConnector<DnsResolverWithOverrides<TrustDnsResolver>>),
42 }
43 
44 impl HttpConnector {
new_gai() -> Self45     pub(crate) fn new_gai() -> Self {
46         Self::Gai(hyper::client::HttpConnector::new())
47     }
48 
new_gai_with_overrides(overrides: HashMap<String, SocketAddr>) -> Self49     pub(crate) fn new_gai_with_overrides(overrides: HashMap<String, SocketAddr>) -> Self {
50         let gai = hyper::client::connect::dns::GaiResolver::new();
51         let overridden_resolver = DnsResolverWithOverrides::new(gai, overrides);
52         Self::GaiWithDnsOverrides(hyper::client::HttpConnector::new_with_resolver(
53             overridden_resolver,
54         ))
55     }
56 
57     #[cfg(feature = "trust-dns")]
new_trust_dns() -> crate::Result<HttpConnector>58     pub(crate) fn new_trust_dns() -> crate::Result<HttpConnector> {
59         TrustDnsResolver::new()
60             .map(hyper::client::HttpConnector::new_with_resolver)
61             .map(Self::TrustDns)
62             .map_err(crate::error::builder)
63     }
64 
65     #[cfg(feature = "trust-dns")]
new_trust_dns_with_overrides( overrides: HashMap<String, SocketAddr>, ) -> crate::Result<HttpConnector>66     pub(crate) fn new_trust_dns_with_overrides(
67         overrides: HashMap<String, SocketAddr>,
68     ) -> crate::Result<HttpConnector> {
69         TrustDnsResolver::new()
70             .map(|resolver| DnsResolverWithOverrides::new(resolver, overrides))
71             .map(hyper::client::HttpConnector::new_with_resolver)
72             .map(Self::TrustDnsWithOverrides)
73             .map_err(crate::error::builder)
74     }
75 }
76 
77 macro_rules! impl_http_connector {
78     ($(fn $name:ident(&mut self, $($par_name:ident: $par_type:ty),*)$( -> $return:ty)?;)+) => {
79         #[allow(dead_code)]
80         impl HttpConnector {
81             $(
82                 fn $name(&mut self, $($par_name: $par_type),*)$( -> $return)? {
83                     match self {
84                         Self::Gai(resolver) => resolver.$name($($par_name),*),
85                         Self::GaiWithDnsOverrides(resolver) => resolver.$name($($par_name),*),
86                         #[cfg(feature = "trust-dns")]
87                         Self::TrustDns(resolver) => resolver.$name($($par_name),*),
88                         #[cfg(feature = "trust-dns")]
89                         Self::TrustDnsWithOverrides(resolver) => resolver.$name($($par_name),*),
90                     }
91                 }
92             )+
93         }
94     };
95 }
96 
97 impl_http_connector! {
98     fn set_local_address(&mut self, addr: Option<IpAddr>);
99     fn enforce_http(&mut self, is_enforced: bool);
100     fn set_nodelay(&mut self, nodelay: bool);
101     fn set_keepalive(&mut self, dur: Option<Duration>);
102 }
103 
104 impl Service<Uri> for HttpConnector {
105     type Response = <hyper::client::HttpConnector as Service<Uri>>::Response;
106     type Error = <hyper::client::HttpConnector as Service<Uri>>::Error;
107     #[cfg(feature = "trust-dns")]
108     type Future =
109         Either<
110             Either<
111                 <hyper::client::HttpConnector as Service<Uri>>::Future,
112                 <hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>> as Service<
113                     Uri,
114                 >>::Future,
115             >,
116             Either<
117                     <hyper::client::HttpConnector<TrustDnsResolver> as Service<Uri>>::Future,
118                 <hyper::client::HttpConnector<DnsResolverWithOverrides<TrustDnsResolver>> as Service<Uri>>::Future
119                  >
120         >;
121     #[cfg(not(feature = "trust-dns"))]
122     type Future =
123         Either<
124             Either<
125                 <hyper::client::HttpConnector as Service<Uri>>::Future,
126                 <hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>> as Service<
127                     Uri,
128                 >>::Future,
129             >,
130             Either<
131                 <hyper::client::HttpConnector as Service<Uri>>::Future,
132                 <hyper::client::HttpConnector as Service<Uri>>::Future,
133             >,
134         >;
135 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>136     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
137         match self {
138             Self::Gai(resolver) => resolver.poll_ready(cx),
139             Self::GaiWithDnsOverrides(resolver) => resolver.poll_ready(cx),
140             #[cfg(feature = "trust-dns")]
141             Self::TrustDns(resolver) => resolver.poll_ready(cx),
142             #[cfg(feature = "trust-dns")]
143             Self::TrustDnsWithOverrides(resolver) => resolver.poll_ready(cx),
144         }
145     }
146 
call(&mut self, dst: Uri) -> Self::Future147     fn call(&mut self, dst: Uri) -> Self::Future {
148         match self {
149             Self::Gai(resolver) => Either::Left(Either::Left(resolver.call(dst))),
150             Self::GaiWithDnsOverrides(resolver) => Either::Left(Either::Right(resolver.call(dst))),
151             #[cfg(feature = "trust-dns")]
152             Self::TrustDns(resolver) => Either::Right(Either::Left(resolver.call(dst))),
153             #[cfg(feature = "trust-dns")]
154             Self::TrustDnsWithOverrides(resolver) => {
155                 Either::Right(Either::Right(resolver.call(dst)))
156             }
157         }
158     }
159 }
160 
161 #[derive(Clone)]
162 pub(crate) struct Connector {
163     inner: Inner,
164     proxies: Arc<Vec<Proxy>>,
165     verbose: verbose::Wrapper,
166     timeout: Option<Duration>,
167     #[cfg(feature = "__tls")]
168     nodelay: bool,
169     #[cfg(feature = "__tls")]
170     user_agent: Option<HeaderValue>,
171 }
172 
173 #[derive(Clone)]
174 enum Inner {
175     #[cfg(not(feature = "__tls"))]
176     Http(HttpConnector),
177     #[cfg(feature = "default-tls")]
178     DefaultTls(HttpConnector, TlsConnector),
179     #[cfg(feature = "__rustls")]
180     RustlsTls {
181         http: HttpConnector,
182         tls: Arc<rustls::ClientConfig>,
183         tls_proxy: Arc<rustls::ClientConfig>,
184     },
185 }
186 
187 impl Connector {
188     #[cfg(not(feature = "__tls"))]
new<T>( mut http: HttpConnector, proxies: Arc<Vec<Proxy>>, local_addr: T, nodelay: bool, ) -> Connector where T: Into<Option<IpAddr>>,189     pub(crate) fn new<T>(
190         mut http: HttpConnector,
191         proxies: Arc<Vec<Proxy>>,
192         local_addr: T,
193         nodelay: bool,
194     ) -> Connector
195     where
196         T: Into<Option<IpAddr>>,
197     {
198         http.set_local_address(local_addr.into());
199         http.set_nodelay(nodelay);
200         Connector {
201             inner: Inner::Http(http),
202             verbose: verbose::OFF,
203             proxies,
204             timeout: None,
205         }
206     }
207 
208     #[cfg(feature = "default-tls")]
new_default_tls<T>( http: HttpConnector, tls: TlsConnectorBuilder, proxies: Arc<Vec<Proxy>>, user_agent: Option<HeaderValue>, local_addr: T, nodelay: bool, ) -> crate::Result<Connector> where T: Into<Option<IpAddr>>,209     pub(crate) fn new_default_tls<T>(
210         http: HttpConnector,
211         tls: TlsConnectorBuilder,
212         proxies: Arc<Vec<Proxy>>,
213         user_agent: Option<HeaderValue>,
214         local_addr: T,
215         nodelay: bool,
216     ) -> crate::Result<Connector>
217     where
218         T: Into<Option<IpAddr>>,
219     {
220         let tls = tls.build().map_err(crate::error::builder)?;
221         Ok(Self::from_built_default_tls(
222             http, tls, proxies, user_agent, local_addr, nodelay,
223         ))
224     }
225 
226     #[cfg(feature = "default-tls")]
from_built_default_tls<T>( mut http: HttpConnector, tls: TlsConnector, proxies: Arc<Vec<Proxy>>, user_agent: Option<HeaderValue>, local_addr: T, nodelay: bool, ) -> Connector where T: Into<Option<IpAddr>>,227     pub(crate) fn from_built_default_tls<T>(
228         mut http: HttpConnector,
229         tls: TlsConnector,
230         proxies: Arc<Vec<Proxy>>,
231         user_agent: Option<HeaderValue>,
232         local_addr: T,
233         nodelay: bool,
234     ) -> Connector
235     where
236         T: Into<Option<IpAddr>>,
237     {
238         http.set_local_address(local_addr.into());
239         http.enforce_http(false);
240 
241         Connector {
242             inner: Inner::DefaultTls(http, tls),
243             proxies,
244             verbose: verbose::OFF,
245             timeout: None,
246             nodelay,
247             user_agent,
248         }
249     }
250 
251     #[cfg(feature = "__rustls")]
new_rustls_tls<T>( mut http: HttpConnector, tls: rustls::ClientConfig, proxies: Arc<Vec<Proxy>>, user_agent: Option<HeaderValue>, local_addr: T, nodelay: bool, ) -> Connector where T: Into<Option<IpAddr>>,252     pub(crate) fn new_rustls_tls<T>(
253         mut http: HttpConnector,
254         tls: rustls::ClientConfig,
255         proxies: Arc<Vec<Proxy>>,
256         user_agent: Option<HeaderValue>,
257         local_addr: T,
258         nodelay: bool,
259     ) -> Connector
260     where
261         T: Into<Option<IpAddr>>,
262     {
263         http.set_local_address(local_addr.into());
264         http.enforce_http(false);
265 
266         let (tls, tls_proxy) = if proxies.is_empty() {
267             let tls = Arc::new(tls);
268             (tls.clone(), tls)
269         } else {
270             let mut tls_proxy = tls.clone();
271             tls_proxy.alpn_protocols.clear();
272             (Arc::new(tls), Arc::new(tls_proxy))
273         };
274 
275         Connector {
276             inner: Inner::RustlsTls {
277                 http,
278                 tls,
279                 tls_proxy,
280             },
281             proxies,
282             verbose: verbose::OFF,
283             timeout: None,
284             nodelay,
285             user_agent,
286         }
287     }
288 
set_timeout(&mut self, timeout: Option<Duration>)289     pub(crate) fn set_timeout(&mut self, timeout: Option<Duration>) {
290         self.timeout = timeout;
291     }
292 
set_verbose(&mut self, enabled: bool)293     pub(crate) fn set_verbose(&mut self, enabled: bool) {
294         self.verbose.0 = enabled;
295     }
296 
297     #[cfg(feature = "socks")]
connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result<Conn, BoxError>298     async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result<Conn, BoxError> {
299         let dns = match proxy {
300             ProxyScheme::Socks5 {
301                 remote_dns: false, ..
302             } => socks::DnsResolve::Local,
303             ProxyScheme::Socks5 {
304                 remote_dns: true, ..
305             } => socks::DnsResolve::Proxy,
306             ProxyScheme::Http { .. } | ProxyScheme::Https { .. } => {
307                 unreachable!("connect_socks is only called for socks proxies");
308             }
309         };
310 
311         match &self.inner {
312             #[cfg(feature = "default-tls")]
313             Inner::DefaultTls(_http, tls) => {
314                 if dst.scheme() == Some(&Scheme::HTTPS) {
315                     let host = dst.host().ok_or("no host in url")?.to_string();
316                     let conn = socks::connect(proxy, dst, dns).await?;
317                     let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
318                     let io = tls_connector.connect(&host, conn).await?;
319                     return Ok(Conn {
320                         inner: self.verbose.wrap(NativeTlsConn { inner: io }),
321                         is_proxy: false,
322                     });
323                 }
324             }
325             #[cfg(feature = "__rustls")]
326             Inner::RustlsTls { tls_proxy, .. } => {
327                 if dst.scheme() == Some(&Scheme::HTTPS) {
328                     use tokio_rustls::webpki::DNSNameRef;
329                     use tokio_rustls::TlsConnector as RustlsConnector;
330 
331                     let tls = tls_proxy.clone();
332                     let host = dst.host().ok_or("no host in url")?.to_string();
333                     let conn = socks::connect(proxy, dst, dns).await?;
334                     let dnsname = DNSNameRef::try_from_ascii_str(&host)
335                         .map(|dnsname| dnsname.to_owned())
336                         .map_err(|_| "Invalid DNS Name")?;
337                     let io = RustlsConnector::from(tls)
338                         .connect(dnsname.as_ref(), conn)
339                         .await?;
340                     return Ok(Conn {
341                         inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
342                         is_proxy: false,
343                     });
344                 }
345             }
346             #[cfg(not(feature = "__tls"))]
347             Inner::Http(_) => (),
348         }
349 
350         socks::connect(proxy, dst, dns).await.map(|tcp| Conn {
351             inner: self.verbose.wrap(tcp),
352             is_proxy: false,
353         })
354     }
355 
connect_with_maybe_proxy(self, dst: Uri, is_proxy: bool) -> Result<Conn, BoxError>356     async fn connect_with_maybe_proxy(self, dst: Uri, is_proxy: bool) -> Result<Conn, BoxError> {
357         match self.inner {
358             #[cfg(not(feature = "__tls"))]
359             Inner::Http(mut http) => {
360                 let io = http.call(dst).await?;
361                 Ok(Conn {
362                     inner: self.verbose.wrap(io),
363                     is_proxy,
364                 })
365             }
366             #[cfg(feature = "default-tls")]
367             Inner::DefaultTls(http, tls) => {
368                 let mut http = http.clone();
369 
370                 // Disable Nagle's algorithm for TLS handshake
371                 //
372                 // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES
373                 if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) {
374                     http.set_nodelay(true);
375                 }
376 
377                 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
378                 let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
379                 let io = http.call(dst).await?;
380 
381                 if let hyper_tls::MaybeHttpsStream::Https(stream) = io {
382                     if !self.nodelay {
383                         stream.get_ref().get_ref().get_ref().set_nodelay(false)?;
384                     }
385                     Ok(Conn {
386                         inner: self.verbose.wrap(NativeTlsConn { inner: stream }),
387                         is_proxy,
388                     })
389                 } else {
390                     Ok(Conn {
391                         inner: self.verbose.wrap(io),
392                         is_proxy,
393                     })
394                 }
395             }
396             #[cfg(feature = "__rustls")]
397             Inner::RustlsTls { http, tls, .. } => {
398                 let mut http = http.clone();
399 
400                 // Disable Nagle's algorithm for TLS handshake
401                 //
402                 // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES
403                 if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) {
404                     http.set_nodelay(true);
405                 }
406 
407                 let mut http = hyper_rustls::HttpsConnector::from((http, tls.clone()));
408                 let io = http.call(dst).await?;
409 
410                 if let hyper_rustls::MaybeHttpsStream::Https(stream) = io {
411                     if !self.nodelay {
412                         let (io, _) = stream.get_ref();
413                         io.set_nodelay(false)?;
414                     }
415                     Ok(Conn {
416                         inner: self.verbose.wrap(RustlsTlsConn { inner: stream }),
417                         is_proxy,
418                     })
419                 } else {
420                     Ok(Conn {
421                         inner: self.verbose.wrap(io),
422                         is_proxy,
423                     })
424                 }
425             }
426         }
427     }
428 
connect_via_proxy( self, dst: Uri, proxy_scheme: ProxyScheme, ) -> Result<Conn, BoxError>429     async fn connect_via_proxy(
430         self,
431         dst: Uri,
432         proxy_scheme: ProxyScheme,
433     ) -> Result<Conn, BoxError> {
434         log::debug!("proxy({:?}) intercepts '{:?}'", proxy_scheme, dst);
435 
436         let (proxy_dst, _auth) = match proxy_scheme {
437             ProxyScheme::Http { host, auth } => (into_uri(Scheme::HTTP, host), auth),
438             ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth),
439             #[cfg(feature = "socks")]
440             ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme).await,
441         };
442 
443         #[cfg(feature = "__tls")]
444         let auth = _auth;
445 
446         match &self.inner {
447             #[cfg(feature = "default-tls")]
448             Inner::DefaultTls(http, tls) => {
449                 if dst.scheme() == Some(&Scheme::HTTPS) {
450                     let host = dst.host().to_owned();
451                     let port = dst.port().map(|p| p.as_u16()).unwrap_or(443);
452                     let http = http.clone();
453                     let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
454                     let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
455                     let conn = http.call(proxy_dst).await?;
456                     log::trace!("tunneling HTTPS over proxy");
457                     let tunneled = tunnel(
458                         conn,
459                         host.ok_or("no host in url")?.to_string(),
460                         port,
461                         self.user_agent.clone(),
462                         auth,
463                     )
464                     .await?;
465                     let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
466                     let io = tls_connector
467                         .connect(&host.ok_or("no host in url")?, tunneled)
468                         .await?;
469                     return Ok(Conn {
470                         inner: self.verbose.wrap(NativeTlsConn { inner: io }),
471                         is_proxy: false,
472                     });
473                 }
474             }
475             #[cfg(feature = "__rustls")]
476             Inner::RustlsTls {
477                 http,
478                 tls,
479                 tls_proxy,
480             } => {
481                 if dst.scheme() == Some(&Scheme::HTTPS) {
482                     use tokio_rustls::webpki::DNSNameRef;
483                     use tokio_rustls::TlsConnector as RustlsConnector;
484 
485                     let host = dst.host().ok_or("no host in url")?.to_string();
486                     let port = dst.port().map(|r| r.as_u16()).unwrap_or(443);
487                     let http = http.clone();
488                     let mut http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone()));
489                     let tls = tls.clone();
490                     let conn = http.call(proxy_dst).await?;
491                     log::trace!("tunneling HTTPS over proxy");
492                     let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host)
493                         .map(|dnsname| dnsname.to_owned())
494                         .map_err(|_| "Invalid DNS Name");
495                     let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?;
496                     let dnsname = maybe_dnsname?;
497                     let io = RustlsConnector::from(tls)
498                         .connect(dnsname.as_ref(), tunneled)
499                         .await?;
500 
501                     return Ok(Conn {
502                         inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
503                         is_proxy: false,
504                     });
505                 }
506             }
507             #[cfg(not(feature = "__tls"))]
508             Inner::Http(_) => (),
509         }
510 
511         self.connect_with_maybe_proxy(proxy_dst, true).await
512     }
513 
set_keepalive(&mut self, dur: Option<Duration>)514     pub fn set_keepalive(&mut self, dur: Option<Duration>) {
515         match &mut self.inner {
516             #[cfg(feature = "default-tls")]
517             Inner::DefaultTls(http, _tls) => http.set_keepalive(dur),
518             #[cfg(feature = "__rustls")]
519             Inner::RustlsTls { http, .. } => http.set_keepalive(dur),
520             #[cfg(not(feature = "__tls"))]
521             Inner::Http(http) => http.set_keepalive(dur),
522         }
523     }
524 }
525 
into_uri(scheme: Scheme, host: Authority) -> Uri526 fn into_uri(scheme: Scheme, host: Authority) -> Uri {
527     // TODO: Should the `http` crate get `From<(Scheme, Authority)> for Uri`?
528     http::Uri::builder()
529         .scheme(scheme)
530         .authority(host)
531         .path_and_query(http::uri::PathAndQuery::from_static("/"))
532         .build()
533         .expect("scheme and authority is valid Uri")
534 }
535 
with_timeout<T, F>(f: F, timeout: Option<Duration>) -> Result<T, BoxError> where F: Future<Output = Result<T, BoxError>>,536 async fn with_timeout<T, F>(f: F, timeout: Option<Duration>) -> Result<T, BoxError>
537 where
538     F: Future<Output = Result<T, BoxError>>,
539 {
540     if let Some(to) = timeout {
541         match tokio::time::timeout(to, f).await {
542             Err(_elapsed) => Err(Box::new(crate::error::TimedOut) as BoxError),
543             Ok(Ok(try_res)) => Ok(try_res),
544             Ok(Err(e)) => Err(e),
545         }
546     } else {
547         f.await
548     }
549 }
550 
551 impl Service<Uri> for Connector {
552     type Response = Conn;
553     type Error = BoxError;
554     type Future = Connecting;
555 
poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>556     fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
557         Poll::Ready(Ok(()))
558     }
559 
call(&mut self, dst: Uri) -> Self::Future560     fn call(&mut self, dst: Uri) -> Self::Future {
561         log::debug!("starting new connection: {:?}", dst);
562         let timeout = self.timeout;
563         for prox in self.proxies.iter() {
564             if let Some(proxy_scheme) = prox.intercept(&dst) {
565                 return Box::pin(with_timeout(
566                     self.clone().connect_via_proxy(dst, proxy_scheme),
567                     timeout,
568                 ));
569             }
570         }
571 
572         Box::pin(with_timeout(
573             self.clone().connect_with_maybe_proxy(dst, false),
574             timeout,
575         ))
576     }
577 }
578 
579 pub(crate) trait AsyncConn:
580     AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static
581 {
582 }
583 
584 impl<T: AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static> AsyncConn for T {}
585 
586 type BoxConn = Box<dyn AsyncConn>;
587 
588 pin_project! {
589     /// Note: the `is_proxy` member means *is plain text HTTP proxy*.
590     /// This tells hyper whether the URI should be written in
591     /// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or
592     /// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise.
593     pub(crate) struct Conn {
594         #[pin]
595         inner: BoxConn,
596         is_proxy: bool,
597     }
598 }
599 
600 impl Connection for Conn {
connected(&self) -> Connected601     fn connected(&self) -> Connected {
602         self.inner.connected().proxy(self.is_proxy)
603     }
604 }
605 
606 impl AsyncRead for Conn {
poll_read( self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>607     fn poll_read(
608         self: Pin<&mut Self>,
609         cx: &mut Context,
610         buf: &mut ReadBuf<'_>,
611     ) -> Poll<io::Result<()>> {
612         let this = self.project();
613         AsyncRead::poll_read(this.inner, cx, buf)
614     }
615 }
616 
617 impl AsyncWrite for Conn {
poll_write( self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll<Result<usize, io::Error>>618     fn poll_write(
619         self: Pin<&mut Self>,
620         cx: &mut Context,
621         buf: &[u8],
622     ) -> Poll<Result<usize, io::Error>> {
623         let this = self.project();
624         AsyncWrite::poll_write(this.inner, cx, buf)
625     }
626 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize, io::Error>>627     fn poll_write_vectored(
628         self: Pin<&mut Self>,
629         cx: &mut Context<'_>,
630         bufs: &[IoSlice<'_>],
631     ) -> Poll<Result<usize, io::Error>> {
632         let this = self.project();
633         AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
634     }
635 
is_write_vectored(&self) -> bool636     fn is_write_vectored(&self) -> bool {
637         self.inner.is_write_vectored()
638     }
639 
poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>>640     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
641         let this = self.project();
642         AsyncWrite::poll_flush(this.inner, cx)
643     }
644 
poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>>645     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
646         let this = self.project();
647         AsyncWrite::poll_shutdown(this.inner, cx)
648     }
649 }
650 
651 pub(crate) type Connecting = Pin<Box<dyn Future<Output = Result<Conn, BoxError>> + Send>>;
652 
653 #[cfg(feature = "__tls")]
tunnel<T>( mut conn: T, host: String, port: u16, user_agent: Option<HeaderValue>, auth: Option<HeaderValue>, ) -> Result<T, BoxError> where T: AsyncRead + AsyncWrite + Unpin,654 async fn tunnel<T>(
655     mut conn: T,
656     host: String,
657     port: u16,
658     user_agent: Option<HeaderValue>,
659     auth: Option<HeaderValue>,
660 ) -> Result<T, BoxError>
661 where
662     T: AsyncRead + AsyncWrite + Unpin,
663 {
664     use tokio::io::{AsyncReadExt, AsyncWriteExt};
665 
666     let mut buf = format!(
667         "\
668          CONNECT {0}:{1} HTTP/1.1\r\n\
669          Host: {0}:{1}\r\n\
670          ",
671         host, port
672     )
673     .into_bytes();
674 
675     // user-agent
676     if let Some(user_agent) = user_agent {
677         buf.extend_from_slice(b"User-Agent: ");
678         buf.extend_from_slice(user_agent.as_bytes());
679         buf.extend_from_slice(b"\r\n");
680     }
681 
682     // proxy-authorization
683     if let Some(value) = auth {
684         log::debug!("tunnel to {}:{} using basic auth", host, port);
685         buf.extend_from_slice(b"Proxy-Authorization: ");
686         buf.extend_from_slice(value.as_bytes());
687         buf.extend_from_slice(b"\r\n");
688     }
689 
690     // headers end
691     buf.extend_from_slice(b"\r\n");
692 
693     conn.write_all(&buf).await?;
694 
695     let mut buf = [0; 8192];
696     let mut pos = 0;
697 
698     loop {
699         let n = conn.read(&mut buf[pos..]).await?;
700 
701         if n == 0 {
702             return Err(tunnel_eof());
703         }
704         pos += n;
705 
706         let recvd = &buf[..pos];
707         if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") {
708             if recvd.ends_with(b"\r\n\r\n") {
709                 return Ok(conn);
710             }
711             if pos == buf.len() {
712                 return Err("proxy headers too long for tunnel".into());
713             }
714         // else read more
715         } else if recvd.starts_with(b"HTTP/1.1 407") {
716             return Err("proxy authentication required".into());
717         } else {
718             return Err("unsuccessful tunnel".into());
719         }
720     }
721 }
722 
723 #[cfg(feature = "__tls")]
tunnel_eof() -> BoxError724 fn tunnel_eof() -> BoxError {
725     "unexpected eof while tunneling".into()
726 }
727 
728 #[cfg(feature = "default-tls")]
729 mod native_tls_conn {
730     use hyper::client::connect::{Connected, Connection};
731     use pin_project_lite::pin_project;
732     use std::{
733         io::{self, IoSlice},
734         pin::Pin,
735         task::{Context, Poll},
736     };
737     use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
738     use tokio_native_tls::TlsStream;
739 
740     pin_project! {
741         pub(super) struct NativeTlsConn<T> {
742             #[pin] pub(super) inner: TlsStream<T>,
743         }
744     }
745 
746     impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for NativeTlsConn<T> {
747         #[cfg(feature = "native-tls-alpn")]
connected(&self) -> Connected748         fn connected(&self) -> Connected {
749             match self.inner.get_ref().negotiated_alpn().ok() {
750                 Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => self
751                     .inner
752                     .get_ref()
753                     .get_ref()
754                     .get_ref()
755                     .connected()
756                     .negotiated_h2(),
757                 _ => self.inner.get_ref().get_ref().get_ref().connected(),
758             }
759         }
760 
761         #[cfg(not(feature = "native-tls-alpn"))]
connected(&self) -> Connected762         fn connected(&self) -> Connected {
763             self.inner.get_ref().get_ref().get_ref().connected()
764         }
765     }
766 
767     impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsConn<T> {
poll_read( self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf<'_>, ) -> Poll<tokio::io::Result<()>>768         fn poll_read(
769             self: Pin<&mut Self>,
770             cx: &mut Context,
771             buf: &mut ReadBuf<'_>,
772         ) -> Poll<tokio::io::Result<()>> {
773             let this = self.project();
774             AsyncRead::poll_read(this.inner, cx, buf)
775         }
776     }
777 
778     impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for NativeTlsConn<T> {
poll_write( self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll<Result<usize, tokio::io::Error>>779         fn poll_write(
780             self: Pin<&mut Self>,
781             cx: &mut Context,
782             buf: &[u8],
783         ) -> Poll<Result<usize, tokio::io::Error>> {
784             let this = self.project();
785             AsyncWrite::poll_write(this.inner, cx, buf)
786         }
787 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize, io::Error>>788         fn poll_write_vectored(
789             self: Pin<&mut Self>,
790             cx: &mut Context<'_>,
791             bufs: &[IoSlice<'_>],
792         ) -> Poll<Result<usize, io::Error>> {
793             let this = self.project();
794             AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
795         }
796 
is_write_vectored(&self) -> bool797         fn is_write_vectored(&self) -> bool {
798             self.inner.is_write_vectored()
799         }
800 
poll_flush( self: Pin<&mut Self>, cx: &mut Context, ) -> Poll<Result<(), tokio::io::Error>>801         fn poll_flush(
802             self: Pin<&mut Self>,
803             cx: &mut Context,
804         ) -> Poll<Result<(), tokio::io::Error>> {
805             let this = self.project();
806             AsyncWrite::poll_flush(this.inner, cx)
807         }
808 
poll_shutdown( self: Pin<&mut Self>, cx: &mut Context, ) -> Poll<Result<(), tokio::io::Error>>809         fn poll_shutdown(
810             self: Pin<&mut Self>,
811             cx: &mut Context,
812         ) -> Poll<Result<(), tokio::io::Error>> {
813             let this = self.project();
814             AsyncWrite::poll_shutdown(this.inner, cx)
815         }
816     }
817 }
818 
819 #[cfg(feature = "__rustls")]
820 mod rustls_tls_conn {
821     use hyper::client::connect::{Connected, Connection};
822     use pin_project_lite::pin_project;
823     use rustls::Session;
824     use std::{
825         io::{self, IoSlice},
826         pin::Pin,
827         task::{Context, Poll},
828     };
829     use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
830     use tokio_rustls::client::TlsStream;
831 
832     pin_project! {
833         pub(super) struct RustlsTlsConn<T> {
834             #[pin] pub(super) inner: TlsStream<T>,
835         }
836     }
837 
838     impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for RustlsTlsConn<T> {
connected(&self) -> Connected839         fn connected(&self) -> Connected {
840             if self.inner.get_ref().1.get_alpn_protocol() == Some(b"h2") {
841                 self.inner.get_ref().0.connected().negotiated_h2()
842             } else {
843                 self.inner.get_ref().0.connected()
844             }
845         }
846     }
847 
848     impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for RustlsTlsConn<T> {
poll_read( self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf<'_>, ) -> Poll<tokio::io::Result<()>>849         fn poll_read(
850             self: Pin<&mut Self>,
851             cx: &mut Context,
852             buf: &mut ReadBuf<'_>,
853         ) -> Poll<tokio::io::Result<()>> {
854             let this = self.project();
855             AsyncRead::poll_read(this.inner, cx, buf)
856         }
857     }
858 
859     impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for RustlsTlsConn<T> {
poll_write( self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll<Result<usize, tokio::io::Error>>860         fn poll_write(
861             self: Pin<&mut Self>,
862             cx: &mut Context,
863             buf: &[u8],
864         ) -> Poll<Result<usize, tokio::io::Error>> {
865             let this = self.project();
866             AsyncWrite::poll_write(this.inner, cx, buf)
867         }
868 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize, io::Error>>869         fn poll_write_vectored(
870             self: Pin<&mut Self>,
871             cx: &mut Context<'_>,
872             bufs: &[IoSlice<'_>],
873         ) -> Poll<Result<usize, io::Error>> {
874             let this = self.project();
875             AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
876         }
877 
is_write_vectored(&self) -> bool878         fn is_write_vectored(&self) -> bool {
879             self.inner.is_write_vectored()
880         }
881 
poll_flush( self: Pin<&mut Self>, cx: &mut Context, ) -> Poll<Result<(), tokio::io::Error>>882         fn poll_flush(
883             self: Pin<&mut Self>,
884             cx: &mut Context,
885         ) -> Poll<Result<(), tokio::io::Error>> {
886             let this = self.project();
887             AsyncWrite::poll_flush(this.inner, cx)
888         }
889 
poll_shutdown( self: Pin<&mut Self>, cx: &mut Context, ) -> Poll<Result<(), tokio::io::Error>>890         fn poll_shutdown(
891             self: Pin<&mut Self>,
892             cx: &mut Context,
893         ) -> Poll<Result<(), tokio::io::Error>> {
894             let this = self.project();
895             AsyncWrite::poll_shutdown(this.inner, cx)
896         }
897     }
898 }
899 
900 #[cfg(feature = "socks")]
901 mod socks {
902     use std::io;
903     use std::net::ToSocketAddrs;
904 
905     use http::Uri;
906     use tokio::net::TcpStream;
907     use tokio_socks::tcp::Socks5Stream;
908 
909     use super::{BoxError, Scheme};
910     use crate::proxy::ProxyScheme;
911 
912     pub(super) enum DnsResolve {
913         Local,
914         Proxy,
915     }
916 
connect( proxy: ProxyScheme, dst: Uri, dns: DnsResolve, ) -> Result<TcpStream, BoxError>917     pub(super) async fn connect(
918         proxy: ProxyScheme,
919         dst: Uri,
920         dns: DnsResolve,
921     ) -> Result<TcpStream, BoxError> {
922         let https = dst.scheme() == Some(&Scheme::HTTPS);
923         let original_host = dst
924             .host()
925             .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?;
926         let mut host = original_host.to_owned();
927         let port = match dst.port() {
928             Some(p) => p.as_u16(),
929             None if https => 443u16,
930             _ => 80u16,
931         };
932 
933         if let DnsResolve::Local = dns {
934             let maybe_new_target = (host.as_str(), port).to_socket_addrs()?.next();
935             if let Some(new_target) = maybe_new_target {
936                 host = new_target.ip().to_string();
937             }
938         }
939 
940         let (socket_addr, auth) = match proxy {
941             ProxyScheme::Socks5 { addr, auth, .. } => (addr, auth),
942             _ => unreachable!(),
943         };
944 
945         // Get a Tokio TcpStream
946         let stream = if let Some((username, password)) = auth {
947             Socks5Stream::connect_with_password(
948                 socket_addr,
949                 (host.as_str(), port),
950                 &username,
951                 &password,
952             )
953             .await
954             .map_err(|e| format!("socks connect error: {}", e))?
955         } else {
956             Socks5Stream::connect(socket_addr, (host.as_str(), port))
957                 .await
958                 .map_err(|e| format!("socks connect error: {}", e))?
959         };
960 
961         Ok(stream.into_inner())
962     }
963 }
964 
965 pub(crate) mod itertools {
966     pub(crate) enum Either<A, B> {
967         Left(A),
968         Right(B),
969     }
970 
971     impl<A, B> Iterator for Either<A, B>
972     where
973         A: Iterator,
974         B: Iterator<Item = <A as Iterator>::Item>,
975     {
976         type Item = <A as Iterator>::Item;
977 
next(&mut self) -> Option<Self::Item>978         fn next(&mut self) -> Option<Self::Item> {
979             match self {
980                 Either::Left(a) => a.next(),
981                 Either::Right(b) => b.next(),
982             }
983         }
984     }
985 }
986 
987 pin_project! {
988     pub(crate) struct WrappedResolverFuture<Fut> {
989         #[pin]
990         fut: Fut,
991     }
992 }
993 
994 impl<Fut, FutOutput, FutError> std::future::Future for WrappedResolverFuture<Fut>
995 where
996     Fut: std::future::Future<Output = Result<FutOutput, FutError>>,
997     FutOutput: Iterator<Item = SocketAddr>,
998 {
999     type Output = Result<itertools::Either<FutOutput, std::iter::Once<SocketAddr>>, FutError>;
1000 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>1001     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1002         let this = self.project();
1003         this.fut
1004             .poll(cx)
1005             .map(|result| result.map(itertools::Either::Left))
1006     }
1007 }
1008 
1009 #[derive(Clone)]
1010 pub(crate) struct DnsResolverWithOverrides<Resolver>
1011 where
1012     Resolver: Clone,
1013 {
1014     dns_resolver: Resolver,
1015     overrides: Arc<HashMap<String, SocketAddr>>,
1016 }
1017 
1018 impl<Resolver: Clone> DnsResolverWithOverrides<Resolver> {
new(dns_resolver: Resolver, overrides: HashMap<String, SocketAddr>) -> Self1019     fn new(dns_resolver: Resolver, overrides: HashMap<String, SocketAddr>) -> Self {
1020         DnsResolverWithOverrides {
1021             dns_resolver,
1022             overrides: Arc::new(overrides),
1023         }
1024     }
1025 }
1026 
1027 impl<Resolver, Iter> Service<Name> for DnsResolverWithOverrides<Resolver>
1028 where
1029     Resolver: Service<Name, Response = Iter> + Clone,
1030     Iter: Iterator<Item = SocketAddr>,
1031 {
1032     type Response = itertools::Either<Iter, std::iter::Once<SocketAddr>>;
1033     type Error = <Resolver as Service<Name>>::Error;
1034     type Future = Either<
1035         WrappedResolverFuture<<Resolver as Service<Name>>::Future>,
1036         futures_util::future::Ready<
1037             Result<itertools::Either<Iter, std::iter::Once<SocketAddr>>, Self::Error>,
1038         >,
1039     >;
1040 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>1041     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1042         self.dns_resolver.poll_ready(cx)
1043     }
1044 
call(&mut self, name: Name) -> Self::Future1045     fn call(&mut self, name: Name) -> Self::Future {
1046         match self.overrides.get(name.as_str()) {
1047             Some(dest) => {
1048                 let fut = futures_util::future::ready(Ok(itertools::Either::Right(
1049                     std::iter::once(dest.to_owned()),
1050                 )));
1051                 Either::Right(fut)
1052             }
1053             None => {
1054                 let resolver_fut = self.dns_resolver.call(name);
1055                 let y = WrappedResolverFuture { fut: resolver_fut };
1056                 Either::Left(y)
1057             }
1058         }
1059     }
1060 }
1061 
1062 mod verbose {
1063     use hyper::client::connect::{Connected, Connection};
1064     use std::fmt;
1065     use std::io::{self, IoSlice};
1066     use std::pin::Pin;
1067     use std::task::{Context, Poll};
1068     use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
1069 
1070     pub(super) const OFF: Wrapper = Wrapper(false);
1071 
1072     #[derive(Clone, Copy)]
1073     pub(super) struct Wrapper(pub(super) bool);
1074 
1075     impl Wrapper {
wrap<T: super::AsyncConn>(&self, conn: T) -> super::BoxConn1076         pub(super) fn wrap<T: super::AsyncConn>(&self, conn: T) -> super::BoxConn {
1077             if self.0 && log::log_enabled!(log::Level::Trace) {
1078                 Box::new(Verbose {
1079                     // truncate is fine
1080                     id: crate::util::fast_random() as u32,
1081                     inner: conn,
1082                 })
1083             } else {
1084                 Box::new(conn)
1085             }
1086         }
1087     }
1088 
1089     struct Verbose<T> {
1090         id: u32,
1091         inner: T,
1092     }
1093 
1094     impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for Verbose<T> {
connected(&self) -> Connected1095         fn connected(&self) -> Connected {
1096             self.inner.connected()
1097         }
1098     }
1099 
1100     impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for Verbose<T> {
poll_read( mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>1101         fn poll_read(
1102             mut self: Pin<&mut Self>,
1103             cx: &mut Context,
1104             buf: &mut ReadBuf<'_>,
1105         ) -> Poll<std::io::Result<()>> {
1106             match Pin::new(&mut self.inner).poll_read(cx, buf) {
1107                 Poll::Ready(Ok(())) => {
1108                     log::trace!("{:08x} read: {:?}", self.id, Escape(buf.filled()));
1109                     Poll::Ready(Ok(()))
1110                 }
1111                 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1112                 Poll::Pending => Poll::Pending,
1113             }
1114         }
1115     }
1116 
1117     impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Verbose<T> {
poll_write( mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll<Result<usize, std::io::Error>>1118         fn poll_write(
1119             mut self: Pin<&mut Self>,
1120             cx: &mut Context,
1121             buf: &[u8],
1122         ) -> Poll<Result<usize, std::io::Error>> {
1123             match Pin::new(&mut self.inner).poll_write(cx, buf) {
1124                 Poll::Ready(Ok(n)) => {
1125                     log::trace!("{:08x} write: {:?}", self.id, Escape(&buf[..n]));
1126                     Poll::Ready(Ok(n))
1127                 }
1128                 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1129                 Poll::Pending => Poll::Pending,
1130             }
1131         }
1132 
poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize, io::Error>>1133         fn poll_write_vectored(
1134             mut self: Pin<&mut Self>,
1135             cx: &mut Context<'_>,
1136             bufs: &[IoSlice<'_>],
1137         ) -> Poll<Result<usize, io::Error>> {
1138             Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
1139         }
1140 
is_write_vectored(&self) -> bool1141         fn is_write_vectored(&self) -> bool {
1142             self.inner.is_write_vectored()
1143         }
1144 
poll_flush( mut self: Pin<&mut Self>, cx: &mut Context, ) -> Poll<Result<(), std::io::Error>>1145         fn poll_flush(
1146             mut self: Pin<&mut Self>,
1147             cx: &mut Context,
1148         ) -> Poll<Result<(), std::io::Error>> {
1149             Pin::new(&mut self.inner).poll_flush(cx)
1150         }
1151 
poll_shutdown( mut self: Pin<&mut Self>, cx: &mut Context, ) -> Poll<Result<(), std::io::Error>>1152         fn poll_shutdown(
1153             mut self: Pin<&mut Self>,
1154             cx: &mut Context,
1155         ) -> Poll<Result<(), std::io::Error>> {
1156             Pin::new(&mut self.inner).poll_shutdown(cx)
1157         }
1158     }
1159 
1160     struct Escape<'a>(&'a [u8]);
1161 
1162     impl fmt::Debug for Escape<'_> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result1163         fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1164             write!(f, "b\"")?;
1165             for &c in self.0 {
1166                 // https://doc.rust-lang.org/reference.html#byte-escapes
1167                 if c == b'\n' {
1168                     write!(f, "\\n")?;
1169                 } else if c == b'\r' {
1170                     write!(f, "\\r")?;
1171                 } else if c == b'\t' {
1172                     write!(f, "\\t")?;
1173                 } else if c == b'\\' || c == b'"' {
1174                     write!(f, "\\{}", c as char)?;
1175                 } else if c == b'\0' {
1176                     write!(f, "\\0")?;
1177                 // ASCII printable
1178                 } else if c >= 0x20 && c < 0x7f {
1179                     write!(f, "{}", c as char)?;
1180                 } else {
1181                     write!(f, "\\x{:02x}", c)?;
1182                 }
1183             }
1184             write!(f, "\"")?;
1185             Ok(())
1186         }
1187     }
1188 }
1189 
1190 #[cfg(feature = "__tls")]
1191 #[cfg(test)]
1192 mod tests {
1193     use super::tunnel;
1194     use crate::proxy;
1195     use std::io::{Read, Write};
1196     use std::net::TcpListener;
1197     use std::thread;
1198     use tokio::net::TcpStream;
1199     use tokio::runtime;
1200 
1201     static TUNNEL_UA: &'static str = "tunnel-test/x.y";
1202     static TUNNEL_OK: &[u8] = b"\
1203         HTTP/1.1 200 OK\r\n\
1204         \r\n\
1205     ";
1206 
1207     macro_rules! mock_tunnel {
1208         () => {{
1209             mock_tunnel!(TUNNEL_OK)
1210         }};
1211         ($write:expr) => {{
1212             mock_tunnel!($write, "")
1213         }};
1214         ($write:expr, $auth:expr) => {{
1215             let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1216             let addr = listener.local_addr().unwrap();
1217             let connect_expected = format!(
1218                 "\
1219                  CONNECT {0}:{1} HTTP/1.1\r\n\
1220                  Host: {0}:{1}\r\n\
1221                  User-Agent: {2}\r\n\
1222                  {3}\
1223                  \r\n\
1224                  ",
1225                 addr.ip(),
1226                 addr.port(),
1227                 TUNNEL_UA,
1228                 $auth
1229             )
1230             .into_bytes();
1231 
1232             thread::spawn(move || {
1233                 let (mut sock, _) = listener.accept().unwrap();
1234                 let mut buf = [0u8; 4096];
1235                 let n = sock.read(&mut buf).unwrap();
1236                 assert_eq!(&buf[..n], &connect_expected[..]);
1237 
1238                 sock.write_all($write).unwrap();
1239             });
1240             addr
1241         }};
1242     }
1243 
ua() -> Option<http::header::HeaderValue>1244     fn ua() -> Option<http::header::HeaderValue> {
1245         Some(http::header::HeaderValue::from_static(TUNNEL_UA))
1246     }
1247 
1248     #[test]
test_tunnel()1249     fn test_tunnel() {
1250         let addr = mock_tunnel!();
1251 
1252         let rt = runtime::Builder::new_current_thread()
1253             .enable_all()
1254             .build()
1255             .expect("new rt");
1256         let f = async move {
1257             let tcp = TcpStream::connect(&addr).await?;
1258             let host = addr.ip().to_string();
1259             let port = addr.port();
1260             tunnel(tcp, host, port, ua(), None).await
1261         };
1262 
1263         rt.block_on(f).unwrap();
1264     }
1265 
1266     #[test]
test_tunnel_eof()1267     fn test_tunnel_eof() {
1268         let addr = mock_tunnel!(b"HTTP/1.1 200 OK");
1269 
1270         let rt = runtime::Builder::new_current_thread()
1271             .enable_all()
1272             .build()
1273             .expect("new rt");
1274         let f = async move {
1275             let tcp = TcpStream::connect(&addr).await?;
1276             let host = addr.ip().to_string();
1277             let port = addr.port();
1278             tunnel(tcp, host, port, ua(), None).await
1279         };
1280 
1281         rt.block_on(f).unwrap_err();
1282     }
1283 
1284     #[test]
test_tunnel_non_http_response()1285     fn test_tunnel_non_http_response() {
1286         let addr = mock_tunnel!(b"foo bar baz hallo");
1287 
1288         let rt = runtime::Builder::new_current_thread()
1289             .enable_all()
1290             .build()
1291             .expect("new rt");
1292         let f = async move {
1293             let tcp = TcpStream::connect(&addr).await?;
1294             let host = addr.ip().to_string();
1295             let port = addr.port();
1296             tunnel(tcp, host, port, ua(), None).await
1297         };
1298 
1299         rt.block_on(f).unwrap_err();
1300     }
1301 
1302     #[test]
test_tunnel_proxy_unauthorized()1303     fn test_tunnel_proxy_unauthorized() {
1304         let addr = mock_tunnel!(
1305             b"\
1306             HTTP/1.1 407 Proxy Authentication Required\r\n\
1307             Proxy-Authenticate: Basic realm=\"nope\"\r\n\
1308             \r\n\
1309         "
1310         );
1311 
1312         let rt = runtime::Builder::new_current_thread()
1313             .enable_all()
1314             .build()
1315             .expect("new rt");
1316         let f = async move {
1317             let tcp = TcpStream::connect(&addr).await?;
1318             let host = addr.ip().to_string();
1319             let port = addr.port();
1320             tunnel(tcp, host, port, ua(), None).await
1321         };
1322 
1323         let error = rt.block_on(f).unwrap_err();
1324         assert_eq!(error.to_string(), "proxy authentication required");
1325     }
1326 
1327     #[test]
test_tunnel_basic_auth()1328     fn test_tunnel_basic_auth() {
1329         let addr = mock_tunnel!(
1330             TUNNEL_OK,
1331             "Proxy-Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==\r\n"
1332         );
1333 
1334         let rt = runtime::Builder::new_current_thread()
1335             .enable_all()
1336             .build()
1337             .expect("new rt");
1338         let f = async move {
1339             let tcp = TcpStream::connect(&addr).await?;
1340             let host = addr.ip().to_string();
1341             let port = addr.port();
1342             tunnel(
1343                 tcp,
1344                 host,
1345                 port,
1346                 ua(),
1347                 Some(proxy::encode_basic_auth("Aladdin", "open sesame")),
1348             )
1349             .await
1350         };
1351 
1352         rt.block_on(f).unwrap();
1353     }
1354 }
1355