1 use crate::Error; 2 use futures::FutureExt; 3 use std::{ 4 cmp, 5 future::Future, 6 pin::Pin, 7 task::{Context, Poll}, 8 time::Duration, 9 }; 10 use tokio::time::{delay_for, Delay}; 11 use tower03::{retry::Policy, timeout::error::Elapsed}; 12 13 pub enum RetryAction { 14 /// Indicate that this request should be retried with a reason 15 Retry(String), 16 /// Indicate that this request should not be retried with a reason 17 DontRetry(String), 18 /// Indicate that this request should not be retried but the request was successful 19 Successful, 20 } 21 22 pub trait RetryLogic: Clone { 23 type Error: std::error::Error + Send + Sync + 'static; 24 type Response; 25 is_retriable_error(&self, error: &Self::Error) -> bool26 fn is_retriable_error(&self, error: &Self::Error) -> bool; 27 should_retry_response(&self, _response: &Self::Response) -> RetryAction28 fn should_retry_response(&self, _response: &Self::Response) -> RetryAction { 29 // Treat the default as the request is successful 30 RetryAction::Successful 31 } 32 } 33 34 #[derive(Debug, Clone)] 35 pub struct FixedRetryPolicy<L> { 36 remaining_attempts: usize, 37 previous_duration: Duration, 38 current_duration: Duration, 39 max_duration: Duration, 40 logic: L, 41 } 42 43 pub struct RetryPolicyFuture<L: RetryLogic> { 44 delay: Delay, 45 policy: FixedRetryPolicy<L>, 46 } 47 48 impl<L: RetryLogic> FixedRetryPolicy<L> { new( remaining_attempts: usize, initial_backoff: Duration, max_duration: Duration, logic: L, ) -> Self49 pub fn new( 50 remaining_attempts: usize, 51 initial_backoff: Duration, 52 max_duration: Duration, 53 logic: L, 54 ) -> Self { 55 FixedRetryPolicy { 56 remaining_attempts, 57 previous_duration: Duration::from_secs(0), 58 current_duration: initial_backoff, 59 max_duration, 60 logic, 61 } 62 } 63 advance(&self) -> FixedRetryPolicy<L>64 fn advance(&self) -> FixedRetryPolicy<L> { 65 let next_duration: Duration = self.previous_duration + self.current_duration; 66 67 FixedRetryPolicy { 68 remaining_attempts: self.remaining_attempts - 1, 69 previous_duration: self.current_duration, 70 current_duration: cmp::min(next_duration, self.max_duration), 71 max_duration: self.max_duration, 72 logic: self.logic.clone(), 73 } 74 } 75 backoff(&self) -> Duration76 fn backoff(&self) -> Duration { 77 self.current_duration 78 } 79 build_retry(&self) -> RetryPolicyFuture<L>80 fn build_retry(&self) -> RetryPolicyFuture<L> { 81 let policy = self.advance(); 82 let delay = delay_for(self.backoff()); 83 84 debug!(message = "retrying request.", delay_ms = %self.backoff().as_millis()); 85 RetryPolicyFuture { delay, policy } 86 } 87 } 88 89 impl<Req, Res, L> Policy<Req, Res, Error> for FixedRetryPolicy<L> 90 where 91 Req: Clone, 92 L: RetryLogic<Response = Res>, 93 { 94 type Future = RetryPolicyFuture<L>; 95 retry(&self, _: &Req, result: Result<&Res, &Error>) -> Option<Self::Future>96 fn retry(&self, _: &Req, result: Result<&Res, &Error>) -> Option<Self::Future> { 97 match result { 98 Ok(response) => { 99 if self.remaining_attempts == 0 { 100 error!("retries exhausted"); 101 return None; 102 } 103 104 match self.logic.should_retry_response(response) { 105 RetryAction::Retry(reason) => { 106 warn!(message = "retrying after response.", %reason); 107 Some(self.build_retry()) 108 } 109 110 RetryAction::DontRetry(reason) => { 111 warn!(message = "request is not retryable; dropping the request.", %reason); 112 None 113 } 114 115 RetryAction::Successful => None, 116 } 117 } 118 Err(error) => { 119 if self.remaining_attempts == 0 { 120 error!(message = "retries exhausted.", %error); 121 return None; 122 } 123 124 if let Some(expected) = error.downcast_ref::<L::Error>() { 125 if self.logic.is_retriable_error(expected) { 126 warn!("retrying after error: {}", expected); 127 Some(self.build_retry()) 128 } else { 129 error!(message = "encountered non-retriable error.", %error); 130 None 131 } 132 } else if error.downcast_ref::<Elapsed>().is_some() { 133 warn!("request timedout."); 134 Some(self.build_retry()) 135 } else { 136 warn!(message = "unexpected error type.", %error); 137 None 138 } 139 } 140 } 141 } 142 clone_request(&self, request: &Req) -> Option<Req>143 fn clone_request(&self, request: &Req) -> Option<Req> { 144 Some(request.clone()) 145 } 146 } 147 148 // Safety: `L` is never pinned and we use no unsafe pin projections 149 // therefore this safe. 150 impl<L: RetryLogic> Unpin for RetryPolicyFuture<L> {} 151 152 impl<L: RetryLogic> Future for RetryPolicyFuture<L> { 153 type Output = FixedRetryPolicy<L>; 154 poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>155 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 156 futures::ready!(self.delay.poll_unpin(cx)); 157 Poll::Ready(self.policy.clone()) 158 } 159 } 160 161 impl RetryAction { is_retryable(&self) -> bool162 pub fn is_retryable(&self) -> bool { 163 if let RetryAction::Retry(_) = &self { 164 true 165 } else { 166 false 167 } 168 } 169 is_not_retryable(&self) -> bool170 pub fn is_not_retryable(&self) -> bool { 171 if let RetryAction::DontRetry(_) = &self { 172 true 173 } else { 174 false 175 } 176 } 177 is_successful(&self) -> bool178 pub fn is_successful(&self) -> bool { 179 if let RetryAction::Successful = &self { 180 true 181 } else { 182 false 183 } 184 } 185 } 186 187 #[cfg(test)] 188 mod tests { 189 use super::*; 190 use crate::test_util::trace_init; 191 use std::{fmt, time::Duration}; 192 use tokio::time; 193 use tokio_test::{assert_pending, assert_ready_err, assert_ready_ok, task}; 194 use tower03::retry::RetryLayer; 195 use tower_test03::{assert_request_eq, mock}; 196 197 #[tokio::test] service_error_retry()198 async fn service_error_retry() { 199 time::pause(); 200 trace_init(); 201 202 let policy = FixedRetryPolicy::new( 203 5, 204 Duration::from_secs(1), 205 Duration::from_secs(10), 206 SvcRetryLogic, 207 ); 208 209 let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy)); 210 211 assert_ready_ok!(svc.poll_ready()); 212 213 let fut = svc.call("hello"); 214 let mut fut = task::spawn(fut); 215 216 assert_request_eq!(handle, "hello").send_error(Error(true)); 217 218 assert_pending!(fut.poll()); 219 220 time::advance(Duration::from_secs(2)).await; 221 assert_pending!(fut.poll()); 222 223 assert_request_eq!(handle, "hello").send_response("world"); 224 assert_eq!(fut.await.unwrap(), "world"); 225 } 226 227 #[tokio::test] service_error_no_retry()228 async fn service_error_no_retry() { 229 trace_init(); 230 231 let policy = FixedRetryPolicy::new( 232 5, 233 Duration::from_secs(1), 234 Duration::from_secs(10), 235 SvcRetryLogic, 236 ); 237 238 let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy)); 239 240 assert_ready_ok!(svc.poll_ready()); 241 242 let mut fut = task::spawn(svc.call("hello")); 243 assert_request_eq!(handle, "hello").send_error(Error(false)); 244 assert_ready_err!(fut.poll()); 245 } 246 247 #[tokio::test] timeout_error()248 async fn timeout_error() { 249 time::pause(); 250 trace_init(); 251 252 let policy = FixedRetryPolicy::new( 253 5, 254 Duration::from_secs(1), 255 Duration::from_secs(10), 256 SvcRetryLogic, 257 ); 258 259 let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy)); 260 261 assert_ready_ok!(svc.poll_ready()); 262 263 let mut fut = task::spawn(svc.call("hello")); 264 assert_request_eq!(handle, "hello").send_error(tower03::timeout::error::Elapsed::new()); 265 assert_pending!(fut.poll()); 266 267 time::advance(Duration::from_secs(2)).await; 268 assert_pending!(fut.poll()); 269 270 assert_request_eq!(handle, "hello").send_response("world"); 271 assert_eq!(fut.await.unwrap(), "world"); 272 } 273 274 #[test] backoff_grows_to_max()275 fn backoff_grows_to_max() { 276 let mut policy = FixedRetryPolicy::new( 277 10, 278 Duration::from_secs(1), 279 Duration::from_secs(10), 280 SvcRetryLogic, 281 ); 282 assert_eq!(Duration::from_secs(1), policy.backoff()); 283 284 policy = policy.advance(); 285 assert_eq!(Duration::from_secs(1), policy.backoff()); 286 287 policy = policy.advance(); 288 assert_eq!(Duration::from_secs(2), policy.backoff()); 289 290 policy = policy.advance(); 291 assert_eq!(Duration::from_secs(3), policy.backoff()); 292 293 policy = policy.advance(); 294 assert_eq!(Duration::from_secs(5), policy.backoff()); 295 296 policy = policy.advance(); 297 assert_eq!(Duration::from_secs(8), policy.backoff()); 298 299 policy = policy.advance(); 300 assert_eq!(Duration::from_secs(10), policy.backoff()); 301 302 policy = policy.advance(); 303 assert_eq!(Duration::from_secs(10), policy.backoff()); 304 } 305 306 #[derive(Debug, Clone)] 307 struct SvcRetryLogic; 308 309 impl RetryLogic for SvcRetryLogic { 310 type Error = Error; 311 type Response = &'static str; 312 is_retriable_error(&self, error: &Self::Error) -> bool313 fn is_retriable_error(&self, error: &Self::Error) -> bool { 314 error.0 315 } 316 } 317 318 #[derive(Debug)] 319 struct Error(bool); 320 321 impl fmt::Display for Error { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result322 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 323 write!(f, "error") 324 } 325 } 326 327 impl std::error::Error for Error {} 328 } 329