1 use std::fmt;
2 use std::io;
3 use std::net::{SocketAddr, TcpListener as StdTcpListener};
4 use std::time::Duration;
5
6 use tokio::net::TcpListener;
7 use tokio::time::Sleep;
8
9 use crate::common::{task, Future, Pin, Poll};
10
11 #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411
12 pub use self::addr_stream::AddrStream;
13 use super::accept::Accept;
14
15 /// A stream of connections from binding to an address.
16 #[must_use = "streams do nothing unless polled"]
17 pub struct AddrIncoming {
18 addr: SocketAddr,
19 listener: TcpListener,
20 sleep_on_errors: bool,
21 tcp_keepalive_timeout: Option<Duration>,
22 tcp_nodelay: bool,
23 timeout: Option<Pin<Box<Sleep>>>,
24 }
25
26 impl AddrIncoming {
new(addr: &SocketAddr) -> crate::Result<Self>27 pub(super) fn new(addr: &SocketAddr) -> crate::Result<Self> {
28 let std_listener = StdTcpListener::bind(addr).map_err(crate::Error::new_listen)?;
29
30 AddrIncoming::from_std(std_listener)
31 }
32
from_std(std_listener: StdTcpListener) -> crate::Result<Self>33 pub(super) fn from_std(std_listener: StdTcpListener) -> crate::Result<Self> {
34 // TcpListener::from_std doesn't set O_NONBLOCK
35 std_listener
36 .set_nonblocking(true)
37 .map_err(crate::Error::new_listen)?;
38 let listener = TcpListener::from_std(std_listener).map_err(crate::Error::new_listen)?;
39 AddrIncoming::from_listener(listener)
40 }
41
42 /// Creates a new `AddrIncoming` binding to provided socket address.
bind(addr: &SocketAddr) -> crate::Result<Self>43 pub fn bind(addr: &SocketAddr) -> crate::Result<Self> {
44 AddrIncoming::new(addr)
45 }
46
47 /// Creates a new `AddrIncoming` from an existing `tokio::net::TcpListener`.
from_listener(listener: TcpListener) -> crate::Result<Self>48 pub fn from_listener(listener: TcpListener) -> crate::Result<Self> {
49 let addr = listener.local_addr().map_err(crate::Error::new_listen)?;
50 Ok(AddrIncoming {
51 listener,
52 addr,
53 sleep_on_errors: true,
54 tcp_keepalive_timeout: None,
55 tcp_nodelay: false,
56 timeout: None,
57 })
58 }
59
60 /// Get the local address bound to this listener.
local_addr(&self) -> SocketAddr61 pub fn local_addr(&self) -> SocketAddr {
62 self.addr
63 }
64
65 /// Set whether TCP keepalive messages are enabled on accepted connections.
66 ///
67 /// If `None` is specified, keepalive is disabled, otherwise the duration
68 /// specified will be the time to remain idle before sending TCP keepalive
69 /// probes.
set_keepalive(&mut self, keepalive: Option<Duration>) -> &mut Self70 pub fn set_keepalive(&mut self, keepalive: Option<Duration>) -> &mut Self {
71 self.tcp_keepalive_timeout = keepalive;
72 self
73 }
74
75 /// Set the value of `TCP_NODELAY` option for accepted connections.
set_nodelay(&mut self, enabled: bool) -> &mut Self76 pub fn set_nodelay(&mut self, enabled: bool) -> &mut Self {
77 self.tcp_nodelay = enabled;
78 self
79 }
80
81 /// Set whether to sleep on accept errors.
82 ///
83 /// A possible scenario is that the process has hit the max open files
84 /// allowed, and so trying to accept a new connection will fail with
85 /// `EMFILE`. In some cases, it's preferable to just wait for some time, if
86 /// the application will likely close some files (or connections), and try
87 /// to accept the connection again. If this option is `true`, the error
88 /// will be logged at the `error` level, since it is still a big deal,
89 /// and then the listener will sleep for 1 second.
90 ///
91 /// In other cases, hitting the max open files should be treat similarly
92 /// to being out-of-memory, and simply error (and shutdown). Setting
93 /// this option to `false` will allow that.
94 ///
95 /// Default is `true`.
set_sleep_on_errors(&mut self, val: bool)96 pub fn set_sleep_on_errors(&mut self, val: bool) {
97 self.sleep_on_errors = val;
98 }
99
poll_next_(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<AddrStream>>100 fn poll_next_(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<AddrStream>> {
101 // Check if a previous timeout is active that was set by IO errors.
102 if let Some(ref mut to) = self.timeout {
103 ready!(Pin::new(to).poll(cx));
104 }
105 self.timeout = None;
106
107 loop {
108 match ready!(self.listener.poll_accept(cx)) {
109 Ok((socket, addr)) => {
110 if let Some(dur) = self.tcp_keepalive_timeout {
111 let socket = socket2::SockRef::from(&socket);
112 let conf = socket2::TcpKeepalive::new().with_time(dur);
113 if let Err(e) = socket.set_tcp_keepalive(&conf) {
114 trace!("error trying to set TCP keepalive: {}", e);
115 }
116 }
117 if let Err(e) = socket.set_nodelay(self.tcp_nodelay) {
118 trace!("error trying to set TCP nodelay: {}", e);
119 }
120 return Poll::Ready(Ok(AddrStream::new(socket, addr)));
121 }
122 Err(e) => {
123 // Connection errors can be ignored directly, continue by
124 // accepting the next request.
125 if is_connection_error(&e) {
126 debug!("accepted connection already errored: {}", e);
127 continue;
128 }
129
130 if self.sleep_on_errors {
131 error!("accept error: {}", e);
132
133 // Sleep 1s.
134 let mut timeout = Box::pin(tokio::time::sleep(Duration::from_secs(1)));
135
136 match timeout.as_mut().poll(cx) {
137 Poll::Ready(()) => {
138 // Wow, it's been a second already? Ok then...
139 continue;
140 }
141 Poll::Pending => {
142 self.timeout = Some(timeout);
143 return Poll::Pending;
144 }
145 }
146 } else {
147 return Poll::Ready(Err(e));
148 }
149 }
150 }
151 }
152 }
153 }
154
155 impl Accept for AddrIncoming {
156 type Conn = AddrStream;
157 type Error = io::Error;
158
poll_accept( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll<Option<Result<Self::Conn, Self::Error>>>159 fn poll_accept(
160 mut self: Pin<&mut Self>,
161 cx: &mut task::Context<'_>,
162 ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
163 let result = ready!(self.poll_next_(cx));
164 Poll::Ready(Some(result))
165 }
166 }
167
168 /// This function defines errors that are per-connection. Which basically
169 /// means that if we get this error from `accept()` system call it means
170 /// next connection might be ready to be accepted.
171 ///
172 /// All other errors will incur a timeout before next `accept()` is performed.
173 /// The timeout is useful to handle resource exhaustion errors like ENFILE
174 /// and EMFILE. Otherwise, could enter into tight loop.
is_connection_error(e: &io::Error) -> bool175 fn is_connection_error(e: &io::Error) -> bool {
176 matches!(e.kind(), io::ErrorKind::ConnectionRefused
177 | io::ErrorKind::ConnectionAborted
178 | io::ErrorKind::ConnectionReset)
179 }
180
181 impl fmt::Debug for AddrIncoming {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result182 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183 f.debug_struct("AddrIncoming")
184 .field("addr", &self.addr)
185 .field("sleep_on_errors", &self.sleep_on_errors)
186 .field("tcp_keepalive_timeout", &self.tcp_keepalive_timeout)
187 .field("tcp_nodelay", &self.tcp_nodelay)
188 .finish()
189 }
190 }
191
192 mod addr_stream {
193 use std::io;
194 use std::net::SocketAddr;
195 #[cfg(unix)]
196 use std::os::unix::io::{AsRawFd, RawFd};
197 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
198 use tokio::net::TcpStream;
199
200 use crate::common::{task, Pin, Poll};
201
202 pin_project_lite::pin_project! {
203 /// A transport returned yieled by `AddrIncoming`.
204 #[derive(Debug)]
205 pub struct AddrStream {
206 #[pin]
207 inner: TcpStream,
208 pub(super) remote_addr: SocketAddr,
209 }
210 }
211
212 impl AddrStream {
new(tcp: TcpStream, addr: SocketAddr) -> AddrStream213 pub(super) fn new(tcp: TcpStream, addr: SocketAddr) -> AddrStream {
214 AddrStream {
215 inner: tcp,
216 remote_addr: addr,
217 }
218 }
219
220 /// Returns the remote (peer) address of this connection.
221 #[inline]
remote_addr(&self) -> SocketAddr222 pub fn remote_addr(&self) -> SocketAddr {
223 self.remote_addr
224 }
225
226 /// Consumes the AddrStream and returns the underlying IO object
227 #[inline]
into_inner(self) -> TcpStream228 pub fn into_inner(self) -> TcpStream {
229 self.inner
230 }
231
232 /// Attempt to receive data on the socket, without removing that data
233 /// from the queue, registering the current task for wakeup if data is
234 /// not yet available.
poll_peek( &mut self, cx: &mut task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll<io::Result<usize>>235 pub fn poll_peek(
236 &mut self,
237 cx: &mut task::Context<'_>,
238 buf: &mut tokio::io::ReadBuf<'_>,
239 ) -> Poll<io::Result<usize>> {
240 self.inner.poll_peek(cx, buf)
241 }
242 }
243
244 impl AsyncRead for AddrStream {
245 #[inline]
poll_read( self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>246 fn poll_read(
247 self: Pin<&mut Self>,
248 cx: &mut task::Context<'_>,
249 buf: &mut ReadBuf<'_>,
250 ) -> Poll<io::Result<()>> {
251 self.project().inner.poll_read(cx, buf)
252 }
253 }
254
255 impl AsyncWrite for AddrStream {
256 #[inline]
poll_write( self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>257 fn poll_write(
258 self: Pin<&mut Self>,
259 cx: &mut task::Context<'_>,
260 buf: &[u8],
261 ) -> Poll<io::Result<usize>> {
262 self.project().inner.poll_write(cx, buf)
263 }
264
265 #[inline]
poll_write_vectored( self: Pin<&mut Self>, cx: &mut task::Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll<io::Result<usize>>266 fn poll_write_vectored(
267 self: Pin<&mut Self>,
268 cx: &mut task::Context<'_>,
269 bufs: &[io::IoSlice<'_>],
270 ) -> Poll<io::Result<usize>> {
271 self.project().inner.poll_write_vectored(cx, bufs)
272 }
273
274 #[inline]
poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>>275 fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
276 // TCP flush is a noop
277 Poll::Ready(Ok(()))
278 }
279
280 #[inline]
poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>>281 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
282 self.project().inner.poll_shutdown(cx)
283 }
284
285 #[inline]
is_write_vectored(&self) -> bool286 fn is_write_vectored(&self) -> bool {
287 // Note that since `self.inner` is a `TcpStream`, this could
288 // *probably* be hard-coded to return `true`...but it seems more
289 // correct to ask it anyway (maybe we're on some platform without
290 // scatter-gather IO?)
291 self.inner.is_write_vectored()
292 }
293 }
294
295 #[cfg(unix)]
296 impl AsRawFd for AddrStream {
as_raw_fd(&self) -> RawFd297 fn as_raw_fd(&self) -> RawFd {
298 self.inner.as_raw_fd()
299 }
300 }
301 }
302