1package retry 2 3import ( 4 "context" 5 "fmt" 6 "net/http" 7 "reflect" 8 "strconv" 9 "strings" 10 "testing" 11 "time" 12 13 "github.com/aws/aws-sdk-go-v2/aws" 14 "github.com/aws/aws-sdk-go-v2/internal/sdk" 15 "github.com/aws/smithy-go/middleware" 16 smithyhttp "github.com/aws/smithy-go/transport/http" 17 "github.com/google/go-cmp/cmp" 18) 19 20func TestMetricsHeaderMiddleware(t *testing.T) { 21 cases := []struct { 22 input middleware.FinalizeInput 23 ctx context.Context 24 expectedHeader string 25 expectedErr string 26 }{ 27 { 28 input: middleware.FinalizeInput{Request: &smithyhttp.Request{Request: &http.Request{Header: make(http.Header)}}}, 29 ctx: func() context.Context { 30 return setRetryMetadata(context.Background(), retryMetadata{ 31 AttemptNum: 0, 32 AttemptTime: time.Date(2020, 01, 02, 03, 04, 05, 0, time.UTC), 33 MaxAttempts: 5, 34 AttemptClockSkew: 0, 35 }) 36 }(), 37 expectedHeader: "attempt=0; max=5", 38 }, 39 { 40 input: middleware.FinalizeInput{Request: &smithyhttp.Request{Request: &http.Request{Header: make(http.Header)}}}, 41 ctx: func() context.Context { 42 attemptTime := time.Date(2020, 01, 02, 03, 04, 05, 0, time.UTC) 43 ctx, cancel := context.WithDeadline(context.Background(), attemptTime.Add(time.Minute)) 44 defer cancel() 45 return setRetryMetadata(ctx, retryMetadata{ 46 AttemptNum: 1, 47 AttemptTime: attemptTime, 48 MaxAttempts: 5, 49 AttemptClockSkew: time.Second * 1, 50 }) 51 }(), 52 expectedHeader: "attempt=1; max=5; ttl=20200102T030506Z", 53 }, 54 { 55 ctx: func() context.Context { 56 return setRetryMetadata(context.Background(), retryMetadata{}) 57 }(), 58 expectedErr: "unknown transport type", 59 }, 60 } 61 62 retryMiddleware := MetricsHeader{} 63 for i, tt := range cases { 64 t.Run(strconv.Itoa(i), func(t *testing.T) { 65 ctx := tt.ctx 66 _, _, err := retryMiddleware.HandleFinalize(ctx, tt.input, middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) ( 67 out middleware.FinalizeOutput, metadata middleware.Metadata, err error, 68 ) { 69 req := in.Request.(*smithyhttp.Request) 70 71 if e, a := tt.expectedHeader, req.Header.Get("amz-sdk-request"); e != a { 72 t.Errorf("expected %v, got %v", e, a) 73 } 74 75 return out, metadata, err 76 })) 77 if err != nil && len(tt.expectedErr) == 0 { 78 t.Fatalf("expected no error, got %q", err) 79 } else if err != nil && len(tt.expectedErr) != 0 { 80 if e, a := tt.expectedErr, err.Error(); !strings.Contains(a, e) { 81 t.Fatalf("expected %q, got %q", e, a) 82 } 83 } else if err == nil && len(tt.expectedErr) != 0 { 84 t.Fatalf("expected error, got nil") 85 } 86 }) 87 } 88} 89 90type retryProvider struct { 91 Retryer aws.Retryer 92} 93 94func (t retryProvider) GetRetryer() aws.Retryer { 95 return t.Retryer 96} 97 98type mockHandler func(context.Context, interface{}) (interface{}, middleware.Metadata, error) 99 100func (m mockHandler) Handle(ctx context.Context, input interface{}) (output interface{}, metadata middleware.Metadata, err error) { 101 return m(ctx, input) 102} 103 104func (m mockHandler) ID() string { 105 return fmt.Sprintf("%T", m) 106} 107 108type testRequest struct { 109 DisableRewind bool 110} 111 112func (r testRequest) RewindStream() error { 113 if r.DisableRewind { 114 return fmt.Errorf("rewind disabled") 115 } 116 return nil 117} 118 119type mockRetryableError struct{ b bool } 120 121func (m mockRetryableError) RetryableError() bool { return m.b } 122func (m mockRetryableError) Error() string { 123 return fmt.Sprintf("mock retryable %t", m.b) 124} 125 126func TestAttemptMiddleware(t *testing.T) { 127 restoreSleep := sdk.TestingUseNopSleep() 128 defer restoreSleep() 129 130 sdkTime := sdk.NowTime 131 defer func() { 132 sdk.NowTime = sdkTime 133 }() 134 135 cases := map[string]struct { 136 Request testRequest 137 Next func(retries *[]retryMetadata) middleware.FinalizeHandler 138 Expect []retryMetadata 139 Err error 140 ExpectResults AttemptResults 141 }{ 142 "no error, no response in a single attempt": { 143 Next: func(retries *[]retryMetadata) middleware.FinalizeHandler { 144 return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { 145 m, ok := getRetryMetadata(ctx) 146 if ok { 147 *retries = append(*retries, m) 148 } 149 return out, metadata, err 150 }) 151 }, 152 Expect: []retryMetadata{ 153 { 154 AttemptNum: 1, 155 AttemptTime: time.Date(2020, 8, 19, 10, 20, 30, 0, time.UTC), 156 MaxAttempts: 3, 157 }, 158 }, 159 ExpectResults: AttemptResults{Results: []AttemptResult{ 160 {}, 161 }}, 162 }, 163 "no error in a single attempt": { 164 Next: func(retries *[]retryMetadata) middleware.FinalizeHandler { 165 return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { 166 m, ok := getRetryMetadata(ctx) 167 if ok { 168 *retries = append(*retries, m) 169 } 170 setMockRawResponse(&metadata, "mockResponse") 171 return out, metadata, err 172 }) 173 }, 174 Expect: []retryMetadata{ 175 { 176 AttemptNum: 1, 177 AttemptTime: time.Date(2020, 8, 19, 10, 20, 30, 0, time.UTC), 178 MaxAttempts: 3, 179 }, 180 }, 181 ExpectResults: AttemptResults{Results: []AttemptResult{ 182 { 183 ResponseMetadata: func() middleware.Metadata { 184 m := middleware.Metadata{} 185 setMockRawResponse(&m, "mockResponse") 186 return m 187 }(), 188 }, 189 }}, 190 }, 191 "retries errors": { 192 Next: func(retries *[]retryMetadata) middleware.FinalizeHandler { 193 num := 0 194 reqsErrs := []error{ 195 mockRetryableError{b: true}, 196 mockRetryableError{b: true}, 197 nil, 198 } 199 return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { 200 m, ok := getRetryMetadata(ctx) 201 if ok { 202 *retries = append(*retries, m) 203 } 204 if num >= len(reqsErrs) { 205 err = fmt.Errorf("more requests then expected") 206 } else { 207 err = reqsErrs[num] 208 num++ 209 } 210 return out, metadata, err 211 }) 212 }, 213 Expect: []retryMetadata{ 214 { 215 AttemptNum: 1, 216 AttemptTime: time.Date(2020, 8, 19, 10, 20, 30, 0, time.UTC), 217 MaxAttempts: 3, 218 }, 219 { 220 AttemptNum: 2, 221 AttemptTime: time.Date(2020, 8, 19, 10, 21, 30, 0, time.UTC), 222 MaxAttempts: 3, 223 }, 224 { 225 AttemptNum: 3, 226 AttemptTime: time.Date(2020, 8, 19, 10, 22, 30, 0, time.UTC), 227 MaxAttempts: 3, 228 }, 229 }, 230 ExpectResults: AttemptResults{Results: []AttemptResult{ 231 { 232 Err: mockRetryableError{b: true}, 233 Retryable: true, 234 Retried: true, 235 }, 236 { 237 Err: mockRetryableError{b: true}, 238 Retryable: true, 239 Retried: true, 240 }, 241 {}, 242 }}, 243 }, 244 "stops after max attempts": { 245 Next: func(retries *[]retryMetadata) middleware.FinalizeHandler { 246 num := 0 247 reqsErrs := []error{ 248 mockRetryableError{b: true}, 249 mockRetryableError{b: true}, 250 mockRetryableError{b: true}, 251 } 252 return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { 253 if num >= len(reqsErrs) { 254 err = fmt.Errorf("more requests then expected") 255 } else { 256 err = reqsErrs[num] 257 num++ 258 } 259 return out, metadata, err 260 }) 261 }, 262 Err: fmt.Errorf("exceeded maximum number of attempts"), 263 ExpectResults: AttemptResults{Results: []AttemptResult{ 264 { 265 Err: mockRetryableError{b: true}, 266 Retryable: true, 267 Retried: true, 268 }, 269 { 270 Err: mockRetryableError{b: true}, 271 Retryable: true, 272 Retried: true, 273 }, 274 { 275 Err: &MaxAttemptsError{Attempt: 3, Err: mockRetryableError{b: true}}, 276 Retryable: true, 277 }, 278 }}, 279 }, 280 "stops on rewind error": { 281 Request: testRequest{DisableRewind: true}, 282 Next: func(retries *[]retryMetadata) middleware.FinalizeHandler { 283 return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { 284 m, ok := getRetryMetadata(ctx) 285 if ok { 286 *retries = append(*retries, m) 287 } 288 return out, metadata, mockRetryableError{b: true} 289 }) 290 }, 291 Expect: []retryMetadata{ 292 { 293 AttemptNum: 1, 294 AttemptTime: time.Date(2020, 8, 19, 10, 20, 30, 0, time.UTC), 295 MaxAttempts: 3, 296 }, 297 }, 298 Err: fmt.Errorf("failed to rewind transport stream for retry"), 299 ExpectResults: AttemptResults{Results: []AttemptResult{ 300 { 301 Err: mockRetryableError{b: true}, 302 Retryable: true, 303 Retried: true, 304 }, 305 { 306 Err: fmt.Errorf( 307 "failed to rewind transport stream for retry, %w", 308 fmt.Errorf("rewind disabled"), 309 ), 310 }, 311 }}, 312 }, 313 "stops on non-retryable errors": { 314 Next: func(retries *[]retryMetadata) middleware.FinalizeHandler { 315 return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { 316 m, ok := getRetryMetadata(ctx) 317 if ok { 318 *retries = append(*retries, m) 319 } 320 return out, metadata, fmt.Errorf("some error") 321 }) 322 }, 323 Expect: []retryMetadata{ 324 { 325 AttemptNum: 1, 326 AttemptTime: time.Date(2020, 8, 19, 10, 20, 30, 0, time.UTC), 327 MaxAttempts: 3, 328 }, 329 }, 330 Err: fmt.Errorf("some error"), 331 ExpectResults: AttemptResults{Results: []AttemptResult{ 332 { 333 Err: fmt.Errorf("some error"), 334 }, 335 }}, 336 }, 337 "nested metadata and valid response": { 338 Next: func(retries *[]retryMetadata) middleware.FinalizeHandler { 339 num := 0 340 reqsErrs := []error{ 341 mockRetryableError{b: true}, 342 nil, 343 } 344 return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { 345 m, ok := getRetryMetadata(ctx) 346 if ok { 347 *retries = append(*retries, m) 348 } 349 if num >= len(reqsErrs) { 350 err = fmt.Errorf("more requests then expected") 351 } else { 352 err = reqsErrs[num] 353 num++ 354 } 355 356 if err != nil { 357 metadata.Set("testKey", "testValue") 358 } else { 359 setMockRawResponse(&metadata, "mockResponse") 360 } 361 return out, metadata, err 362 }) 363 }, 364 Expect: []retryMetadata{ 365 { 366 AttemptNum: 1, 367 AttemptTime: time.Date(2020, 8, 19, 10, 20, 30, 0, time.UTC), 368 MaxAttempts: 3, 369 }, 370 { 371 AttemptNum: 2, 372 AttemptTime: time.Date(2020, 8, 19, 10, 21, 30, 0, time.UTC), 373 MaxAttempts: 3, 374 }, 375 }, 376 ExpectResults: AttemptResults{Results: []AttemptResult{ 377 { 378 Err: mockRetryableError{b: true}, 379 Retryable: true, 380 Retried: true, 381 ResponseMetadata: func() middleware.Metadata { 382 m := middleware.Metadata{} 383 m.Set("testKey", "testValue") 384 return m 385 }(), 386 }, 387 { 388 ResponseMetadata: func() middleware.Metadata { 389 m := middleware.Metadata{} 390 setMockRawResponse(&m, "mockResponse") 391 return m 392 }(), 393 }, 394 }}, 395 }, 396 } 397 398 for name, tt := range cases { 399 t.Run(name, func(t *testing.T) { 400 sdk.NowTime = func() func() time.Time { 401 base := time.Date(2020, 8, 19, 10, 20, 30, 0, time.UTC) 402 num := 0 403 return func() time.Time { 404 t := base.Add(time.Minute * time.Duration(num)) 405 num++ 406 return t 407 } 408 }() 409 410 am := NewAttemptMiddleware(NewStandard(func(s *StandardOptions) { 411 s.MaxAttempts = 3 412 }), func(i interface{}) interface{} { 413 return i 414 }) 415 416 var recorded []retryMetadata 417 _, metadata, err := am.HandleFinalize(context.Background(), middleware.FinalizeInput{Request: tt.Request}, tt.Next(&recorded)) 418 if err != nil && tt.Err == nil { 419 t.Errorf("expect no error, got %v", err) 420 } else if err == nil && tt.Err != nil { 421 t.Errorf("expect error, got none") 422 } else if err != nil && tt.Err != nil { 423 if !strings.Contains(err.Error(), tt.Err.Error()) { 424 t.Errorf("expect %v, got %v", tt.Err, err) 425 } 426 } 427 if diff := cmp.Diff(recorded, tt.Expect); len(diff) > 0 { 428 t.Error(diff) 429 } 430 431 attemptResults, ok := GetAttemptResults(metadata) 432 if !ok { 433 t.Fatalf("expected metadata to contain attempt results, got none") 434 } 435 if e, a := tt.ExpectResults, attemptResults; !reflect.DeepEqual(e, a) { 436 t.Fatalf("expected %v, got %v", e, a) 437 } 438 }) 439 } 440} 441 442// mockRawResponseKey is used to test the behavior when response metadata is 443// nested within the attempt request. 444type mockRawResponseKey struct{} 445 446func setMockRawResponse(m *middleware.Metadata, v interface{}) { 447 m.Set(mockRawResponseKey{}, v) 448} 449