1 use crate::Error;
2 use futures::FutureExt;
3 use std::{
4     cmp,
5     future::Future,
6     pin::Pin,
7     task::{Context, Poll},
8     time::Duration,
9 };
10 use tokio::time::{delay_for, Delay};
11 use tower03::{retry::Policy, timeout::error::Elapsed};
12 
13 pub enum RetryAction {
14     /// Indicate that this request should be retried with a reason
15     Retry(String),
16     /// Indicate that this request should not be retried with a reason
17     DontRetry(String),
18     /// Indicate that this request should not be retried but the request was successful
19     Successful,
20 }
21 
22 pub trait RetryLogic: Clone {
23     type Error: std::error::Error + Send + Sync + 'static;
24     type Response;
25 
is_retriable_error(&self, error: &Self::Error) -> bool26     fn is_retriable_error(&self, error: &Self::Error) -> bool;
27 
should_retry_response(&self, _response: &Self::Response) -> RetryAction28     fn should_retry_response(&self, _response: &Self::Response) -> RetryAction {
29         // Treat the default as the request is successful
30         RetryAction::Successful
31     }
32 }
33 
34 #[derive(Debug, Clone)]
35 pub struct FixedRetryPolicy<L> {
36     remaining_attempts: usize,
37     previous_duration: Duration,
38     current_duration: Duration,
39     max_duration: Duration,
40     logic: L,
41 }
42 
43 pub struct RetryPolicyFuture<L: RetryLogic> {
44     delay: Delay,
45     policy: FixedRetryPolicy<L>,
46 }
47 
48 impl<L: RetryLogic> FixedRetryPolicy<L> {
new( remaining_attempts: usize, initial_backoff: Duration, max_duration: Duration, logic: L, ) -> Self49     pub fn new(
50         remaining_attempts: usize,
51         initial_backoff: Duration,
52         max_duration: Duration,
53         logic: L,
54     ) -> Self {
55         FixedRetryPolicy {
56             remaining_attempts,
57             previous_duration: Duration::from_secs(0),
58             current_duration: initial_backoff,
59             max_duration,
60             logic,
61         }
62     }
63 
advance(&self) -> FixedRetryPolicy<L>64     fn advance(&self) -> FixedRetryPolicy<L> {
65         let next_duration: Duration = self.previous_duration + self.current_duration;
66 
67         FixedRetryPolicy {
68             remaining_attempts: self.remaining_attempts - 1,
69             previous_duration: self.current_duration,
70             current_duration: cmp::min(next_duration, self.max_duration),
71             max_duration: self.max_duration,
72             logic: self.logic.clone(),
73         }
74     }
75 
backoff(&self) -> Duration76     fn backoff(&self) -> Duration {
77         self.current_duration
78     }
79 
build_retry(&self) -> RetryPolicyFuture<L>80     fn build_retry(&self) -> RetryPolicyFuture<L> {
81         let policy = self.advance();
82         let delay = delay_for(self.backoff());
83 
84         debug!(message = "retrying request.", delay_ms = %self.backoff().as_millis());
85         RetryPolicyFuture { delay, policy }
86     }
87 }
88 
89 impl<Req, Res, L> Policy<Req, Res, Error> for FixedRetryPolicy<L>
90 where
91     Req: Clone,
92     L: RetryLogic<Response = Res>,
93 {
94     type Future = RetryPolicyFuture<L>;
95 
retry(&self, _: &Req, result: Result<&Res, &Error>) -> Option<Self::Future>96     fn retry(&self, _: &Req, result: Result<&Res, &Error>) -> Option<Self::Future> {
97         match result {
98             Ok(response) => {
99                 if self.remaining_attempts == 0 {
100                     error!("retries exhausted");
101                     return None;
102                 }
103 
104                 match self.logic.should_retry_response(response) {
105                     RetryAction::Retry(reason) => {
106                         warn!(message = "retrying after response.", %reason);
107                         Some(self.build_retry())
108                     }
109 
110                     RetryAction::DontRetry(reason) => {
111                         warn!(message = "request is not retryable; dropping the request.", %reason);
112                         None
113                     }
114 
115                     RetryAction::Successful => None,
116                 }
117             }
118             Err(error) => {
119                 if self.remaining_attempts == 0 {
120                     error!(message = "retries exhausted.", %error);
121                     return None;
122                 }
123 
124                 if let Some(expected) = error.downcast_ref::<L::Error>() {
125                     if self.logic.is_retriable_error(expected) {
126                         warn!("retrying after error: {}", expected);
127                         Some(self.build_retry())
128                     } else {
129                         error!(message = "encountered non-retriable error.", %error);
130                         None
131                     }
132                 } else if error.downcast_ref::<Elapsed>().is_some() {
133                     warn!("request timedout.");
134                     Some(self.build_retry())
135                 } else {
136                     warn!(message = "unexpected error type.", %error);
137                     None
138                 }
139             }
140         }
141     }
142 
clone_request(&self, request: &Req) -> Option<Req>143     fn clone_request(&self, request: &Req) -> Option<Req> {
144         Some(request.clone())
145     }
146 }
147 
148 // Safety: `L` is never pinned and we use no unsafe pin projections
149 // therefore this safe.
150 impl<L: RetryLogic> Unpin for RetryPolicyFuture<L> {}
151 
152 impl<L: RetryLogic> Future for RetryPolicyFuture<L> {
153     type Output = FixedRetryPolicy<L>;
154 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>155     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
156         futures::ready!(self.delay.poll_unpin(cx));
157         Poll::Ready(self.policy.clone())
158     }
159 }
160 
161 impl RetryAction {
is_retryable(&self) -> bool162     pub fn is_retryable(&self) -> bool {
163         if let RetryAction::Retry(_) = &self {
164             true
165         } else {
166             false
167         }
168     }
169 
is_not_retryable(&self) -> bool170     pub fn is_not_retryable(&self) -> bool {
171         if let RetryAction::DontRetry(_) = &self {
172             true
173         } else {
174             false
175         }
176     }
177 
is_successful(&self) -> bool178     pub fn is_successful(&self) -> bool {
179         if let RetryAction::Successful = &self {
180             true
181         } else {
182             false
183         }
184     }
185 }
186 
187 #[cfg(test)]
188 mod tests {
189     use super::*;
190     use crate::test_util::trace_init;
191     use std::{fmt, time::Duration};
192     use tokio::time;
193     use tokio_test::{assert_pending, assert_ready_err, assert_ready_ok, task};
194     use tower03::retry::RetryLayer;
195     use tower_test03::{assert_request_eq, mock};
196 
197     #[tokio::test]
service_error_retry()198     async fn service_error_retry() {
199         time::pause();
200         trace_init();
201 
202         let policy = FixedRetryPolicy::new(
203             5,
204             Duration::from_secs(1),
205             Duration::from_secs(10),
206             SvcRetryLogic,
207         );
208 
209         let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy));
210 
211         assert_ready_ok!(svc.poll_ready());
212 
213         let fut = svc.call("hello");
214         let mut fut = task::spawn(fut);
215 
216         assert_request_eq!(handle, "hello").send_error(Error(true));
217 
218         assert_pending!(fut.poll());
219 
220         time::advance(Duration::from_secs(2)).await;
221         assert_pending!(fut.poll());
222 
223         assert_request_eq!(handle, "hello").send_response("world");
224         assert_eq!(fut.await.unwrap(), "world");
225     }
226 
227     #[tokio::test]
service_error_no_retry()228     async fn service_error_no_retry() {
229         trace_init();
230 
231         let policy = FixedRetryPolicy::new(
232             5,
233             Duration::from_secs(1),
234             Duration::from_secs(10),
235             SvcRetryLogic,
236         );
237 
238         let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy));
239 
240         assert_ready_ok!(svc.poll_ready());
241 
242         let mut fut = task::spawn(svc.call("hello"));
243         assert_request_eq!(handle, "hello").send_error(Error(false));
244         assert_ready_err!(fut.poll());
245     }
246 
247     #[tokio::test]
timeout_error()248     async fn timeout_error() {
249         time::pause();
250         trace_init();
251 
252         let policy = FixedRetryPolicy::new(
253             5,
254             Duration::from_secs(1),
255             Duration::from_secs(10),
256             SvcRetryLogic,
257         );
258 
259         let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy));
260 
261         assert_ready_ok!(svc.poll_ready());
262 
263         let mut fut = task::spawn(svc.call("hello"));
264         assert_request_eq!(handle, "hello").send_error(tower03::timeout::error::Elapsed::new());
265         assert_pending!(fut.poll());
266 
267         time::advance(Duration::from_secs(2)).await;
268         assert_pending!(fut.poll());
269 
270         assert_request_eq!(handle, "hello").send_response("world");
271         assert_eq!(fut.await.unwrap(), "world");
272     }
273 
274     #[test]
backoff_grows_to_max()275     fn backoff_grows_to_max() {
276         let mut policy = FixedRetryPolicy::new(
277             10,
278             Duration::from_secs(1),
279             Duration::from_secs(10),
280             SvcRetryLogic,
281         );
282         assert_eq!(Duration::from_secs(1), policy.backoff());
283 
284         policy = policy.advance();
285         assert_eq!(Duration::from_secs(1), policy.backoff());
286 
287         policy = policy.advance();
288         assert_eq!(Duration::from_secs(2), policy.backoff());
289 
290         policy = policy.advance();
291         assert_eq!(Duration::from_secs(3), policy.backoff());
292 
293         policy = policy.advance();
294         assert_eq!(Duration::from_secs(5), policy.backoff());
295 
296         policy = policy.advance();
297         assert_eq!(Duration::from_secs(8), policy.backoff());
298 
299         policy = policy.advance();
300         assert_eq!(Duration::from_secs(10), policy.backoff());
301 
302         policy = policy.advance();
303         assert_eq!(Duration::from_secs(10), policy.backoff());
304     }
305 
306     #[derive(Debug, Clone)]
307     struct SvcRetryLogic;
308 
309     impl RetryLogic for SvcRetryLogic {
310         type Error = Error;
311         type Response = &'static str;
312 
is_retriable_error(&self, error: &Self::Error) -> bool313         fn is_retriable_error(&self, error: &Self::Error) -> bool {
314             error.0
315         }
316     }
317 
318     #[derive(Debug)]
319     struct Error(bool);
320 
321     impl fmt::Display for Error {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result322         fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
323             write!(f, "error")
324         }
325     }
326 
327     impl std::error::Error for Error {}
328 }
329