1 use std::{
2     collections::VecDeque,
3     future::Future,
4     io,
5     net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6},
6     pin::Pin,
7     task::{Context, Poll},
8 };
9 
10 use actix_rt::net::{TcpSocket, TcpStream};
11 use actix_service::{Service, ServiceFactory};
12 use futures_core::{future::LocalBoxFuture, ready};
13 use log::{error, trace};
14 use tokio_util::sync::ReusableBoxFuture;
15 
16 use super::connect::{Address, Connect, ConnectAddrs, Connection};
17 use super::error::ConnectError;
18 
19 /// TCP connector service factory
20 #[derive(Debug, Copy, Clone)]
21 pub struct TcpConnectorFactory;
22 
23 impl TcpConnectorFactory {
24     /// Create TCP connector service
service(&self) -> TcpConnector25     pub fn service(&self) -> TcpConnector {
26         TcpConnector
27     }
28 }
29 
30 impl<T: Address> ServiceFactory<Connect<T>> for TcpConnectorFactory {
31     type Response = Connection<T, TcpStream>;
32     type Error = ConnectError;
33     type Config = ();
34     type Service = TcpConnector;
35     type InitError = ();
36     type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
37 
new_service(&self, _: ()) -> Self::Future38     fn new_service(&self, _: ()) -> Self::Future {
39         let service = self.service();
40         Box::pin(async move { Ok(service) })
41     }
42 }
43 
44 /// TCP connector service
45 #[derive(Debug, Copy, Clone)]
46 pub struct TcpConnector;
47 
48 impl<T: Address> Service<Connect<T>> for TcpConnector {
49     type Response = Connection<T, TcpStream>;
50     type Error = ConnectError;
51     type Future = TcpConnectorResponse<T>;
52 
53     actix_service::always_ready!();
54 
call(&self, req: Connect<T>) -> Self::Future55     fn call(&self, req: Connect<T>) -> Self::Future {
56         let port = req.port();
57         let Connect {
58             req,
59             addr,
60             local_addr,
61             ..
62         } = req;
63 
64         TcpConnectorResponse::new(req, port, local_addr, addr)
65     }
66 }
67 
68 /// TCP stream connector response future
69 pub enum TcpConnectorResponse<T> {
70     Response {
71         req: Option<T>,
72         port: u16,
73         local_addr: Option<IpAddr>,
74         addrs: Option<VecDeque<SocketAddr>>,
75         stream: ReusableBoxFuture<Result<TcpStream, io::Error>>,
76     },
77     Error(Option<ConnectError>),
78 }
79 
80 impl<T: Address> TcpConnectorResponse<T> {
new( req: T, port: u16, local_addr: Option<IpAddr>, addr: ConnectAddrs, ) -> TcpConnectorResponse<T>81     pub(crate) fn new(
82         req: T,
83         port: u16,
84         local_addr: Option<IpAddr>,
85         addr: ConnectAddrs,
86     ) -> TcpConnectorResponse<T> {
87         if addr.is_none() {
88             error!("TCP connector: unresolved connection address");
89             return TcpConnectorResponse::Error(Some(ConnectError::Unresolved));
90         }
91 
92         trace!(
93             "TCP connector: connecting to {} on port {}",
94             req.hostname(),
95             port
96         );
97 
98         match addr {
99             ConnectAddrs::None => unreachable!("none variant already checked"),
100 
101             ConnectAddrs::One(addr) => TcpConnectorResponse::Response {
102                 req: Some(req),
103                 port,
104                 local_addr,
105                 addrs: None,
106                 stream: ReusableBoxFuture::new(connect(addr, local_addr)),
107             },
108 
109             // when resolver returns multiple socket addr for request they would be popped from
110             // front end of queue and returns with the first successful tcp connection.
111             ConnectAddrs::Multi(mut addrs) => {
112                 let addr = addrs.pop_front().unwrap();
113 
114                 TcpConnectorResponse::Response {
115                     req: Some(req),
116                     port,
117                     local_addr,
118                     addrs: Some(addrs),
119                     stream: ReusableBoxFuture::new(connect(addr, local_addr)),
120                 }
121             }
122         }
123     }
124 }
125 
126 impl<T: Address> Future for TcpConnectorResponse<T> {
127     type Output = Result<Connection<T, TcpStream>, ConnectError>;
128 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>129     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
130         match self.get_mut() {
131             TcpConnectorResponse::Error(err) => Poll::Ready(Err(err.take().unwrap())),
132 
133             TcpConnectorResponse::Response {
134                 req,
135                 port,
136                 local_addr,
137                 addrs,
138                 stream,
139             } => loop {
140                 match ready!(stream.poll(cx)) {
141                     Ok(sock) => {
142                         let req = req.take().unwrap();
143                         trace!(
144                             "TCP connector: successfully connected to {:?} - {:?}",
145                             req.hostname(),
146                             sock.peer_addr()
147                         );
148                         return Poll::Ready(Ok(Connection::new(sock, req)));
149                     }
150 
151                     Err(err) => {
152                         trace!(
153                             "TCP connector: failed to connect to {:?} port: {}",
154                             req.as_ref().unwrap().hostname(),
155                             port,
156                         );
157 
158                         if let Some(addr) = addrs.as_mut().and_then(|addrs| addrs.pop_front()) {
159                             stream.set(connect(addr, *local_addr));
160                         } else {
161                             return Poll::Ready(Err(ConnectError::Io(err)));
162                         }
163                     }
164                 }
165             },
166         }
167     }
168 }
169 
connect(addr: SocketAddr, local_addr: Option<IpAddr>) -> io::Result<TcpStream>170 async fn connect(addr: SocketAddr, local_addr: Option<IpAddr>) -> io::Result<TcpStream> {
171     // use local addr if connect asks for it.
172     match local_addr {
173         Some(ip_addr) => {
174             let socket = match ip_addr {
175                 IpAddr::V4(ip_addr) => {
176                     let socket = TcpSocket::new_v4()?;
177                     let addr = SocketAddr::V4(SocketAddrV4::new(ip_addr, 0));
178                     socket.bind(addr)?;
179                     socket
180                 }
181                 IpAddr::V6(ip_addr) => {
182                     let socket = TcpSocket::new_v6()?;
183                     let addr = SocketAddr::V6(SocketAddrV6::new(ip_addr, 0, 0, 0));
184                     socket.bind(addr)?;
185                     socket
186                 }
187             };
188 
189             socket.connect(addr).await
190         }
191 
192         None => TcpStream::connect(addr).await,
193     }
194 }
195