1package request_test 2 3import ( 4 "bytes" 5 "fmt" 6 "io/ioutil" 7 "net/http" 8 "testing" 9 "time" 10 11 "github.com/stretchr/testify/assert" 12 13 "github.com/aws/aws-sdk-go/aws" 14 "github.com/aws/aws-sdk-go/aws/awserr" 15 "github.com/aws/aws-sdk-go/aws/client" 16 "github.com/aws/aws-sdk-go/aws/request" 17 "github.com/aws/aws-sdk-go/awstesting" 18 "github.com/aws/aws-sdk-go/awstesting/unit" 19 "github.com/aws/aws-sdk-go/service/s3" 20) 21 22type mockClient struct { 23 *client.Client 24} 25type MockInput struct{} 26type MockOutput struct { 27 States []*MockState 28} 29type MockState struct { 30 State *string 31} 32 33func (c *mockClient) MockRequest(input *MockInput) (*request.Request, *MockOutput) { 34 op := &request.Operation{ 35 Name: "Mock", 36 HTTPMethod: "POST", 37 HTTPPath: "/", 38 } 39 40 if input == nil { 41 input = &MockInput{} 42 } 43 44 output := &MockOutput{} 45 req := c.NewRequest(op, input, output) 46 req.Data = output 47 return req, output 48} 49 50func BuildNewMockRequest(c *mockClient, in *MockInput) func([]request.Option) (*request.Request, error) { 51 return func(opts []request.Option) (*request.Request, error) { 52 req, _ := c.MockRequest(in) 53 req.ApplyOptions(opts...) 54 return req, nil 55 } 56} 57 58func TestWaiterPathAll(t *testing.T) { 59 svc := &mockClient{Client: awstesting.NewClient(&aws.Config{ 60 Region: aws.String("mock-region"), 61 })} 62 svc.Handlers.Send.Clear() // mock sending 63 svc.Handlers.Unmarshal.Clear() 64 svc.Handlers.UnmarshalMeta.Clear() 65 svc.Handlers.ValidateResponse.Clear() 66 67 reqNum := 0 68 resps := []*MockOutput{ 69 { // Request 1 70 States: []*MockState{ 71 {State: aws.String("pending")}, 72 {State: aws.String("pending")}, 73 }, 74 }, 75 { // Request 2 76 States: []*MockState{ 77 {State: aws.String("running")}, 78 {State: aws.String("pending")}, 79 }, 80 }, 81 { // Request 3 82 States: []*MockState{ 83 {State: aws.String("running")}, 84 {State: aws.String("running")}, 85 }, 86 }, 87 } 88 89 numBuiltReq := 0 90 svc.Handlers.Build.PushBack(func(r *request.Request) { 91 numBuiltReq++ 92 }) 93 svc.Handlers.Unmarshal.PushBack(func(r *request.Request) { 94 if reqNum >= len(resps) { 95 assert.Fail(t, "too many polling requests made") 96 return 97 } 98 r.Data = resps[reqNum] 99 reqNum++ 100 }) 101 102 w := request.Waiter{ 103 MaxAttempts: 10, 104 Delay: request.ConstantWaiterDelay(0), 105 SleepWithContext: aws.SleepWithContext, 106 Acceptors: []request.WaiterAcceptor{ 107 { 108 State: request.SuccessWaiterState, 109 Matcher: request.PathAllWaiterMatch, 110 Argument: "States[].State", 111 Expected: "running", 112 }, 113 }, 114 NewRequest: BuildNewMockRequest(svc, &MockInput{}), 115 } 116 117 err := w.WaitWithContext(aws.BackgroundContext()) 118 assert.NoError(t, err) 119 assert.Equal(t, 3, numBuiltReq) 120 assert.Equal(t, 3, reqNum) 121} 122 123func TestWaiterPath(t *testing.T) { 124 svc := &mockClient{Client: awstesting.NewClient(&aws.Config{ 125 Region: aws.String("mock-region"), 126 })} 127 svc.Handlers.Send.Clear() // mock sending 128 svc.Handlers.Unmarshal.Clear() 129 svc.Handlers.UnmarshalMeta.Clear() 130 svc.Handlers.ValidateResponse.Clear() 131 132 reqNum := 0 133 resps := []*MockOutput{ 134 { // Request 1 135 States: []*MockState{ 136 {State: aws.String("pending")}, 137 {State: aws.String("pending")}, 138 }, 139 }, 140 { // Request 2 141 States: []*MockState{ 142 {State: aws.String("running")}, 143 {State: aws.String("pending")}, 144 }, 145 }, 146 { // Request 3 147 States: []*MockState{ 148 {State: aws.String("running")}, 149 {State: aws.String("running")}, 150 }, 151 }, 152 } 153 154 numBuiltReq := 0 155 svc.Handlers.Build.PushBack(func(r *request.Request) { 156 numBuiltReq++ 157 }) 158 svc.Handlers.Unmarshal.PushBack(func(r *request.Request) { 159 if reqNum >= len(resps) { 160 assert.Fail(t, "too many polling requests made") 161 return 162 } 163 r.Data = resps[reqNum] 164 reqNum++ 165 }) 166 167 w := request.Waiter{ 168 MaxAttempts: 10, 169 Delay: request.ConstantWaiterDelay(0), 170 SleepWithContext: aws.SleepWithContext, 171 Acceptors: []request.WaiterAcceptor{ 172 { 173 State: request.SuccessWaiterState, 174 Matcher: request.PathWaiterMatch, 175 Argument: "States[].State", 176 Expected: "running", 177 }, 178 }, 179 NewRequest: BuildNewMockRequest(svc, &MockInput{}), 180 } 181 182 err := w.WaitWithContext(aws.BackgroundContext()) 183 assert.NoError(t, err) 184 assert.Equal(t, 3, numBuiltReq) 185 assert.Equal(t, 3, reqNum) 186} 187 188func TestWaiterFailure(t *testing.T) { 189 svc := &mockClient{Client: awstesting.NewClient(&aws.Config{ 190 Region: aws.String("mock-region"), 191 })} 192 svc.Handlers.Send.Clear() // mock sending 193 svc.Handlers.Unmarshal.Clear() 194 svc.Handlers.UnmarshalMeta.Clear() 195 svc.Handlers.ValidateResponse.Clear() 196 197 reqNum := 0 198 resps := []*MockOutput{ 199 { // Request 1 200 States: []*MockState{ 201 {State: aws.String("pending")}, 202 {State: aws.String("pending")}, 203 }, 204 }, 205 { // Request 2 206 States: []*MockState{ 207 {State: aws.String("running")}, 208 {State: aws.String("pending")}, 209 }, 210 }, 211 { // Request 3 212 States: []*MockState{ 213 {State: aws.String("running")}, 214 {State: aws.String("stopping")}, 215 }, 216 }, 217 } 218 219 numBuiltReq := 0 220 svc.Handlers.Build.PushBack(func(r *request.Request) { 221 numBuiltReq++ 222 }) 223 svc.Handlers.Unmarshal.PushBack(func(r *request.Request) { 224 if reqNum >= len(resps) { 225 assert.Fail(t, "too many polling requests made") 226 return 227 } 228 r.Data = resps[reqNum] 229 reqNum++ 230 }) 231 232 w := request.Waiter{ 233 MaxAttempts: 10, 234 Delay: request.ConstantWaiterDelay(0), 235 SleepWithContext: aws.SleepWithContext, 236 Acceptors: []request.WaiterAcceptor{ 237 { 238 State: request.SuccessWaiterState, 239 Matcher: request.PathAllWaiterMatch, 240 Argument: "States[].State", 241 Expected: "running", 242 }, 243 { 244 State: request.FailureWaiterState, 245 Matcher: request.PathAnyWaiterMatch, 246 Argument: "States[].State", 247 Expected: "stopping", 248 }, 249 }, 250 NewRequest: BuildNewMockRequest(svc, &MockInput{}), 251 } 252 253 err := w.WaitWithContext(aws.BackgroundContext()).(awserr.Error) 254 assert.Error(t, err) 255 assert.Equal(t, request.WaiterResourceNotReadyErrorCode, err.Code()) 256 assert.Equal(t, "failed waiting for successful resource state", err.Message()) 257 assert.Equal(t, 3, numBuiltReq) 258 assert.Equal(t, 3, reqNum) 259} 260 261func TestWaiterError(t *testing.T) { 262 svc := &mockClient{Client: awstesting.NewClient(&aws.Config{ 263 Region: aws.String("mock-region"), 264 })} 265 svc.Handlers.Send.Clear() // mock sending 266 svc.Handlers.Unmarshal.Clear() 267 svc.Handlers.UnmarshalMeta.Clear() 268 svc.Handlers.UnmarshalError.Clear() 269 svc.Handlers.ValidateResponse.Clear() 270 271 reqNum := 0 272 resps := []*MockOutput{ 273 { // Request 1 274 States: []*MockState{ 275 {State: aws.String("pending")}, 276 {State: aws.String("pending")}, 277 }, 278 }, 279 { // Request 1, error case retry 280 }, 281 { // Request 2, error case failure 282 }, 283 { // Request 3 284 States: []*MockState{ 285 {State: aws.String("running")}, 286 {State: aws.String("running")}, 287 }, 288 }, 289 } 290 reqErrs := make([]error, len(resps)) 291 reqErrs[1] = awserr.New("MockException", "mock exception message", nil) 292 reqErrs[2] = awserr.New("FailureException", "mock failure exception message", nil) 293 294 numBuiltReq := 0 295 svc.Handlers.Build.PushBack(func(r *request.Request) { 296 numBuiltReq++ 297 }) 298 svc.Handlers.Send.PushBack(func(r *request.Request) { 299 code := 200 300 if reqNum == 1 { 301 code = 400 302 } 303 r.HTTPResponse = &http.Response{ 304 StatusCode: code, 305 Status: http.StatusText(code), 306 Body: ioutil.NopCloser(bytes.NewReader([]byte{})), 307 } 308 }) 309 svc.Handlers.Unmarshal.PushBack(func(r *request.Request) { 310 if reqNum >= len(resps) { 311 assert.Fail(t, "too many polling requests made") 312 return 313 } 314 r.Data = resps[reqNum] 315 reqNum++ 316 }) 317 svc.Handlers.UnmarshalMeta.PushBack(func(r *request.Request) { 318 // If there was an error unmarshal error will be called instead of unmarshal 319 // need to increment count here also 320 if err := reqErrs[reqNum]; err != nil { 321 r.Error = err 322 reqNum++ 323 } 324 }) 325 326 w := request.Waiter{ 327 MaxAttempts: 10, 328 Delay: request.ConstantWaiterDelay(0), 329 SleepWithContext: aws.SleepWithContext, 330 Acceptors: []request.WaiterAcceptor{ 331 { 332 State: request.SuccessWaiterState, 333 Matcher: request.PathAllWaiterMatch, 334 Argument: "States[].State", 335 Expected: "running", 336 }, 337 { 338 State: request.RetryWaiterState, 339 Matcher: request.ErrorWaiterMatch, 340 Argument: "", 341 Expected: "MockException", 342 }, 343 { 344 State: request.FailureWaiterState, 345 Matcher: request.ErrorWaiterMatch, 346 Argument: "", 347 Expected: "FailureException", 348 }, 349 }, 350 NewRequest: BuildNewMockRequest(svc, &MockInput{}), 351 } 352 353 err := w.WaitWithContext(aws.BackgroundContext()) 354 if err == nil { 355 t.Fatalf("expected error, but did not get one") 356 } 357 aerr := err.(awserr.Error) 358 if e, a := request.WaiterResourceNotReadyErrorCode, aerr.Code(); e != a { 359 t.Errorf("expect %q error code, got %q", e, a) 360 } 361 if e, a := 3, numBuiltReq; e != a { 362 t.Errorf("expect %d built requests got %d", e, a) 363 } 364 if e, a := 3, reqNum; e != a { 365 t.Errorf("expect %d reqNum got %d", e, a) 366 } 367} 368 369func TestWaiterStatus(t *testing.T) { 370 svc := &mockClient{Client: awstesting.NewClient(&aws.Config{ 371 Region: aws.String("mock-region"), 372 })} 373 svc.Handlers.Send.Clear() // mock sending 374 svc.Handlers.Unmarshal.Clear() 375 svc.Handlers.UnmarshalMeta.Clear() 376 svc.Handlers.ValidateResponse.Clear() 377 378 reqNum := 0 379 svc.Handlers.Build.PushBack(func(r *request.Request) { 380 reqNum++ 381 }) 382 svc.Handlers.Send.PushBack(func(r *request.Request) { 383 code := 200 384 if reqNum == 3 { 385 code = 404 386 r.Error = awserr.New("NotFound", "Not Found", nil) 387 } 388 r.HTTPResponse = &http.Response{ 389 StatusCode: code, 390 Status: http.StatusText(code), 391 Body: ioutil.NopCloser(bytes.NewReader([]byte{})), 392 } 393 }) 394 395 w := request.Waiter{ 396 MaxAttempts: 10, 397 Delay: request.ConstantWaiterDelay(0), 398 SleepWithContext: aws.SleepWithContext, 399 Acceptors: []request.WaiterAcceptor{ 400 { 401 State: request.SuccessWaiterState, 402 Matcher: request.StatusWaiterMatch, 403 Argument: "", 404 Expected: 404, 405 }, 406 }, 407 NewRequest: BuildNewMockRequest(svc, &MockInput{}), 408 } 409 410 err := w.WaitWithContext(aws.BackgroundContext()) 411 assert.NoError(t, err) 412 assert.Equal(t, 3, reqNum) 413} 414 415func TestWaiter_ApplyOptions(t *testing.T) { 416 w := request.Waiter{} 417 418 logger := aws.NewDefaultLogger() 419 420 w.ApplyOptions( 421 request.WithWaiterLogger(logger), 422 request.WithWaiterRequestOptions(request.WithLogLevel(aws.LogDebug)), 423 request.WithWaiterMaxAttempts(2), 424 request.WithWaiterDelay(request.ConstantWaiterDelay(5*time.Second)), 425 ) 426 427 if e, a := logger, w.Logger; e != a { 428 t.Errorf("expect logger to be set, and match, was not, %v, %v", e, a) 429 } 430 431 if len(w.RequestOptions) != 1 { 432 t.Fatalf("expect request options to be set to only a single option, %v", w.RequestOptions) 433 } 434 r := request.Request{} 435 r.ApplyOptions(w.RequestOptions...) 436 if e, a := aws.LogDebug, r.Config.LogLevel.Value(); e != a { 437 t.Errorf("expect %v loglevel got %v", e, a) 438 } 439 440 if e, a := 2, w.MaxAttempts; e != a { 441 t.Errorf("expect %d retryer max attempts, got %d", e, a) 442 } 443 if e, a := 5*time.Second, w.Delay(0); e != a { 444 t.Errorf("expect %d retryer delay, got %d", e, a) 445 } 446} 447 448func TestWaiter_WithContextCanceled(t *testing.T) { 449 c := awstesting.NewClient() 450 451 ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})} 452 reqCount := 0 453 454 w := request.Waiter{ 455 Name: "TestWaiter", 456 MaxAttempts: 10, 457 Delay: request.ConstantWaiterDelay(1 * time.Millisecond), 458 SleepWithContext: aws.SleepWithContext, 459 Acceptors: []request.WaiterAcceptor{ 460 { 461 State: request.SuccessWaiterState, 462 Matcher: request.StatusWaiterMatch, 463 Expected: 200, 464 }, 465 }, 466 Logger: aws.NewDefaultLogger(), 467 NewRequest: func(opts []request.Option) (*request.Request, error) { 468 req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil) 469 req.HTTPResponse = &http.Response{StatusCode: http.StatusNotFound} 470 req.Handlers.Clear() 471 req.Data = struct{}{} 472 req.Handlers.Send.PushBack(func(r *request.Request) { 473 if reqCount == 1 { 474 ctx.Error = fmt.Errorf("context canceled") 475 close(ctx.DoneCh) 476 } 477 reqCount++ 478 }) 479 480 return req, nil 481 }, 482 } 483 484 w.SleepWithContext = func(c aws.Context, delay time.Duration) error { 485 context := c.(*awstesting.FakeContext) 486 select { 487 case <-context.DoneCh: 488 return context.Err() 489 default: 490 return nil 491 } 492 } 493 494 err := w.WaitWithContext(ctx) 495 496 if err == nil { 497 t.Fatalf("expect waiter to be canceled.") 498 } 499 aerr := err.(awserr.Error) 500 if e, a := request.CanceledErrorCode, aerr.Code(); e != a { 501 t.Errorf("expect %q error code, got %q", e, a) 502 } 503 if e, a := 2, reqCount; e != a { 504 t.Errorf("expect %d requests, got %d", e, a) 505 } 506} 507 508func TestWaiter_WithContext(t *testing.T) { 509 c := awstesting.NewClient() 510 511 ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})} 512 reqCount := 0 513 514 statuses := []int{http.StatusNotFound, http.StatusOK} 515 516 w := request.Waiter{ 517 Name: "TestWaiter", 518 MaxAttempts: 10, 519 Delay: request.ConstantWaiterDelay(1 * time.Millisecond), 520 SleepWithContext: aws.SleepWithContext, 521 Acceptors: []request.WaiterAcceptor{ 522 { 523 State: request.SuccessWaiterState, 524 Matcher: request.StatusWaiterMatch, 525 Expected: 200, 526 }, 527 }, 528 Logger: aws.NewDefaultLogger(), 529 NewRequest: func(opts []request.Option) (*request.Request, error) { 530 req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil) 531 req.HTTPResponse = &http.Response{StatusCode: statuses[reqCount]} 532 req.Handlers.Clear() 533 req.Data = struct{}{} 534 req.Handlers.Send.PushBack(func(r *request.Request) { 535 if reqCount == 1 { 536 ctx.Error = fmt.Errorf("context canceled") 537 close(ctx.DoneCh) 538 } 539 reqCount++ 540 }) 541 542 return req, nil 543 }, 544 } 545 546 err := w.WaitWithContext(ctx) 547 548 if err != nil { 549 t.Fatalf("expect no error, got %v", err) 550 } 551 if e, a := 2, reqCount; e != a { 552 t.Errorf("expect %d requests, got %d", e, a) 553 } 554} 555 556func TestWaiter_AttemptsExpires(t *testing.T) { 557 c := awstesting.NewClient() 558 559 ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})} 560 reqCount := 0 561 562 w := request.Waiter{ 563 Name: "TestWaiter", 564 MaxAttempts: 2, 565 Delay: request.ConstantWaiterDelay(1 * time.Millisecond), 566 SleepWithContext: aws.SleepWithContext, 567 Acceptors: []request.WaiterAcceptor{ 568 { 569 State: request.SuccessWaiterState, 570 Matcher: request.StatusWaiterMatch, 571 Expected: 200, 572 }, 573 }, 574 Logger: aws.NewDefaultLogger(), 575 NewRequest: func(opts []request.Option) (*request.Request, error) { 576 req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil) 577 req.HTTPResponse = &http.Response{StatusCode: http.StatusNotFound} 578 req.Handlers.Clear() 579 req.Data = struct{}{} 580 req.Handlers.Send.PushBack(func(r *request.Request) { 581 reqCount++ 582 }) 583 584 return req, nil 585 }, 586 } 587 588 err := w.WaitWithContext(ctx) 589 590 if err == nil { 591 t.Fatalf("expect error did not get one") 592 } 593 aerr := err.(awserr.Error) 594 if e, a := request.WaiterResourceNotReadyErrorCode, aerr.Code(); e != a { 595 t.Errorf("expect %q error code, got %q", e, a) 596 } 597 if e, a := 2, reqCount; e != a { 598 t.Errorf("expect %d requests, got %d", e, a) 599 } 600} 601 602func TestWaiterNilInput(t *testing.T) { 603 // Code generation doesn't have a great way to verify the code is correct 604 // other than being run via unit tests in the SDK. This should be fixed 605 // So code generation can be validated independently. 606 607 client := s3.New(unit.Session) 608 client.Handlers.Validate.Clear() 609 client.Handlers.Send.Clear() // mock sending 610 client.Handlers.Send.PushBack(func(r *request.Request) { 611 r.HTTPResponse = &http.Response{ 612 StatusCode: http.StatusOK, 613 } 614 }) 615 client.Handlers.Unmarshal.Clear() 616 client.Handlers.UnmarshalMeta.Clear() 617 client.Handlers.ValidateResponse.Clear() 618 client.Config.SleepDelay = func(dur time.Duration) {} 619 620 // Ensure waiters do not panic on nil input. It doesn't make sense to 621 // call a waiter without an input, Validation will 622 err := client.WaitUntilBucketExists(nil) 623 if err != nil { 624 t.Fatalf("expect no error, but got %v", err) 625 } 626} 627 628func TestWaiterWithContextNilInput(t *testing.T) { 629 // Code generation doesn't have a great way to verify the code is correct 630 // other than being run via unit tests in the SDK. This should be fixed 631 // So code generation can be validated independently. 632 633 client := s3.New(unit.Session) 634 client.Handlers.Validate.Clear() 635 client.Handlers.Send.Clear() // mock sending 636 client.Handlers.Send.PushBack(func(r *request.Request) { 637 r.HTTPResponse = &http.Response{ 638 StatusCode: http.StatusOK, 639 } 640 }) 641 client.Handlers.Unmarshal.Clear() 642 client.Handlers.UnmarshalMeta.Clear() 643 client.Handlers.ValidateResponse.Clear() 644 645 // Ensure waiters do not panic on nil input 646 ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})} 647 err := client.WaitUntilBucketExistsWithContext(ctx, nil, 648 request.WithWaiterDelay(request.ConstantWaiterDelay(0)), 649 request.WithWaiterMaxAttempts(1), 650 ) 651 if err != nil { 652 t.Fatalf("expect no error, but got %v", err) 653 } 654} 655