1 use futures_util::future::Either;
2 #[cfg(feature = "__tls")]
3 use http::header::HeaderValue;
4 use http::uri::{Authority, Scheme};
5 use http::Uri;
6 use hyper::client::connect::{
7 dns::{GaiResolver, Name},
8 Connected, Connection,
9 };
10 use hyper::service::Service;
11 #[cfg(feature = "native-tls-crate")]
12 use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
13 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14
15 use pin_project_lite::pin_project;
16 use std::io::IoSlice;
17 use std::net::IpAddr;
18 use std::pin::Pin;
19 use std::sync::Arc;
20 use std::task::{Context, Poll};
21 use std::time::Duration;
22 use std::{collections::HashMap, io};
23 use std::{future::Future, net::SocketAddr};
24
25 #[cfg(feature = "default-tls")]
26 use self::native_tls_conn::NativeTlsConn;
27 #[cfg(feature = "__rustls")]
28 use self::rustls_tls_conn::RustlsTlsConn;
29 #[cfg(feature = "trust-dns")]
30 use crate::dns::TrustDnsResolver;
31 use crate::error::BoxError;
32 use crate::proxy::{Proxy, ProxyScheme};
33
34 #[derive(Clone)]
35 pub(crate) enum HttpConnector {
36 Gai(hyper::client::HttpConnector),
37 GaiWithDnsOverrides(hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>>),
38 #[cfg(feature = "trust-dns")]
39 TrustDns(hyper::client::HttpConnector<TrustDnsResolver>),
40 #[cfg(feature = "trust-dns")]
41 TrustDnsWithOverrides(hyper::client::HttpConnector<DnsResolverWithOverrides<TrustDnsResolver>>),
42 }
43
44 impl HttpConnector {
new_gai() -> Self45 pub(crate) fn new_gai() -> Self {
46 Self::Gai(hyper::client::HttpConnector::new())
47 }
48
new_gai_with_overrides(overrides: HashMap<String, SocketAddr>) -> Self49 pub(crate) fn new_gai_with_overrides(overrides: HashMap<String, SocketAddr>) -> Self {
50 let gai = hyper::client::connect::dns::GaiResolver::new();
51 let overridden_resolver = DnsResolverWithOverrides::new(gai, overrides);
52 Self::GaiWithDnsOverrides(hyper::client::HttpConnector::new_with_resolver(
53 overridden_resolver,
54 ))
55 }
56
57 #[cfg(feature = "trust-dns")]
new_trust_dns() -> crate::Result<HttpConnector>58 pub(crate) fn new_trust_dns() -> crate::Result<HttpConnector> {
59 TrustDnsResolver::new()
60 .map(hyper::client::HttpConnector::new_with_resolver)
61 .map(Self::TrustDns)
62 .map_err(crate::error::builder)
63 }
64
65 #[cfg(feature = "trust-dns")]
new_trust_dns_with_overrides( overrides: HashMap<String, SocketAddr>, ) -> crate::Result<HttpConnector>66 pub(crate) fn new_trust_dns_with_overrides(
67 overrides: HashMap<String, SocketAddr>,
68 ) -> crate::Result<HttpConnector> {
69 TrustDnsResolver::new()
70 .map(|resolver| DnsResolverWithOverrides::new(resolver, overrides))
71 .map(hyper::client::HttpConnector::new_with_resolver)
72 .map(Self::TrustDnsWithOverrides)
73 .map_err(crate::error::builder)
74 }
75 }
76
77 macro_rules! impl_http_connector {
78 ($(fn $name:ident(&mut self, $($par_name:ident: $par_type:ty),*)$( -> $return:ty)?;)+) => {
79 #[allow(dead_code)]
80 impl HttpConnector {
81 $(
82 fn $name(&mut self, $($par_name: $par_type),*)$( -> $return)? {
83 match self {
84 Self::Gai(resolver) => resolver.$name($($par_name),*),
85 Self::GaiWithDnsOverrides(resolver) => resolver.$name($($par_name),*),
86 #[cfg(feature = "trust-dns")]
87 Self::TrustDns(resolver) => resolver.$name($($par_name),*),
88 #[cfg(feature = "trust-dns")]
89 Self::TrustDnsWithOverrides(resolver) => resolver.$name($($par_name),*),
90 }
91 }
92 )+
93 }
94 };
95 }
96
97 impl_http_connector! {
98 fn set_local_address(&mut self, addr: Option<IpAddr>);
99 fn enforce_http(&mut self, is_enforced: bool);
100 fn set_nodelay(&mut self, nodelay: bool);
101 fn set_keepalive(&mut self, dur: Option<Duration>);
102 }
103
104 impl Service<Uri> for HttpConnector {
105 type Response = <hyper::client::HttpConnector as Service<Uri>>::Response;
106 type Error = <hyper::client::HttpConnector as Service<Uri>>::Error;
107 #[cfg(feature = "trust-dns")]
108 type Future =
109 Either<
110 Either<
111 <hyper::client::HttpConnector as Service<Uri>>::Future,
112 <hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>> as Service<
113 Uri,
114 >>::Future,
115 >,
116 Either<
117 <hyper::client::HttpConnector<TrustDnsResolver> as Service<Uri>>::Future,
118 <hyper::client::HttpConnector<DnsResolverWithOverrides<TrustDnsResolver>> as Service<Uri>>::Future
119 >
120 >;
121 #[cfg(not(feature = "trust-dns"))]
122 type Future =
123 Either<
124 Either<
125 <hyper::client::HttpConnector as Service<Uri>>::Future,
126 <hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>> as Service<
127 Uri,
128 >>::Future,
129 >,
130 Either<
131 <hyper::client::HttpConnector as Service<Uri>>::Future,
132 <hyper::client::HttpConnector as Service<Uri>>::Future,
133 >,
134 >;
135
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>136 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
137 match self {
138 Self::Gai(resolver) => resolver.poll_ready(cx),
139 Self::GaiWithDnsOverrides(resolver) => resolver.poll_ready(cx),
140 #[cfg(feature = "trust-dns")]
141 Self::TrustDns(resolver) => resolver.poll_ready(cx),
142 #[cfg(feature = "trust-dns")]
143 Self::TrustDnsWithOverrides(resolver) => resolver.poll_ready(cx),
144 }
145 }
146
call(&mut self, dst: Uri) -> Self::Future147 fn call(&mut self, dst: Uri) -> Self::Future {
148 match self {
149 Self::Gai(resolver) => Either::Left(Either::Left(resolver.call(dst))),
150 Self::GaiWithDnsOverrides(resolver) => Either::Left(Either::Right(resolver.call(dst))),
151 #[cfg(feature = "trust-dns")]
152 Self::TrustDns(resolver) => Either::Right(Either::Left(resolver.call(dst))),
153 #[cfg(feature = "trust-dns")]
154 Self::TrustDnsWithOverrides(resolver) => {
155 Either::Right(Either::Right(resolver.call(dst)))
156 }
157 }
158 }
159 }
160
161 #[derive(Clone)]
162 pub(crate) struct Connector {
163 inner: Inner,
164 proxies: Arc<Vec<Proxy>>,
165 verbose: verbose::Wrapper,
166 timeout: Option<Duration>,
167 #[cfg(feature = "__tls")]
168 nodelay: bool,
169 #[cfg(feature = "__tls")]
170 user_agent: Option<HeaderValue>,
171 }
172
173 #[derive(Clone)]
174 enum Inner {
175 #[cfg(not(feature = "__tls"))]
176 Http(HttpConnector),
177 #[cfg(feature = "default-tls")]
178 DefaultTls(HttpConnector, TlsConnector),
179 #[cfg(feature = "__rustls")]
180 RustlsTls {
181 http: HttpConnector,
182 tls: Arc<rustls::ClientConfig>,
183 tls_proxy: Arc<rustls::ClientConfig>,
184 },
185 }
186
187 impl Connector {
188 #[cfg(not(feature = "__tls"))]
new<T>( mut http: HttpConnector, proxies: Arc<Vec<Proxy>>, local_addr: T, nodelay: bool, ) -> Connector where T: Into<Option<IpAddr>>,189 pub(crate) fn new<T>(
190 mut http: HttpConnector,
191 proxies: Arc<Vec<Proxy>>,
192 local_addr: T,
193 nodelay: bool,
194 ) -> Connector
195 where
196 T: Into<Option<IpAddr>>,
197 {
198 http.set_local_address(local_addr.into());
199 http.set_nodelay(nodelay);
200 Connector {
201 inner: Inner::Http(http),
202 verbose: verbose::OFF,
203 proxies,
204 timeout: None,
205 }
206 }
207
208 #[cfg(feature = "default-tls")]
new_default_tls<T>( http: HttpConnector, tls: TlsConnectorBuilder, proxies: Arc<Vec<Proxy>>, user_agent: Option<HeaderValue>, local_addr: T, nodelay: bool, ) -> crate::Result<Connector> where T: Into<Option<IpAddr>>,209 pub(crate) fn new_default_tls<T>(
210 http: HttpConnector,
211 tls: TlsConnectorBuilder,
212 proxies: Arc<Vec<Proxy>>,
213 user_agent: Option<HeaderValue>,
214 local_addr: T,
215 nodelay: bool,
216 ) -> crate::Result<Connector>
217 where
218 T: Into<Option<IpAddr>>,
219 {
220 let tls = tls.build().map_err(crate::error::builder)?;
221 Ok(Self::from_built_default_tls(
222 http, tls, proxies, user_agent, local_addr, nodelay,
223 ))
224 }
225
226 #[cfg(feature = "default-tls")]
from_built_default_tls<T>( mut http: HttpConnector, tls: TlsConnector, proxies: Arc<Vec<Proxy>>, user_agent: Option<HeaderValue>, local_addr: T, nodelay: bool, ) -> Connector where T: Into<Option<IpAddr>>,227 pub(crate) fn from_built_default_tls<T>(
228 mut http: HttpConnector,
229 tls: TlsConnector,
230 proxies: Arc<Vec<Proxy>>,
231 user_agent: Option<HeaderValue>,
232 local_addr: T,
233 nodelay: bool,
234 ) -> Connector
235 where
236 T: Into<Option<IpAddr>>,
237 {
238 http.set_local_address(local_addr.into());
239 http.enforce_http(false);
240
241 Connector {
242 inner: Inner::DefaultTls(http, tls),
243 proxies,
244 verbose: verbose::OFF,
245 timeout: None,
246 nodelay,
247 user_agent,
248 }
249 }
250
251 #[cfg(feature = "__rustls")]
new_rustls_tls<T>( mut http: HttpConnector, tls: rustls::ClientConfig, proxies: Arc<Vec<Proxy>>, user_agent: Option<HeaderValue>, local_addr: T, nodelay: bool, ) -> Connector where T: Into<Option<IpAddr>>,252 pub(crate) fn new_rustls_tls<T>(
253 mut http: HttpConnector,
254 tls: rustls::ClientConfig,
255 proxies: Arc<Vec<Proxy>>,
256 user_agent: Option<HeaderValue>,
257 local_addr: T,
258 nodelay: bool,
259 ) -> Connector
260 where
261 T: Into<Option<IpAddr>>,
262 {
263 http.set_local_address(local_addr.into());
264 http.enforce_http(false);
265
266 let (tls, tls_proxy) = if proxies.is_empty() {
267 let tls = Arc::new(tls);
268 (tls.clone(), tls)
269 } else {
270 let mut tls_proxy = tls.clone();
271 tls_proxy.alpn_protocols.clear();
272 (Arc::new(tls), Arc::new(tls_proxy))
273 };
274
275 Connector {
276 inner: Inner::RustlsTls {
277 http,
278 tls,
279 tls_proxy,
280 },
281 proxies,
282 verbose: verbose::OFF,
283 timeout: None,
284 nodelay,
285 user_agent,
286 }
287 }
288
set_timeout(&mut self, timeout: Option<Duration>)289 pub(crate) fn set_timeout(&mut self, timeout: Option<Duration>) {
290 self.timeout = timeout;
291 }
292
set_verbose(&mut self, enabled: bool)293 pub(crate) fn set_verbose(&mut self, enabled: bool) {
294 self.verbose.0 = enabled;
295 }
296
297 #[cfg(feature = "socks")]
connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result<Conn, BoxError>298 async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result<Conn, BoxError> {
299 let dns = match proxy {
300 ProxyScheme::Socks5 {
301 remote_dns: false, ..
302 } => socks::DnsResolve::Local,
303 ProxyScheme::Socks5 {
304 remote_dns: true, ..
305 } => socks::DnsResolve::Proxy,
306 ProxyScheme::Http { .. } | ProxyScheme::Https { .. } => {
307 unreachable!("connect_socks is only called for socks proxies");
308 }
309 };
310
311 match &self.inner {
312 #[cfg(feature = "default-tls")]
313 Inner::DefaultTls(_http, tls) => {
314 if dst.scheme() == Some(&Scheme::HTTPS) {
315 let host = dst.host().ok_or("no host in url")?.to_string();
316 let conn = socks::connect(proxy, dst, dns).await?;
317 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
318 let io = tls_connector.connect(&host, conn).await?;
319 return Ok(Conn {
320 inner: self.verbose.wrap(NativeTlsConn { inner: io }),
321 is_proxy: false,
322 });
323 }
324 }
325 #[cfg(feature = "__rustls")]
326 Inner::RustlsTls { tls_proxy, .. } => {
327 if dst.scheme() == Some(&Scheme::HTTPS) {
328 use tokio_rustls::webpki::DNSNameRef;
329 use tokio_rustls::TlsConnector as RustlsConnector;
330
331 let tls = tls_proxy.clone();
332 let host = dst.host().ok_or("no host in url")?.to_string();
333 let conn = socks::connect(proxy, dst, dns).await?;
334 let dnsname = DNSNameRef::try_from_ascii_str(&host)
335 .map(|dnsname| dnsname.to_owned())
336 .map_err(|_| "Invalid DNS Name")?;
337 let io = RustlsConnector::from(tls)
338 .connect(dnsname.as_ref(), conn)
339 .await?;
340 return Ok(Conn {
341 inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
342 is_proxy: false,
343 });
344 }
345 }
346 #[cfg(not(feature = "__tls"))]
347 Inner::Http(_) => (),
348 }
349
350 socks::connect(proxy, dst, dns).await.map(|tcp| Conn {
351 inner: self.verbose.wrap(tcp),
352 is_proxy: false,
353 })
354 }
355
connect_with_maybe_proxy(self, dst: Uri, is_proxy: bool) -> Result<Conn, BoxError>356 async fn connect_with_maybe_proxy(self, dst: Uri, is_proxy: bool) -> Result<Conn, BoxError> {
357 match self.inner {
358 #[cfg(not(feature = "__tls"))]
359 Inner::Http(mut http) => {
360 let io = http.call(dst).await?;
361 Ok(Conn {
362 inner: self.verbose.wrap(io),
363 is_proxy,
364 })
365 }
366 #[cfg(feature = "default-tls")]
367 Inner::DefaultTls(http, tls) => {
368 let mut http = http.clone();
369
370 // Disable Nagle's algorithm for TLS handshake
371 //
372 // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES
373 if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) {
374 http.set_nodelay(true);
375 }
376
377 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
378 let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
379 let io = http.call(dst).await?;
380
381 if let hyper_tls::MaybeHttpsStream::Https(stream) = io {
382 if !self.nodelay {
383 stream.get_ref().get_ref().get_ref().set_nodelay(false)?;
384 }
385 Ok(Conn {
386 inner: self.verbose.wrap(NativeTlsConn { inner: stream }),
387 is_proxy,
388 })
389 } else {
390 Ok(Conn {
391 inner: self.verbose.wrap(io),
392 is_proxy,
393 })
394 }
395 }
396 #[cfg(feature = "__rustls")]
397 Inner::RustlsTls { http, tls, .. } => {
398 let mut http = http.clone();
399
400 // Disable Nagle's algorithm for TLS handshake
401 //
402 // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES
403 if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) {
404 http.set_nodelay(true);
405 }
406
407 let mut http = hyper_rustls::HttpsConnector::from((http, tls.clone()));
408 let io = http.call(dst).await?;
409
410 if let hyper_rustls::MaybeHttpsStream::Https(stream) = io {
411 if !self.nodelay {
412 let (io, _) = stream.get_ref();
413 io.set_nodelay(false)?;
414 }
415 Ok(Conn {
416 inner: self.verbose.wrap(RustlsTlsConn { inner: stream }),
417 is_proxy,
418 })
419 } else {
420 Ok(Conn {
421 inner: self.verbose.wrap(io),
422 is_proxy,
423 })
424 }
425 }
426 }
427 }
428
connect_via_proxy( self, dst: Uri, proxy_scheme: ProxyScheme, ) -> Result<Conn, BoxError>429 async fn connect_via_proxy(
430 self,
431 dst: Uri,
432 proxy_scheme: ProxyScheme,
433 ) -> Result<Conn, BoxError> {
434 log::debug!("proxy({:?}) intercepts '{:?}'", proxy_scheme, dst);
435
436 let (proxy_dst, _auth) = match proxy_scheme {
437 ProxyScheme::Http { host, auth } => (into_uri(Scheme::HTTP, host), auth),
438 ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth),
439 #[cfg(feature = "socks")]
440 ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme).await,
441 };
442
443 #[cfg(feature = "__tls")]
444 let auth = _auth;
445
446 match &self.inner {
447 #[cfg(feature = "default-tls")]
448 Inner::DefaultTls(http, tls) => {
449 if dst.scheme() == Some(&Scheme::HTTPS) {
450 let host = dst.host().to_owned();
451 let port = dst.port().map(|p| p.as_u16()).unwrap_or(443);
452 let http = http.clone();
453 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
454 let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
455 let conn = http.call(proxy_dst).await?;
456 log::trace!("tunneling HTTPS over proxy");
457 let tunneled = tunnel(
458 conn,
459 host.ok_or("no host in url")?.to_string(),
460 port,
461 self.user_agent.clone(),
462 auth,
463 )
464 .await?;
465 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
466 let io = tls_connector
467 .connect(&host.ok_or("no host in url")?, tunneled)
468 .await?;
469 return Ok(Conn {
470 inner: self.verbose.wrap(NativeTlsConn { inner: io }),
471 is_proxy: false,
472 });
473 }
474 }
475 #[cfg(feature = "__rustls")]
476 Inner::RustlsTls {
477 http,
478 tls,
479 tls_proxy,
480 } => {
481 if dst.scheme() == Some(&Scheme::HTTPS) {
482 use tokio_rustls::webpki::DNSNameRef;
483 use tokio_rustls::TlsConnector as RustlsConnector;
484
485 let host = dst.host().ok_or("no host in url")?.to_string();
486 let port = dst.port().map(|r| r.as_u16()).unwrap_or(443);
487 let http = http.clone();
488 let mut http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone()));
489 let tls = tls.clone();
490 let conn = http.call(proxy_dst).await?;
491 log::trace!("tunneling HTTPS over proxy");
492 let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host)
493 .map(|dnsname| dnsname.to_owned())
494 .map_err(|_| "Invalid DNS Name");
495 let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?;
496 let dnsname = maybe_dnsname?;
497 let io = RustlsConnector::from(tls)
498 .connect(dnsname.as_ref(), tunneled)
499 .await?;
500
501 return Ok(Conn {
502 inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
503 is_proxy: false,
504 });
505 }
506 }
507 #[cfg(not(feature = "__tls"))]
508 Inner::Http(_) => (),
509 }
510
511 self.connect_with_maybe_proxy(proxy_dst, true).await
512 }
513
set_keepalive(&mut self, dur: Option<Duration>)514 pub fn set_keepalive(&mut self, dur: Option<Duration>) {
515 match &mut self.inner {
516 #[cfg(feature = "default-tls")]
517 Inner::DefaultTls(http, _tls) => http.set_keepalive(dur),
518 #[cfg(feature = "__rustls")]
519 Inner::RustlsTls { http, .. } => http.set_keepalive(dur),
520 #[cfg(not(feature = "__tls"))]
521 Inner::Http(http) => http.set_keepalive(dur),
522 }
523 }
524 }
525
into_uri(scheme: Scheme, host: Authority) -> Uri526 fn into_uri(scheme: Scheme, host: Authority) -> Uri {
527 // TODO: Should the `http` crate get `From<(Scheme, Authority)> for Uri`?
528 http::Uri::builder()
529 .scheme(scheme)
530 .authority(host)
531 .path_and_query(http::uri::PathAndQuery::from_static("/"))
532 .build()
533 .expect("scheme and authority is valid Uri")
534 }
535
with_timeout<T, F>(f: F, timeout: Option<Duration>) -> Result<T, BoxError> where F: Future<Output = Result<T, BoxError>>,536 async fn with_timeout<T, F>(f: F, timeout: Option<Duration>) -> Result<T, BoxError>
537 where
538 F: Future<Output = Result<T, BoxError>>,
539 {
540 if let Some(to) = timeout {
541 match tokio::time::timeout(to, f).await {
542 Err(_elapsed) => Err(Box::new(crate::error::TimedOut) as BoxError),
543 Ok(Ok(try_res)) => Ok(try_res),
544 Ok(Err(e)) => Err(e),
545 }
546 } else {
547 f.await
548 }
549 }
550
551 impl Service<Uri> for Connector {
552 type Response = Conn;
553 type Error = BoxError;
554 type Future = Connecting;
555
poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>556 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
557 Poll::Ready(Ok(()))
558 }
559
call(&mut self, dst: Uri) -> Self::Future560 fn call(&mut self, dst: Uri) -> Self::Future {
561 log::debug!("starting new connection: {:?}", dst);
562 let timeout = self.timeout;
563 for prox in self.proxies.iter() {
564 if let Some(proxy_scheme) = prox.intercept(&dst) {
565 return Box::pin(with_timeout(
566 self.clone().connect_via_proxy(dst, proxy_scheme),
567 timeout,
568 ));
569 }
570 }
571
572 Box::pin(with_timeout(
573 self.clone().connect_with_maybe_proxy(dst, false),
574 timeout,
575 ))
576 }
577 }
578
579 pub(crate) trait AsyncConn:
580 AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static
581 {
582 }
583
584 impl<T: AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static> AsyncConn for T {}
585
586 type BoxConn = Box<dyn AsyncConn>;
587
588 pin_project! {
589 /// Note: the `is_proxy` member means *is plain text HTTP proxy*.
590 /// This tells hyper whether the URI should be written in
591 /// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or
592 /// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise.
593 pub(crate) struct Conn {
594 #[pin]
595 inner: BoxConn,
596 is_proxy: bool,
597 }
598 }
599
600 impl Connection for Conn {
connected(&self) -> Connected601 fn connected(&self) -> Connected {
602 self.inner.connected().proxy(self.is_proxy)
603 }
604 }
605
606 impl AsyncRead for Conn {
poll_read( self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>607 fn poll_read(
608 self: Pin<&mut Self>,
609 cx: &mut Context,
610 buf: &mut ReadBuf<'_>,
611 ) -> Poll<io::Result<()>> {
612 let this = self.project();
613 AsyncRead::poll_read(this.inner, cx, buf)
614 }
615 }
616
617 impl AsyncWrite for Conn {
poll_write( self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll<Result<usize, io::Error>>618 fn poll_write(
619 self: Pin<&mut Self>,
620 cx: &mut Context,
621 buf: &[u8],
622 ) -> Poll<Result<usize, io::Error>> {
623 let this = self.project();
624 AsyncWrite::poll_write(this.inner, cx, buf)
625 }
626
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize, io::Error>>627 fn poll_write_vectored(
628 self: Pin<&mut Self>,
629 cx: &mut Context<'_>,
630 bufs: &[IoSlice<'_>],
631 ) -> Poll<Result<usize, io::Error>> {
632 let this = self.project();
633 AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
634 }
635
is_write_vectored(&self) -> bool636 fn is_write_vectored(&self) -> bool {
637 self.inner.is_write_vectored()
638 }
639
poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>>640 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
641 let this = self.project();
642 AsyncWrite::poll_flush(this.inner, cx)
643 }
644
poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>>645 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
646 let this = self.project();
647 AsyncWrite::poll_shutdown(this.inner, cx)
648 }
649 }
650
651 pub(crate) type Connecting = Pin<Box<dyn Future<Output = Result<Conn, BoxError>> + Send>>;
652
653 #[cfg(feature = "__tls")]
tunnel<T>( mut conn: T, host: String, port: u16, user_agent: Option<HeaderValue>, auth: Option<HeaderValue>, ) -> Result<T, BoxError> where T: AsyncRead + AsyncWrite + Unpin,654 async fn tunnel<T>(
655 mut conn: T,
656 host: String,
657 port: u16,
658 user_agent: Option<HeaderValue>,
659 auth: Option<HeaderValue>,
660 ) -> Result<T, BoxError>
661 where
662 T: AsyncRead + AsyncWrite + Unpin,
663 {
664 use tokio::io::{AsyncReadExt, AsyncWriteExt};
665
666 let mut buf = format!(
667 "\
668 CONNECT {0}:{1} HTTP/1.1\r\n\
669 Host: {0}:{1}\r\n\
670 ",
671 host, port
672 )
673 .into_bytes();
674
675 // user-agent
676 if let Some(user_agent) = user_agent {
677 buf.extend_from_slice(b"User-Agent: ");
678 buf.extend_from_slice(user_agent.as_bytes());
679 buf.extend_from_slice(b"\r\n");
680 }
681
682 // proxy-authorization
683 if let Some(value) = auth {
684 log::debug!("tunnel to {}:{} using basic auth", host, port);
685 buf.extend_from_slice(b"Proxy-Authorization: ");
686 buf.extend_from_slice(value.as_bytes());
687 buf.extend_from_slice(b"\r\n");
688 }
689
690 // headers end
691 buf.extend_from_slice(b"\r\n");
692
693 conn.write_all(&buf).await?;
694
695 let mut buf = [0; 8192];
696 let mut pos = 0;
697
698 loop {
699 let n = conn.read(&mut buf[pos..]).await?;
700
701 if n == 0 {
702 return Err(tunnel_eof());
703 }
704 pos += n;
705
706 let recvd = &buf[..pos];
707 if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") {
708 if recvd.ends_with(b"\r\n\r\n") {
709 return Ok(conn);
710 }
711 if pos == buf.len() {
712 return Err("proxy headers too long for tunnel".into());
713 }
714 // else read more
715 } else if recvd.starts_with(b"HTTP/1.1 407") {
716 return Err("proxy authentication required".into());
717 } else {
718 return Err("unsuccessful tunnel".into());
719 }
720 }
721 }
722
723 #[cfg(feature = "__tls")]
tunnel_eof() -> BoxError724 fn tunnel_eof() -> BoxError {
725 "unexpected eof while tunneling".into()
726 }
727
728 #[cfg(feature = "default-tls")]
729 mod native_tls_conn {
730 use hyper::client::connect::{Connected, Connection};
731 use pin_project_lite::pin_project;
732 use std::{
733 io::{self, IoSlice},
734 pin::Pin,
735 task::{Context, Poll},
736 };
737 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
738 use tokio_native_tls::TlsStream;
739
740 pin_project! {
741 pub(super) struct NativeTlsConn<T> {
742 #[pin] pub(super) inner: TlsStream<T>,
743 }
744 }
745
746 impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for NativeTlsConn<T> {
747 #[cfg(feature = "native-tls-alpn")]
connected(&self) -> Connected748 fn connected(&self) -> Connected {
749 match self.inner.get_ref().negotiated_alpn().ok() {
750 Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => self
751 .inner
752 .get_ref()
753 .get_ref()
754 .get_ref()
755 .connected()
756 .negotiated_h2(),
757 _ => self.inner.get_ref().get_ref().get_ref().connected(),
758 }
759 }
760
761 #[cfg(not(feature = "native-tls-alpn"))]
connected(&self) -> Connected762 fn connected(&self) -> Connected {
763 self.inner.get_ref().get_ref().get_ref().connected()
764 }
765 }
766
767 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsConn<T> {
poll_read( self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf<'_>, ) -> Poll<tokio::io::Result<()>>768 fn poll_read(
769 self: Pin<&mut Self>,
770 cx: &mut Context,
771 buf: &mut ReadBuf<'_>,
772 ) -> Poll<tokio::io::Result<()>> {
773 let this = self.project();
774 AsyncRead::poll_read(this.inner, cx, buf)
775 }
776 }
777
778 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for NativeTlsConn<T> {
poll_write( self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll<Result<usize, tokio::io::Error>>779 fn poll_write(
780 self: Pin<&mut Self>,
781 cx: &mut Context,
782 buf: &[u8],
783 ) -> Poll<Result<usize, tokio::io::Error>> {
784 let this = self.project();
785 AsyncWrite::poll_write(this.inner, cx, buf)
786 }
787
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize, io::Error>>788 fn poll_write_vectored(
789 self: Pin<&mut Self>,
790 cx: &mut Context<'_>,
791 bufs: &[IoSlice<'_>],
792 ) -> Poll<Result<usize, io::Error>> {
793 let this = self.project();
794 AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
795 }
796
is_write_vectored(&self) -> bool797 fn is_write_vectored(&self) -> bool {
798 self.inner.is_write_vectored()
799 }
800
poll_flush( self: Pin<&mut Self>, cx: &mut Context, ) -> Poll<Result<(), tokio::io::Error>>801 fn poll_flush(
802 self: Pin<&mut Self>,
803 cx: &mut Context,
804 ) -> Poll<Result<(), tokio::io::Error>> {
805 let this = self.project();
806 AsyncWrite::poll_flush(this.inner, cx)
807 }
808
poll_shutdown( self: Pin<&mut Self>, cx: &mut Context, ) -> Poll<Result<(), tokio::io::Error>>809 fn poll_shutdown(
810 self: Pin<&mut Self>,
811 cx: &mut Context,
812 ) -> Poll<Result<(), tokio::io::Error>> {
813 let this = self.project();
814 AsyncWrite::poll_shutdown(this.inner, cx)
815 }
816 }
817 }
818
819 #[cfg(feature = "__rustls")]
820 mod rustls_tls_conn {
821 use hyper::client::connect::{Connected, Connection};
822 use pin_project_lite::pin_project;
823 use rustls::Session;
824 use std::{
825 io::{self, IoSlice},
826 pin::Pin,
827 task::{Context, Poll},
828 };
829 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
830 use tokio_rustls::client::TlsStream;
831
832 pin_project! {
833 pub(super) struct RustlsTlsConn<T> {
834 #[pin] pub(super) inner: TlsStream<T>,
835 }
836 }
837
838 impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for RustlsTlsConn<T> {
connected(&self) -> Connected839 fn connected(&self) -> Connected {
840 if self.inner.get_ref().1.get_alpn_protocol() == Some(b"h2") {
841 self.inner.get_ref().0.connected().negotiated_h2()
842 } else {
843 self.inner.get_ref().0.connected()
844 }
845 }
846 }
847
848 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for RustlsTlsConn<T> {
poll_read( self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf<'_>, ) -> Poll<tokio::io::Result<()>>849 fn poll_read(
850 self: Pin<&mut Self>,
851 cx: &mut Context,
852 buf: &mut ReadBuf<'_>,
853 ) -> Poll<tokio::io::Result<()>> {
854 let this = self.project();
855 AsyncRead::poll_read(this.inner, cx, buf)
856 }
857 }
858
859 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for RustlsTlsConn<T> {
poll_write( self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll<Result<usize, tokio::io::Error>>860 fn poll_write(
861 self: Pin<&mut Self>,
862 cx: &mut Context,
863 buf: &[u8],
864 ) -> Poll<Result<usize, tokio::io::Error>> {
865 let this = self.project();
866 AsyncWrite::poll_write(this.inner, cx, buf)
867 }
868
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize, io::Error>>869 fn poll_write_vectored(
870 self: Pin<&mut Self>,
871 cx: &mut Context<'_>,
872 bufs: &[IoSlice<'_>],
873 ) -> Poll<Result<usize, io::Error>> {
874 let this = self.project();
875 AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
876 }
877
is_write_vectored(&self) -> bool878 fn is_write_vectored(&self) -> bool {
879 self.inner.is_write_vectored()
880 }
881
poll_flush( self: Pin<&mut Self>, cx: &mut Context, ) -> Poll<Result<(), tokio::io::Error>>882 fn poll_flush(
883 self: Pin<&mut Self>,
884 cx: &mut Context,
885 ) -> Poll<Result<(), tokio::io::Error>> {
886 let this = self.project();
887 AsyncWrite::poll_flush(this.inner, cx)
888 }
889
poll_shutdown( self: Pin<&mut Self>, cx: &mut Context, ) -> Poll<Result<(), tokio::io::Error>>890 fn poll_shutdown(
891 self: Pin<&mut Self>,
892 cx: &mut Context,
893 ) -> Poll<Result<(), tokio::io::Error>> {
894 let this = self.project();
895 AsyncWrite::poll_shutdown(this.inner, cx)
896 }
897 }
898 }
899
900 #[cfg(feature = "socks")]
901 mod socks {
902 use std::io;
903 use std::net::ToSocketAddrs;
904
905 use http::Uri;
906 use tokio::net::TcpStream;
907 use tokio_socks::tcp::Socks5Stream;
908
909 use super::{BoxError, Scheme};
910 use crate::proxy::ProxyScheme;
911
912 pub(super) enum DnsResolve {
913 Local,
914 Proxy,
915 }
916
connect( proxy: ProxyScheme, dst: Uri, dns: DnsResolve, ) -> Result<TcpStream, BoxError>917 pub(super) async fn connect(
918 proxy: ProxyScheme,
919 dst: Uri,
920 dns: DnsResolve,
921 ) -> Result<TcpStream, BoxError> {
922 let https = dst.scheme() == Some(&Scheme::HTTPS);
923 let original_host = dst
924 .host()
925 .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?;
926 let mut host = original_host.to_owned();
927 let port = match dst.port() {
928 Some(p) => p.as_u16(),
929 None if https => 443u16,
930 _ => 80u16,
931 };
932
933 if let DnsResolve::Local = dns {
934 let maybe_new_target = (host.as_str(), port).to_socket_addrs()?.next();
935 if let Some(new_target) = maybe_new_target {
936 host = new_target.ip().to_string();
937 }
938 }
939
940 let (socket_addr, auth) = match proxy {
941 ProxyScheme::Socks5 { addr, auth, .. } => (addr, auth),
942 _ => unreachable!(),
943 };
944
945 // Get a Tokio TcpStream
946 let stream = if let Some((username, password)) = auth {
947 Socks5Stream::connect_with_password(
948 socket_addr,
949 (host.as_str(), port),
950 &username,
951 &password,
952 )
953 .await
954 .map_err(|e| format!("socks connect error: {}", e))?
955 } else {
956 Socks5Stream::connect(socket_addr, (host.as_str(), port))
957 .await
958 .map_err(|e| format!("socks connect error: {}", e))?
959 };
960
961 Ok(stream.into_inner())
962 }
963 }
964
965 pub(crate) mod itertools {
966 pub(crate) enum Either<A, B> {
967 Left(A),
968 Right(B),
969 }
970
971 impl<A, B> Iterator for Either<A, B>
972 where
973 A: Iterator,
974 B: Iterator<Item = <A as Iterator>::Item>,
975 {
976 type Item = <A as Iterator>::Item;
977
next(&mut self) -> Option<Self::Item>978 fn next(&mut self) -> Option<Self::Item> {
979 match self {
980 Either::Left(a) => a.next(),
981 Either::Right(b) => b.next(),
982 }
983 }
984 }
985 }
986
987 pin_project! {
988 pub(crate) struct WrappedResolverFuture<Fut> {
989 #[pin]
990 fut: Fut,
991 }
992 }
993
994 impl<Fut, FutOutput, FutError> std::future::Future for WrappedResolverFuture<Fut>
995 where
996 Fut: std::future::Future<Output = Result<FutOutput, FutError>>,
997 FutOutput: Iterator<Item = SocketAddr>,
998 {
999 type Output = Result<itertools::Either<FutOutput, std::iter::Once<SocketAddr>>, FutError>;
1000
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>1001 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1002 let this = self.project();
1003 this.fut
1004 .poll(cx)
1005 .map(|result| result.map(itertools::Either::Left))
1006 }
1007 }
1008
1009 #[derive(Clone)]
1010 pub(crate) struct DnsResolverWithOverrides<Resolver>
1011 where
1012 Resolver: Clone,
1013 {
1014 dns_resolver: Resolver,
1015 overrides: Arc<HashMap<String, SocketAddr>>,
1016 }
1017
1018 impl<Resolver: Clone> DnsResolverWithOverrides<Resolver> {
new(dns_resolver: Resolver, overrides: HashMap<String, SocketAddr>) -> Self1019 fn new(dns_resolver: Resolver, overrides: HashMap<String, SocketAddr>) -> Self {
1020 DnsResolverWithOverrides {
1021 dns_resolver,
1022 overrides: Arc::new(overrides),
1023 }
1024 }
1025 }
1026
1027 impl<Resolver, Iter> Service<Name> for DnsResolverWithOverrides<Resolver>
1028 where
1029 Resolver: Service<Name, Response = Iter> + Clone,
1030 Iter: Iterator<Item = SocketAddr>,
1031 {
1032 type Response = itertools::Either<Iter, std::iter::Once<SocketAddr>>;
1033 type Error = <Resolver as Service<Name>>::Error;
1034 type Future = Either<
1035 WrappedResolverFuture<<Resolver as Service<Name>>::Future>,
1036 futures_util::future::Ready<
1037 Result<itertools::Either<Iter, std::iter::Once<SocketAddr>>, Self::Error>,
1038 >,
1039 >;
1040
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>1041 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1042 self.dns_resolver.poll_ready(cx)
1043 }
1044
call(&mut self, name: Name) -> Self::Future1045 fn call(&mut self, name: Name) -> Self::Future {
1046 match self.overrides.get(name.as_str()) {
1047 Some(dest) => {
1048 let fut = futures_util::future::ready(Ok(itertools::Either::Right(
1049 std::iter::once(dest.to_owned()),
1050 )));
1051 Either::Right(fut)
1052 }
1053 None => {
1054 let resolver_fut = self.dns_resolver.call(name);
1055 let y = WrappedResolverFuture { fut: resolver_fut };
1056 Either::Left(y)
1057 }
1058 }
1059 }
1060 }
1061
1062 mod verbose {
1063 use hyper::client::connect::{Connected, Connection};
1064 use std::fmt;
1065 use std::io::{self, IoSlice};
1066 use std::pin::Pin;
1067 use std::task::{Context, Poll};
1068 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
1069
1070 pub(super) const OFF: Wrapper = Wrapper(false);
1071
1072 #[derive(Clone, Copy)]
1073 pub(super) struct Wrapper(pub(super) bool);
1074
1075 impl Wrapper {
wrap<T: super::AsyncConn>(&self, conn: T) -> super::BoxConn1076 pub(super) fn wrap<T: super::AsyncConn>(&self, conn: T) -> super::BoxConn {
1077 if self.0 && log::log_enabled!(log::Level::Trace) {
1078 Box::new(Verbose {
1079 // truncate is fine
1080 id: crate::util::fast_random() as u32,
1081 inner: conn,
1082 })
1083 } else {
1084 Box::new(conn)
1085 }
1086 }
1087 }
1088
1089 struct Verbose<T> {
1090 id: u32,
1091 inner: T,
1092 }
1093
1094 impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for Verbose<T> {
connected(&self) -> Connected1095 fn connected(&self) -> Connected {
1096 self.inner.connected()
1097 }
1098 }
1099
1100 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for Verbose<T> {
poll_read( mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>1101 fn poll_read(
1102 mut self: Pin<&mut Self>,
1103 cx: &mut Context,
1104 buf: &mut ReadBuf<'_>,
1105 ) -> Poll<std::io::Result<()>> {
1106 match Pin::new(&mut self.inner).poll_read(cx, buf) {
1107 Poll::Ready(Ok(())) => {
1108 log::trace!("{:08x} read: {:?}", self.id, Escape(buf.filled()));
1109 Poll::Ready(Ok(()))
1110 }
1111 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1112 Poll::Pending => Poll::Pending,
1113 }
1114 }
1115 }
1116
1117 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Verbose<T> {
poll_write( mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll<Result<usize, std::io::Error>>1118 fn poll_write(
1119 mut self: Pin<&mut Self>,
1120 cx: &mut Context,
1121 buf: &[u8],
1122 ) -> Poll<Result<usize, std::io::Error>> {
1123 match Pin::new(&mut self.inner).poll_write(cx, buf) {
1124 Poll::Ready(Ok(n)) => {
1125 log::trace!("{:08x} write: {:?}", self.id, Escape(&buf[..n]));
1126 Poll::Ready(Ok(n))
1127 }
1128 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1129 Poll::Pending => Poll::Pending,
1130 }
1131 }
1132
poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize, io::Error>>1133 fn poll_write_vectored(
1134 mut self: Pin<&mut Self>,
1135 cx: &mut Context<'_>,
1136 bufs: &[IoSlice<'_>],
1137 ) -> Poll<Result<usize, io::Error>> {
1138 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
1139 }
1140
is_write_vectored(&self) -> bool1141 fn is_write_vectored(&self) -> bool {
1142 self.inner.is_write_vectored()
1143 }
1144
poll_flush( mut self: Pin<&mut Self>, cx: &mut Context, ) -> Poll<Result<(), std::io::Error>>1145 fn poll_flush(
1146 mut self: Pin<&mut Self>,
1147 cx: &mut Context,
1148 ) -> Poll<Result<(), std::io::Error>> {
1149 Pin::new(&mut self.inner).poll_flush(cx)
1150 }
1151
poll_shutdown( mut self: Pin<&mut Self>, cx: &mut Context, ) -> Poll<Result<(), std::io::Error>>1152 fn poll_shutdown(
1153 mut self: Pin<&mut Self>,
1154 cx: &mut Context,
1155 ) -> Poll<Result<(), std::io::Error>> {
1156 Pin::new(&mut self.inner).poll_shutdown(cx)
1157 }
1158 }
1159
1160 struct Escape<'a>(&'a [u8]);
1161
1162 impl fmt::Debug for Escape<'_> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result1163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1164 write!(f, "b\"")?;
1165 for &c in self.0 {
1166 // https://doc.rust-lang.org/reference.html#byte-escapes
1167 if c == b'\n' {
1168 write!(f, "\\n")?;
1169 } else if c == b'\r' {
1170 write!(f, "\\r")?;
1171 } else if c == b'\t' {
1172 write!(f, "\\t")?;
1173 } else if c == b'\\' || c == b'"' {
1174 write!(f, "\\{}", c as char)?;
1175 } else if c == b'\0' {
1176 write!(f, "\\0")?;
1177 // ASCII printable
1178 } else if c >= 0x20 && c < 0x7f {
1179 write!(f, "{}", c as char)?;
1180 } else {
1181 write!(f, "\\x{:02x}", c)?;
1182 }
1183 }
1184 write!(f, "\"")?;
1185 Ok(())
1186 }
1187 }
1188 }
1189
1190 #[cfg(feature = "__tls")]
1191 #[cfg(test)]
1192 mod tests {
1193 use super::tunnel;
1194 use crate::proxy;
1195 use std::io::{Read, Write};
1196 use std::net::TcpListener;
1197 use std::thread;
1198 use tokio::net::TcpStream;
1199 use tokio::runtime;
1200
1201 static TUNNEL_UA: &'static str = "tunnel-test/x.y";
1202 static TUNNEL_OK: &[u8] = b"\
1203 HTTP/1.1 200 OK\r\n\
1204 \r\n\
1205 ";
1206
1207 macro_rules! mock_tunnel {
1208 () => {{
1209 mock_tunnel!(TUNNEL_OK)
1210 }};
1211 ($write:expr) => {{
1212 mock_tunnel!($write, "")
1213 }};
1214 ($write:expr, $auth:expr) => {{
1215 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1216 let addr = listener.local_addr().unwrap();
1217 let connect_expected = format!(
1218 "\
1219 CONNECT {0}:{1} HTTP/1.1\r\n\
1220 Host: {0}:{1}\r\n\
1221 User-Agent: {2}\r\n\
1222 {3}\
1223 \r\n\
1224 ",
1225 addr.ip(),
1226 addr.port(),
1227 TUNNEL_UA,
1228 $auth
1229 )
1230 .into_bytes();
1231
1232 thread::spawn(move || {
1233 let (mut sock, _) = listener.accept().unwrap();
1234 let mut buf = [0u8; 4096];
1235 let n = sock.read(&mut buf).unwrap();
1236 assert_eq!(&buf[..n], &connect_expected[..]);
1237
1238 sock.write_all($write).unwrap();
1239 });
1240 addr
1241 }};
1242 }
1243
ua() -> Option<http::header::HeaderValue>1244 fn ua() -> Option<http::header::HeaderValue> {
1245 Some(http::header::HeaderValue::from_static(TUNNEL_UA))
1246 }
1247
1248 #[test]
test_tunnel()1249 fn test_tunnel() {
1250 let addr = mock_tunnel!();
1251
1252 let rt = runtime::Builder::new_current_thread()
1253 .enable_all()
1254 .build()
1255 .expect("new rt");
1256 let f = async move {
1257 let tcp = TcpStream::connect(&addr).await?;
1258 let host = addr.ip().to_string();
1259 let port = addr.port();
1260 tunnel(tcp, host, port, ua(), None).await
1261 };
1262
1263 rt.block_on(f).unwrap();
1264 }
1265
1266 #[test]
test_tunnel_eof()1267 fn test_tunnel_eof() {
1268 let addr = mock_tunnel!(b"HTTP/1.1 200 OK");
1269
1270 let rt = runtime::Builder::new_current_thread()
1271 .enable_all()
1272 .build()
1273 .expect("new rt");
1274 let f = async move {
1275 let tcp = TcpStream::connect(&addr).await?;
1276 let host = addr.ip().to_string();
1277 let port = addr.port();
1278 tunnel(tcp, host, port, ua(), None).await
1279 };
1280
1281 rt.block_on(f).unwrap_err();
1282 }
1283
1284 #[test]
test_tunnel_non_http_response()1285 fn test_tunnel_non_http_response() {
1286 let addr = mock_tunnel!(b"foo bar baz hallo");
1287
1288 let rt = runtime::Builder::new_current_thread()
1289 .enable_all()
1290 .build()
1291 .expect("new rt");
1292 let f = async move {
1293 let tcp = TcpStream::connect(&addr).await?;
1294 let host = addr.ip().to_string();
1295 let port = addr.port();
1296 tunnel(tcp, host, port, ua(), None).await
1297 };
1298
1299 rt.block_on(f).unwrap_err();
1300 }
1301
1302 #[test]
test_tunnel_proxy_unauthorized()1303 fn test_tunnel_proxy_unauthorized() {
1304 let addr = mock_tunnel!(
1305 b"\
1306 HTTP/1.1 407 Proxy Authentication Required\r\n\
1307 Proxy-Authenticate: Basic realm=\"nope\"\r\n\
1308 \r\n\
1309 "
1310 );
1311
1312 let rt = runtime::Builder::new_current_thread()
1313 .enable_all()
1314 .build()
1315 .expect("new rt");
1316 let f = async move {
1317 let tcp = TcpStream::connect(&addr).await?;
1318 let host = addr.ip().to_string();
1319 let port = addr.port();
1320 tunnel(tcp, host, port, ua(), None).await
1321 };
1322
1323 let error = rt.block_on(f).unwrap_err();
1324 assert_eq!(error.to_string(), "proxy authentication required");
1325 }
1326
1327 #[test]
test_tunnel_basic_auth()1328 fn test_tunnel_basic_auth() {
1329 let addr = mock_tunnel!(
1330 TUNNEL_OK,
1331 "Proxy-Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==\r\n"
1332 );
1333
1334 let rt = runtime::Builder::new_current_thread()
1335 .enable_all()
1336 .build()
1337 .expect("new rt");
1338 let f = async move {
1339 let tcp = TcpStream::connect(&addr).await?;
1340 let host = addr.ip().to_string();
1341 let port = addr.port();
1342 tunnel(
1343 tcp,
1344 host,
1345 port,
1346 ua(),
1347 Some(proxy::encode_basic_auth("Aladdin", "open sesame")),
1348 )
1349 .await
1350 };
1351
1352 rt.block_on(f).unwrap();
1353 }
1354 }
1355