1 use std::fmt;
2 use std::io;
3 use std::net::{SocketAddr, TcpListener as StdTcpListener};
4 use std::time::Duration;
5
6 use tokio::net::TcpListener;
7 use tokio::time::Sleep;
8 use tracing::{debug, error, trace};
9
10 use crate::common::{task, Future, Pin, Poll};
11
12 #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411
13 pub use self::addr_stream::AddrStream;
14 use super::accept::Accept;
15
16 /// A stream of connections from binding to an address.
17 #[must_use = "streams do nothing unless polled"]
18 pub struct AddrIncoming {
19 addr: SocketAddr,
20 listener: TcpListener,
21 sleep_on_errors: bool,
22 tcp_keepalive_timeout: Option<Duration>,
23 tcp_nodelay: bool,
24 timeout: Option<Pin<Box<Sleep>>>,
25 }
26
27 impl AddrIncoming {
new(addr: &SocketAddr) -> crate::Result<Self>28 pub(super) fn new(addr: &SocketAddr) -> crate::Result<Self> {
29 let std_listener = StdTcpListener::bind(addr).map_err(crate::Error::new_listen)?;
30
31 AddrIncoming::from_std(std_listener)
32 }
33
from_std(std_listener: StdTcpListener) -> crate::Result<Self>34 pub(super) fn from_std(std_listener: StdTcpListener) -> crate::Result<Self> {
35 // TcpListener::from_std doesn't set O_NONBLOCK
36 std_listener
37 .set_nonblocking(true)
38 .map_err(crate::Error::new_listen)?;
39 let listener = TcpListener::from_std(std_listener).map_err(crate::Error::new_listen)?;
40 AddrIncoming::from_listener(listener)
41 }
42
43 /// Creates a new `AddrIncoming` binding to provided socket address.
bind(addr: &SocketAddr) -> crate::Result<Self>44 pub fn bind(addr: &SocketAddr) -> crate::Result<Self> {
45 AddrIncoming::new(addr)
46 }
47
48 /// Creates a new `AddrIncoming` from an existing `tokio::net::TcpListener`.
from_listener(listener: TcpListener) -> crate::Result<Self>49 pub fn from_listener(listener: TcpListener) -> crate::Result<Self> {
50 let addr = listener.local_addr().map_err(crate::Error::new_listen)?;
51 Ok(AddrIncoming {
52 listener,
53 addr,
54 sleep_on_errors: true,
55 tcp_keepalive_timeout: None,
56 tcp_nodelay: false,
57 timeout: None,
58 })
59 }
60
61 /// Get the local address bound to this listener.
local_addr(&self) -> SocketAddr62 pub fn local_addr(&self) -> SocketAddr {
63 self.addr
64 }
65
66 /// Set whether TCP keepalive messages are enabled on accepted connections.
67 ///
68 /// If `None` is specified, keepalive is disabled, otherwise the duration
69 /// specified will be the time to remain idle before sending TCP keepalive
70 /// probes.
set_keepalive(&mut self, keepalive: Option<Duration>) -> &mut Self71 pub fn set_keepalive(&mut self, keepalive: Option<Duration>) -> &mut Self {
72 self.tcp_keepalive_timeout = keepalive;
73 self
74 }
75
76 /// Set the value of `TCP_NODELAY` option for accepted connections.
set_nodelay(&mut self, enabled: bool) -> &mut Self77 pub fn set_nodelay(&mut self, enabled: bool) -> &mut Self {
78 self.tcp_nodelay = enabled;
79 self
80 }
81
82 /// Set whether to sleep on accept errors.
83 ///
84 /// A possible scenario is that the process has hit the max open files
85 /// allowed, and so trying to accept a new connection will fail with
86 /// `EMFILE`. In some cases, it's preferable to just wait for some time, if
87 /// the application will likely close some files (or connections), and try
88 /// to accept the connection again. If this option is `true`, the error
89 /// will be logged at the `error` level, since it is still a big deal,
90 /// and then the listener will sleep for 1 second.
91 ///
92 /// In other cases, hitting the max open files should be treat similarly
93 /// to being out-of-memory, and simply error (and shutdown). Setting
94 /// this option to `false` will allow that.
95 ///
96 /// Default is `true`.
set_sleep_on_errors(&mut self, val: bool)97 pub fn set_sleep_on_errors(&mut self, val: bool) {
98 self.sleep_on_errors = val;
99 }
100
poll_next_(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<AddrStream>>101 fn poll_next_(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<AddrStream>> {
102 // Check if a previous timeout is active that was set by IO errors.
103 if let Some(ref mut to) = self.timeout {
104 ready!(Pin::new(to).poll(cx));
105 }
106 self.timeout = None;
107
108 loop {
109 match ready!(self.listener.poll_accept(cx)) {
110 Ok((socket, addr)) => {
111 if let Some(dur) = self.tcp_keepalive_timeout {
112 let socket = socket2::SockRef::from(&socket);
113 let conf = socket2::TcpKeepalive::new().with_time(dur);
114 if let Err(e) = socket.set_tcp_keepalive(&conf) {
115 trace!("error trying to set TCP keepalive: {}", e);
116 }
117 }
118 if let Err(e) = socket.set_nodelay(self.tcp_nodelay) {
119 trace!("error trying to set TCP nodelay: {}", e);
120 }
121 return Poll::Ready(Ok(AddrStream::new(socket, addr)));
122 }
123 Err(e) => {
124 // Connection errors can be ignored directly, continue by
125 // accepting the next request.
126 if is_connection_error(&e) {
127 debug!("accepted connection already errored: {}", e);
128 continue;
129 }
130
131 if self.sleep_on_errors {
132 error!("accept error: {}", e);
133
134 // Sleep 1s.
135 let mut timeout = Box::pin(tokio::time::sleep(Duration::from_secs(1)));
136
137 match timeout.as_mut().poll(cx) {
138 Poll::Ready(()) => {
139 // Wow, it's been a second already? Ok then...
140 continue;
141 }
142 Poll::Pending => {
143 self.timeout = Some(timeout);
144 return Poll::Pending;
145 }
146 }
147 } else {
148 return Poll::Ready(Err(e));
149 }
150 }
151 }
152 }
153 }
154 }
155
156 impl Accept for AddrIncoming {
157 type Conn = AddrStream;
158 type Error = io::Error;
159
poll_accept( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll<Option<Result<Self::Conn, Self::Error>>>160 fn poll_accept(
161 mut self: Pin<&mut Self>,
162 cx: &mut task::Context<'_>,
163 ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
164 let result = ready!(self.poll_next_(cx));
165 Poll::Ready(Some(result))
166 }
167 }
168
169 /// This function defines errors that are per-connection. Which basically
170 /// means that if we get this error from `accept()` system call it means
171 /// next connection might be ready to be accepted.
172 ///
173 /// All other errors will incur a timeout before next `accept()` is performed.
174 /// The timeout is useful to handle resource exhaustion errors like ENFILE
175 /// and EMFILE. Otherwise, could enter into tight loop.
is_connection_error(e: &io::Error) -> bool176 fn is_connection_error(e: &io::Error) -> bool {
177 matches!(e.kind(), io::ErrorKind::ConnectionRefused
178 | io::ErrorKind::ConnectionAborted
179 | io::ErrorKind::ConnectionReset)
180 }
181
182 impl fmt::Debug for AddrIncoming {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result183 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184 f.debug_struct("AddrIncoming")
185 .field("addr", &self.addr)
186 .field("sleep_on_errors", &self.sleep_on_errors)
187 .field("tcp_keepalive_timeout", &self.tcp_keepalive_timeout)
188 .field("tcp_nodelay", &self.tcp_nodelay)
189 .finish()
190 }
191 }
192
193 mod addr_stream {
194 use std::io;
195 use std::net::SocketAddr;
196 #[cfg(unix)]
197 use std::os::unix::io::{AsRawFd, RawFd};
198 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
199 use tokio::net::TcpStream;
200
201 use crate::common::{task, Pin, Poll};
202
203 pin_project_lite::pin_project! {
204 /// A transport returned yieled by `AddrIncoming`.
205 #[derive(Debug)]
206 pub struct AddrStream {
207 #[pin]
208 inner: TcpStream,
209 pub(super) remote_addr: SocketAddr,
210 }
211 }
212
213 impl AddrStream {
new(tcp: TcpStream, addr: SocketAddr) -> AddrStream214 pub(super) fn new(tcp: TcpStream, addr: SocketAddr) -> AddrStream {
215 AddrStream {
216 inner: tcp,
217 remote_addr: addr,
218 }
219 }
220
221 /// Returns the remote (peer) address of this connection.
222 #[inline]
remote_addr(&self) -> SocketAddr223 pub fn remote_addr(&self) -> SocketAddr {
224 self.remote_addr
225 }
226
227 /// Consumes the AddrStream and returns the underlying IO object
228 #[inline]
into_inner(self) -> TcpStream229 pub fn into_inner(self) -> TcpStream {
230 self.inner
231 }
232
233 /// Attempt to receive data on the socket, without removing that data
234 /// from the queue, registering the current task for wakeup if data is
235 /// not yet available.
poll_peek( &mut self, cx: &mut task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll<io::Result<usize>>236 pub fn poll_peek(
237 &mut self,
238 cx: &mut task::Context<'_>,
239 buf: &mut tokio::io::ReadBuf<'_>,
240 ) -> Poll<io::Result<usize>> {
241 self.inner.poll_peek(cx, buf)
242 }
243 }
244
245 impl AsyncRead for AddrStream {
246 #[inline]
poll_read( self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>247 fn poll_read(
248 self: Pin<&mut Self>,
249 cx: &mut task::Context<'_>,
250 buf: &mut ReadBuf<'_>,
251 ) -> Poll<io::Result<()>> {
252 self.project().inner.poll_read(cx, buf)
253 }
254 }
255
256 impl AsyncWrite for AddrStream {
257 #[inline]
poll_write( self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>258 fn poll_write(
259 self: Pin<&mut Self>,
260 cx: &mut task::Context<'_>,
261 buf: &[u8],
262 ) -> Poll<io::Result<usize>> {
263 self.project().inner.poll_write(cx, buf)
264 }
265
266 #[inline]
poll_write_vectored( self: Pin<&mut Self>, cx: &mut task::Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll<io::Result<usize>>267 fn poll_write_vectored(
268 self: Pin<&mut Self>,
269 cx: &mut task::Context<'_>,
270 bufs: &[io::IoSlice<'_>],
271 ) -> Poll<io::Result<usize>> {
272 self.project().inner.poll_write_vectored(cx, bufs)
273 }
274
275 #[inline]
poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>>276 fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
277 // TCP flush is a noop
278 Poll::Ready(Ok(()))
279 }
280
281 #[inline]
poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>>282 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
283 self.project().inner.poll_shutdown(cx)
284 }
285
286 #[inline]
is_write_vectored(&self) -> bool287 fn is_write_vectored(&self) -> bool {
288 // Note that since `self.inner` is a `TcpStream`, this could
289 // *probably* be hard-coded to return `true`...but it seems more
290 // correct to ask it anyway (maybe we're on some platform without
291 // scatter-gather IO?)
292 self.inner.is_write_vectored()
293 }
294 }
295
296 #[cfg(unix)]
297 impl AsRawFd for AddrStream {
as_raw_fd(&self) -> RawFd298 fn as_raw_fd(&self) -> RawFd {
299 self.inner.as_raw_fd()
300 }
301 }
302 }
303