1 use std::future::Future;
2 use std::io;
3 use std::pin::Pin;
4 use std::task::Poll;
5 use std::time::{Duration, Instant};
6 
7 /// `Incoming` is a stream of incoming sockets
8 /// Polling the stream may return a temporary io::Error (for instance if we can't open the connection because of "too many open files" limit)
9 /// we use for_each combinator which:
10 /// 1. Runs for every Ok(socket)
11 /// 2. Stops on the FIRST Err()
12 /// So any temporary io::Error will cause the entire server to terminate.
13 /// This wrapper type for tokio::Incoming stops accepting new connections
14 /// for a specified amount of time once an io::Error is encountered
15 pub struct SuspendableStream<S> {
16 	stream: S,
17 	next_delay: Duration,
18 	initial_delay: Duration,
19 	max_delay: Duration,
20 	suspended_until: Option<Instant>,
21 }
22 
23 impl<S> SuspendableStream<S> {
24 	/// construct a new Suspendable stream, given tokio::Incoming
25 	/// and the amount of time to pause for.
new(stream: S) -> Self26 	pub fn new(stream: S) -> Self {
27 		SuspendableStream {
28 			stream,
29 			next_delay: Duration::from_millis(20),
30 			initial_delay: Duration::from_millis(10),
31 			max_delay: Duration::from_secs(5),
32 			suspended_until: None,
33 		}
34 	}
35 }
36 
37 impl<S, I> futures::Stream for SuspendableStream<S>
38 where
39 	S: futures::Stream<Item = io::Result<I>> + Unpin,
40 {
41 	type Item = I;
42 
poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>>43 	fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
44 		loop {
45 			// If we encountered a connection error before then we suspend
46 			// polling from the underlying stream for a bit
47 			if let Some(deadline) = &mut self.suspended_until {
48 				let deadline = tokio::time::Instant::from_std(*deadline);
49 				let sleep = tokio::time::sleep_until(deadline);
50 				futures::pin_mut!(sleep);
51 				match sleep.poll(cx) {
52 					Poll::Pending => return Poll::Pending,
53 					Poll::Ready(()) => {
54 						self.suspended_until = None;
55 					}
56 				}
57 			}
58 
59 			match Pin::new(&mut self.stream).poll_next(cx) {
60 				Poll::Pending => return Poll::Pending,
61 				Poll::Ready(None) => {
62 					if self.next_delay > self.initial_delay {
63 						self.next_delay = self.initial_delay;
64 					}
65 					return Poll::Ready(None);
66 				}
67 				Poll::Ready(Some(Ok(item))) => {
68 					if self.next_delay > self.initial_delay {
69 						self.next_delay = self.initial_delay;
70 					}
71 
72 					return Poll::Ready(Some(item));
73 				}
74 				Poll::Ready(Some(Err(ref err))) => {
75 					if connection_error(err) {
76 						warn!("Connection Error: {:?}", err);
77 						continue;
78 					}
79 					self.next_delay = if self.next_delay < self.max_delay {
80 						self.next_delay * 2
81 					} else {
82 						self.next_delay
83 					};
84 					debug!("Error accepting connection: {}", err);
85 					debug!("The server will stop accepting connections for {:?}", self.next_delay);
86 					self.suspended_until = Some(Instant::now() + self.next_delay);
87 				}
88 			}
89 		}
90 	}
91 }
92 
93 /// assert that the error was a connection error
connection_error(e: &io::Error) -> bool94 fn connection_error(e: &io::Error) -> bool {
95 	e.kind() == io::ErrorKind::ConnectionRefused
96 		|| e.kind() == io::ErrorKind::ConnectionAborted
97 		|| e.kind() == io::ErrorKind::ConnectionReset
98 }
99