1 use std::fmt;
2 use std::io;
3 use std::net::{SocketAddr, TcpListener as StdTcpListener};
4 use std::time::Duration;
5 
6 use tokio::net::TcpListener;
7 use tokio::time::Sleep;
8 use tracing::{debug, error, trace};
9 
10 use crate::common::{task, Future, Pin, Poll};
11 
12 #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411
13 pub use self::addr_stream::AddrStream;
14 use super::accept::Accept;
15 
16 /// A stream of connections from binding to an address.
17 #[must_use = "streams do nothing unless polled"]
18 pub struct AddrIncoming {
19     addr: SocketAddr,
20     listener: TcpListener,
21     sleep_on_errors: bool,
22     tcp_keepalive_timeout: Option<Duration>,
23     tcp_nodelay: bool,
24     timeout: Option<Pin<Box<Sleep>>>,
25 }
26 
27 impl AddrIncoming {
new(addr: &SocketAddr) -> crate::Result<Self>28     pub(super) fn new(addr: &SocketAddr) -> crate::Result<Self> {
29         let std_listener = StdTcpListener::bind(addr).map_err(crate::Error::new_listen)?;
30 
31         AddrIncoming::from_std(std_listener)
32     }
33 
from_std(std_listener: StdTcpListener) -> crate::Result<Self>34     pub(super) fn from_std(std_listener: StdTcpListener) -> crate::Result<Self> {
35         // TcpListener::from_std doesn't set O_NONBLOCK
36         std_listener
37             .set_nonblocking(true)
38             .map_err(crate::Error::new_listen)?;
39         let listener = TcpListener::from_std(std_listener).map_err(crate::Error::new_listen)?;
40         AddrIncoming::from_listener(listener)
41     }
42 
43     /// Creates a new `AddrIncoming` binding to provided socket address.
bind(addr: &SocketAddr) -> crate::Result<Self>44     pub fn bind(addr: &SocketAddr) -> crate::Result<Self> {
45         AddrIncoming::new(addr)
46     }
47 
48     /// Creates a new `AddrIncoming` from an existing `tokio::net::TcpListener`.
from_listener(listener: TcpListener) -> crate::Result<Self>49     pub fn from_listener(listener: TcpListener) -> crate::Result<Self> {
50         let addr = listener.local_addr().map_err(crate::Error::new_listen)?;
51         Ok(AddrIncoming {
52             listener,
53             addr,
54             sleep_on_errors: true,
55             tcp_keepalive_timeout: None,
56             tcp_nodelay: false,
57             timeout: None,
58         })
59     }
60 
61     /// Get the local address bound to this listener.
local_addr(&self) -> SocketAddr62     pub fn local_addr(&self) -> SocketAddr {
63         self.addr
64     }
65 
66     /// Set whether TCP keepalive messages are enabled on accepted connections.
67     ///
68     /// If `None` is specified, keepalive is disabled, otherwise the duration
69     /// specified will be the time to remain idle before sending TCP keepalive
70     /// probes.
set_keepalive(&mut self, keepalive: Option<Duration>) -> &mut Self71     pub fn set_keepalive(&mut self, keepalive: Option<Duration>) -> &mut Self {
72         self.tcp_keepalive_timeout = keepalive;
73         self
74     }
75 
76     /// Set the value of `TCP_NODELAY` option for accepted connections.
set_nodelay(&mut self, enabled: bool) -> &mut Self77     pub fn set_nodelay(&mut self, enabled: bool) -> &mut Self {
78         self.tcp_nodelay = enabled;
79         self
80     }
81 
82     /// Set whether to sleep on accept errors.
83     ///
84     /// A possible scenario is that the process has hit the max open files
85     /// allowed, and so trying to accept a new connection will fail with
86     /// `EMFILE`. In some cases, it's preferable to just wait for some time, if
87     /// the application will likely close some files (or connections), and try
88     /// to accept the connection again. If this option is `true`, the error
89     /// will be logged at the `error` level, since it is still a big deal,
90     /// and then the listener will sleep for 1 second.
91     ///
92     /// In other cases, hitting the max open files should be treat similarly
93     /// to being out-of-memory, and simply error (and shutdown). Setting
94     /// this option to `false` will allow that.
95     ///
96     /// Default is `true`.
set_sleep_on_errors(&mut self, val: bool)97     pub fn set_sleep_on_errors(&mut self, val: bool) {
98         self.sleep_on_errors = val;
99     }
100 
poll_next_(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<AddrStream>>101     fn poll_next_(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<AddrStream>> {
102         // Check if a previous timeout is active that was set by IO errors.
103         if let Some(ref mut to) = self.timeout {
104             ready!(Pin::new(to).poll(cx));
105         }
106         self.timeout = None;
107 
108         loop {
109             match ready!(self.listener.poll_accept(cx)) {
110                 Ok((socket, addr)) => {
111                     if let Some(dur) = self.tcp_keepalive_timeout {
112                         let socket = socket2::SockRef::from(&socket);
113                         let conf = socket2::TcpKeepalive::new().with_time(dur);
114                         if let Err(e) = socket.set_tcp_keepalive(&conf) {
115                             trace!("error trying to set TCP keepalive: {}", e);
116                         }
117                     }
118                     if let Err(e) = socket.set_nodelay(self.tcp_nodelay) {
119                         trace!("error trying to set TCP nodelay: {}", e);
120                     }
121                     return Poll::Ready(Ok(AddrStream::new(socket, addr)));
122                 }
123                 Err(e) => {
124                     // Connection errors can be ignored directly, continue by
125                     // accepting the next request.
126                     if is_connection_error(&e) {
127                         debug!("accepted connection already errored: {}", e);
128                         continue;
129                     }
130 
131                     if self.sleep_on_errors {
132                         error!("accept error: {}", e);
133 
134                         // Sleep 1s.
135                         let mut timeout = Box::pin(tokio::time::sleep(Duration::from_secs(1)));
136 
137                         match timeout.as_mut().poll(cx) {
138                             Poll::Ready(()) => {
139                                 // Wow, it's been a second already? Ok then...
140                                 continue;
141                             }
142                             Poll::Pending => {
143                                 self.timeout = Some(timeout);
144                                 return Poll::Pending;
145                             }
146                         }
147                     } else {
148                         return Poll::Ready(Err(e));
149                     }
150                 }
151             }
152         }
153     }
154 }
155 
156 impl Accept for AddrIncoming {
157     type Conn = AddrStream;
158     type Error = io::Error;
159 
poll_accept( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll<Option<Result<Self::Conn, Self::Error>>>160     fn poll_accept(
161         mut self: Pin<&mut Self>,
162         cx: &mut task::Context<'_>,
163     ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
164         let result = ready!(self.poll_next_(cx));
165         Poll::Ready(Some(result))
166     }
167 }
168 
169 /// This function defines errors that are per-connection. Which basically
170 /// means that if we get this error from `accept()` system call it means
171 /// next connection might be ready to be accepted.
172 ///
173 /// All other errors will incur a timeout before next `accept()` is performed.
174 /// The timeout is useful to handle resource exhaustion errors like ENFILE
175 /// and EMFILE. Otherwise, could enter into tight loop.
is_connection_error(e: &io::Error) -> bool176 fn is_connection_error(e: &io::Error) -> bool {
177     matches!(e.kind(), io::ErrorKind::ConnectionRefused
178         | io::ErrorKind::ConnectionAborted
179         | io::ErrorKind::ConnectionReset)
180 }
181 
182 impl fmt::Debug for AddrIncoming {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result183     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184         f.debug_struct("AddrIncoming")
185             .field("addr", &self.addr)
186             .field("sleep_on_errors", &self.sleep_on_errors)
187             .field("tcp_keepalive_timeout", &self.tcp_keepalive_timeout)
188             .field("tcp_nodelay", &self.tcp_nodelay)
189             .finish()
190     }
191 }
192 
193 mod addr_stream {
194     use std::io;
195     use std::net::SocketAddr;
196     #[cfg(unix)]
197     use std::os::unix::io::{AsRawFd, RawFd};
198     use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
199     use tokio::net::TcpStream;
200 
201     use crate::common::{task, Pin, Poll};
202 
203     pin_project_lite::pin_project! {
204         /// A transport returned yieled by `AddrIncoming`.
205         #[derive(Debug)]
206         pub struct AddrStream {
207             #[pin]
208             inner: TcpStream,
209             pub(super) remote_addr: SocketAddr,
210         }
211     }
212 
213     impl AddrStream {
new(tcp: TcpStream, addr: SocketAddr) -> AddrStream214         pub(super) fn new(tcp: TcpStream, addr: SocketAddr) -> AddrStream {
215             AddrStream {
216                 inner: tcp,
217                 remote_addr: addr,
218             }
219         }
220 
221         /// Returns the remote (peer) address of this connection.
222         #[inline]
remote_addr(&self) -> SocketAddr223         pub fn remote_addr(&self) -> SocketAddr {
224             self.remote_addr
225         }
226 
227         /// Consumes the AddrStream and returns the underlying IO object
228         #[inline]
into_inner(self) -> TcpStream229         pub fn into_inner(self) -> TcpStream {
230             self.inner
231         }
232 
233         /// Attempt to receive data on the socket, without removing that data
234         /// from the queue, registering the current task for wakeup if data is
235         /// not yet available.
poll_peek( &mut self, cx: &mut task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll<io::Result<usize>>236         pub fn poll_peek(
237             &mut self,
238             cx: &mut task::Context<'_>,
239             buf: &mut tokio::io::ReadBuf<'_>,
240         ) -> Poll<io::Result<usize>> {
241             self.inner.poll_peek(cx, buf)
242         }
243     }
244 
245     impl AsyncRead for AddrStream {
246         #[inline]
poll_read( self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>247         fn poll_read(
248             self: Pin<&mut Self>,
249             cx: &mut task::Context<'_>,
250             buf: &mut ReadBuf<'_>,
251         ) -> Poll<io::Result<()>> {
252             self.project().inner.poll_read(cx, buf)
253         }
254     }
255 
256     impl AsyncWrite for AddrStream {
257         #[inline]
poll_write( self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>258         fn poll_write(
259             self: Pin<&mut Self>,
260             cx: &mut task::Context<'_>,
261             buf: &[u8],
262         ) -> Poll<io::Result<usize>> {
263             self.project().inner.poll_write(cx, buf)
264         }
265 
266         #[inline]
poll_write_vectored( self: Pin<&mut Self>, cx: &mut task::Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll<io::Result<usize>>267         fn poll_write_vectored(
268             self: Pin<&mut Self>,
269             cx: &mut task::Context<'_>,
270             bufs: &[io::IoSlice<'_>],
271         ) -> Poll<io::Result<usize>> {
272             self.project().inner.poll_write_vectored(cx, bufs)
273         }
274 
275         #[inline]
poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>>276         fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
277             // TCP flush is a noop
278             Poll::Ready(Ok(()))
279         }
280 
281         #[inline]
poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>>282         fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
283             self.project().inner.poll_shutdown(cx)
284         }
285 
286         #[inline]
is_write_vectored(&self) -> bool287         fn is_write_vectored(&self) -> bool {
288             // Note that since `self.inner` is a `TcpStream`, this could
289             // *probably* be hard-coded to return `true`...but it seems more
290             // correct to ask it anyway (maybe we're on some platform without
291             // scatter-gather IO?)
292             self.inner.is_write_vectored()
293         }
294     }
295 
296     #[cfg(unix)]
297     impl AsRawFd for AddrStream {
as_raw_fd(&self) -> RawFd298         fn as_raw_fd(&self) -> RawFd {
299             self.inner.as_raw_fd()
300         }
301     }
302 }
303