1 use std::fmt;
2 use std::io;
3 use std::net::{SocketAddr, TcpListener as StdTcpListener};
4 use std::time::Duration;
5 
6 use tokio::net::TcpListener;
7 use tokio::time::Sleep;
8 
9 use crate::common::{task, Future, Pin, Poll};
10 
11 #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411
12 pub use self::addr_stream::AddrStream;
13 use super::accept::Accept;
14 
15 /// A stream of connections from binding to an address.
16 #[must_use = "streams do nothing unless polled"]
17 pub struct AddrIncoming {
18     addr: SocketAddr,
19     listener: TcpListener,
20     sleep_on_errors: bool,
21     tcp_keepalive_timeout: Option<Duration>,
22     tcp_nodelay: bool,
23     timeout: Option<Pin<Box<Sleep>>>,
24 }
25 
26 impl AddrIncoming {
new(addr: &SocketAddr) -> crate::Result<Self>27     pub(super) fn new(addr: &SocketAddr) -> crate::Result<Self> {
28         let std_listener = StdTcpListener::bind(addr).map_err(crate::Error::new_listen)?;
29 
30         AddrIncoming::from_std(std_listener)
31     }
32 
from_std(std_listener: StdTcpListener) -> crate::Result<Self>33     pub(super) fn from_std(std_listener: StdTcpListener) -> crate::Result<Self> {
34         // TcpListener::from_std doesn't set O_NONBLOCK
35         std_listener
36             .set_nonblocking(true)
37             .map_err(crate::Error::new_listen)?;
38         let listener = TcpListener::from_std(std_listener).map_err(crate::Error::new_listen)?;
39         AddrIncoming::from_listener(listener)
40     }
41 
42     /// Creates a new `AddrIncoming` binding to provided socket address.
bind(addr: &SocketAddr) -> crate::Result<Self>43     pub fn bind(addr: &SocketAddr) -> crate::Result<Self> {
44         AddrIncoming::new(addr)
45     }
46 
47     /// Creates a new `AddrIncoming` from an existing `tokio::net::TcpListener`.
from_listener(listener: TcpListener) -> crate::Result<Self>48     pub fn from_listener(listener: TcpListener) -> crate::Result<Self> {
49         let addr = listener.local_addr().map_err(crate::Error::new_listen)?;
50         Ok(AddrIncoming {
51             listener,
52             addr,
53             sleep_on_errors: true,
54             tcp_keepalive_timeout: None,
55             tcp_nodelay: false,
56             timeout: None,
57         })
58     }
59 
60     /// Get the local address bound to this listener.
local_addr(&self) -> SocketAddr61     pub fn local_addr(&self) -> SocketAddr {
62         self.addr
63     }
64 
65     /// Set whether TCP keepalive messages are enabled on accepted connections.
66     ///
67     /// If `None` is specified, keepalive is disabled, otherwise the duration
68     /// specified will be the time to remain idle before sending TCP keepalive
69     /// probes.
set_keepalive(&mut self, keepalive: Option<Duration>) -> &mut Self70     pub fn set_keepalive(&mut self, keepalive: Option<Duration>) -> &mut Self {
71         self.tcp_keepalive_timeout = keepalive;
72         self
73     }
74 
75     /// Set the value of `TCP_NODELAY` option for accepted connections.
set_nodelay(&mut self, enabled: bool) -> &mut Self76     pub fn set_nodelay(&mut self, enabled: bool) -> &mut Self {
77         self.tcp_nodelay = enabled;
78         self
79     }
80 
81     /// Set whether to sleep on accept errors.
82     ///
83     /// A possible scenario is that the process has hit the max open files
84     /// allowed, and so trying to accept a new connection will fail with
85     /// `EMFILE`. In some cases, it's preferable to just wait for some time, if
86     /// the application will likely close some files (or connections), and try
87     /// to accept the connection again. If this option is `true`, the error
88     /// will be logged at the `error` level, since it is still a big deal,
89     /// and then the listener will sleep for 1 second.
90     ///
91     /// In other cases, hitting the max open files should be treat similarly
92     /// to being out-of-memory, and simply error (and shutdown). Setting
93     /// this option to `false` will allow that.
94     ///
95     /// Default is `true`.
set_sleep_on_errors(&mut self, val: bool)96     pub fn set_sleep_on_errors(&mut self, val: bool) {
97         self.sleep_on_errors = val;
98     }
99 
poll_next_(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<AddrStream>>100     fn poll_next_(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<AddrStream>> {
101         // Check if a previous timeout is active that was set by IO errors.
102         if let Some(ref mut to) = self.timeout {
103             ready!(Pin::new(to).poll(cx));
104         }
105         self.timeout = None;
106 
107         loop {
108             match ready!(self.listener.poll_accept(cx)) {
109                 Ok((socket, addr)) => {
110                     if let Some(dur) = self.tcp_keepalive_timeout {
111                         let socket = socket2::SockRef::from(&socket);
112                         let conf = socket2::TcpKeepalive::new().with_time(dur);
113                         if let Err(e) = socket.set_tcp_keepalive(&conf) {
114                             trace!("error trying to set TCP keepalive: {}", e);
115                         }
116                     }
117                     if let Err(e) = socket.set_nodelay(self.tcp_nodelay) {
118                         trace!("error trying to set TCP nodelay: {}", e);
119                     }
120                     return Poll::Ready(Ok(AddrStream::new(socket, addr)));
121                 }
122                 Err(e) => {
123                     // Connection errors can be ignored directly, continue by
124                     // accepting the next request.
125                     if is_connection_error(&e) {
126                         debug!("accepted connection already errored: {}", e);
127                         continue;
128                     }
129 
130                     if self.sleep_on_errors {
131                         error!("accept error: {}", e);
132 
133                         // Sleep 1s.
134                         let mut timeout = Box::pin(tokio::time::sleep(Duration::from_secs(1)));
135 
136                         match timeout.as_mut().poll(cx) {
137                             Poll::Ready(()) => {
138                                 // Wow, it's been a second already? Ok then...
139                                 continue;
140                             }
141                             Poll::Pending => {
142                                 self.timeout = Some(timeout);
143                                 return Poll::Pending;
144                             }
145                         }
146                     } else {
147                         return Poll::Ready(Err(e));
148                     }
149                 }
150             }
151         }
152     }
153 }
154 
155 impl Accept for AddrIncoming {
156     type Conn = AddrStream;
157     type Error = io::Error;
158 
poll_accept( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll<Option<Result<Self::Conn, Self::Error>>>159     fn poll_accept(
160         mut self: Pin<&mut Self>,
161         cx: &mut task::Context<'_>,
162     ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
163         let result = ready!(self.poll_next_(cx));
164         Poll::Ready(Some(result))
165     }
166 }
167 
168 /// This function defines errors that are per-connection. Which basically
169 /// means that if we get this error from `accept()` system call it means
170 /// next connection might be ready to be accepted.
171 ///
172 /// All other errors will incur a timeout before next `accept()` is performed.
173 /// The timeout is useful to handle resource exhaustion errors like ENFILE
174 /// and EMFILE. Otherwise, could enter into tight loop.
is_connection_error(e: &io::Error) -> bool175 fn is_connection_error(e: &io::Error) -> bool {
176     matches!(e.kind(), io::ErrorKind::ConnectionRefused
177         | io::ErrorKind::ConnectionAborted
178         | io::ErrorKind::ConnectionReset)
179 }
180 
181 impl fmt::Debug for AddrIncoming {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result182     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183         f.debug_struct("AddrIncoming")
184             .field("addr", &self.addr)
185             .field("sleep_on_errors", &self.sleep_on_errors)
186             .field("tcp_keepalive_timeout", &self.tcp_keepalive_timeout)
187             .field("tcp_nodelay", &self.tcp_nodelay)
188             .finish()
189     }
190 }
191 
192 mod addr_stream {
193     use std::io;
194     use std::net::SocketAddr;
195     #[cfg(unix)]
196     use std::os::unix::io::{AsRawFd, RawFd};
197     use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
198     use tokio::net::TcpStream;
199 
200     use crate::common::{task, Pin, Poll};
201 
202     pin_project_lite::pin_project! {
203         /// A transport returned yieled by `AddrIncoming`.
204         #[derive(Debug)]
205         pub struct AddrStream {
206             #[pin]
207             inner: TcpStream,
208             pub(super) remote_addr: SocketAddr,
209         }
210     }
211 
212     impl AddrStream {
new(tcp: TcpStream, addr: SocketAddr) -> AddrStream213         pub(super) fn new(tcp: TcpStream, addr: SocketAddr) -> AddrStream {
214             AddrStream {
215                 inner: tcp,
216                 remote_addr: addr,
217             }
218         }
219 
220         /// Returns the remote (peer) address of this connection.
221         #[inline]
remote_addr(&self) -> SocketAddr222         pub fn remote_addr(&self) -> SocketAddr {
223             self.remote_addr
224         }
225 
226         /// Consumes the AddrStream and returns the underlying IO object
227         #[inline]
into_inner(self) -> TcpStream228         pub fn into_inner(self) -> TcpStream {
229             self.inner
230         }
231 
232         /// Attempt to receive data on the socket, without removing that data
233         /// from the queue, registering the current task for wakeup if data is
234         /// not yet available.
poll_peek( &mut self, cx: &mut task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll<io::Result<usize>>235         pub fn poll_peek(
236             &mut self,
237             cx: &mut task::Context<'_>,
238             buf: &mut tokio::io::ReadBuf<'_>,
239         ) -> Poll<io::Result<usize>> {
240             self.inner.poll_peek(cx, buf)
241         }
242     }
243 
244     impl AsyncRead for AddrStream {
245         #[inline]
poll_read( self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>246         fn poll_read(
247             self: Pin<&mut Self>,
248             cx: &mut task::Context<'_>,
249             buf: &mut ReadBuf<'_>,
250         ) -> Poll<io::Result<()>> {
251             self.project().inner.poll_read(cx, buf)
252         }
253     }
254 
255     impl AsyncWrite for AddrStream {
256         #[inline]
poll_write( self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>257         fn poll_write(
258             self: Pin<&mut Self>,
259             cx: &mut task::Context<'_>,
260             buf: &[u8],
261         ) -> Poll<io::Result<usize>> {
262             self.project().inner.poll_write(cx, buf)
263         }
264 
265         #[inline]
poll_write_vectored( self: Pin<&mut Self>, cx: &mut task::Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll<io::Result<usize>>266         fn poll_write_vectored(
267             self: Pin<&mut Self>,
268             cx: &mut task::Context<'_>,
269             bufs: &[io::IoSlice<'_>],
270         ) -> Poll<io::Result<usize>> {
271             self.project().inner.poll_write_vectored(cx, bufs)
272         }
273 
274         #[inline]
poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>>275         fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
276             // TCP flush is a noop
277             Poll::Ready(Ok(()))
278         }
279 
280         #[inline]
poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>>281         fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
282             self.project().inner.poll_shutdown(cx)
283         }
284 
285         #[inline]
is_write_vectored(&self) -> bool286         fn is_write_vectored(&self) -> bool {
287             // Note that since `self.inner` is a `TcpStream`, this could
288             // *probably* be hard-coded to return `true`...but it seems more
289             // correct to ask it anyway (maybe we're on some platform without
290             // scatter-gather IO?)
291             self.inner.is_write_vectored()
292         }
293     }
294 
295     #[cfg(unix)]
296     impl AsRawFd for AddrStream {
as_raw_fd(&self) -> RawFd297         fn as_raw_fd(&self) -> RawFd {
298             self.inner.as_raw_fd()
299         }
300     }
301 }
302