1package request_test 2 3import ( 4 "bytes" 5 "fmt" 6 "io/ioutil" 7 "net/http" 8 "testing" 9 "time" 10 11 "github.com/aws/aws-sdk-go/aws" 12 "github.com/aws/aws-sdk-go/aws/awserr" 13 "github.com/aws/aws-sdk-go/aws/client" 14 "github.com/aws/aws-sdk-go/aws/request" 15 "github.com/aws/aws-sdk-go/awstesting" 16 "github.com/aws/aws-sdk-go/awstesting/unit" 17 "github.com/aws/aws-sdk-go/service/s3" 18) 19 20type mockClient struct { 21 *client.Client 22} 23type MockInput struct{} 24type MockOutput struct { 25 States []*MockState 26} 27type MockState struct { 28 State *string 29} 30 31func (c *mockClient) MockRequest(input *MockInput) (*request.Request, *MockOutput) { 32 op := &request.Operation{ 33 Name: "Mock", 34 HTTPMethod: "POST", 35 HTTPPath: "/", 36 } 37 38 if input == nil { 39 input = &MockInput{} 40 } 41 42 output := &MockOutput{} 43 req := c.NewRequest(op, input, output) 44 req.Data = output 45 return req, output 46} 47 48func BuildNewMockRequest(c *mockClient, in *MockInput) func([]request.Option) (*request.Request, error) { 49 return func(opts []request.Option) (*request.Request, error) { 50 req, _ := c.MockRequest(in) 51 req.ApplyOptions(opts...) 52 return req, nil 53 } 54} 55 56func TestWaiterPathAll(t *testing.T) { 57 svc := &mockClient{Client: awstesting.NewClient(&aws.Config{ 58 Region: aws.String("mock-region"), 59 })} 60 svc.Handlers.Send.Clear() // mock sending 61 svc.Handlers.Unmarshal.Clear() 62 svc.Handlers.UnmarshalMeta.Clear() 63 svc.Handlers.ValidateResponse.Clear() 64 65 reqNum := 0 66 resps := []*MockOutput{ 67 { // Request 1 68 States: []*MockState{ 69 {State: aws.String("pending")}, 70 {State: aws.String("pending")}, 71 }, 72 }, 73 { // Request 2 74 States: []*MockState{ 75 {State: aws.String("running")}, 76 {State: aws.String("pending")}, 77 }, 78 }, 79 { // Request 3 80 States: []*MockState{ 81 {State: aws.String("running")}, 82 {State: aws.String("running")}, 83 }, 84 }, 85 } 86 87 numBuiltReq := 0 88 svc.Handlers.Build.PushBack(func(r *request.Request) { 89 numBuiltReq++ 90 }) 91 svc.Handlers.Unmarshal.PushBack(func(r *request.Request) { 92 if reqNum >= len(resps) { 93 t.Errorf("too many polling requests made") 94 return 95 } 96 r.Data = resps[reqNum] 97 reqNum++ 98 }) 99 100 w := request.Waiter{ 101 MaxAttempts: 10, 102 Delay: request.ConstantWaiterDelay(0), 103 SleepWithContext: aws.SleepWithContext, 104 Acceptors: []request.WaiterAcceptor{ 105 { 106 State: request.SuccessWaiterState, 107 Matcher: request.PathAllWaiterMatch, 108 Argument: "States[].State", 109 Expected: "running", 110 }, 111 }, 112 NewRequest: BuildNewMockRequest(svc, &MockInput{}), 113 } 114 115 err := w.WaitWithContext(aws.BackgroundContext()) 116 if err != nil { 117 t.Errorf("expect nil, %v", err) 118 } 119 if e, a := 3, numBuiltReq; e != a { 120 t.Errorf("expect %v, got %v", e, a) 121 } 122 if e, a := 3, reqNum; e != a { 123 t.Errorf("expect %v, got %v", e, a) 124 } 125} 126 127func TestWaiterPath(t *testing.T) { 128 svc := &mockClient{Client: awstesting.NewClient(&aws.Config{ 129 Region: aws.String("mock-region"), 130 })} 131 svc.Handlers.Send.Clear() // mock sending 132 svc.Handlers.Unmarshal.Clear() 133 svc.Handlers.UnmarshalMeta.Clear() 134 svc.Handlers.ValidateResponse.Clear() 135 136 reqNum := 0 137 resps := []*MockOutput{ 138 { // Request 1 139 States: []*MockState{ 140 {State: aws.String("pending")}, 141 {State: aws.String("pending")}, 142 }, 143 }, 144 { // Request 2 145 States: []*MockState{ 146 {State: aws.String("running")}, 147 {State: aws.String("pending")}, 148 }, 149 }, 150 { // Request 3 151 States: []*MockState{ 152 {State: aws.String("running")}, 153 {State: aws.String("running")}, 154 }, 155 }, 156 } 157 158 numBuiltReq := 0 159 svc.Handlers.Build.PushBack(func(r *request.Request) { 160 numBuiltReq++ 161 }) 162 svc.Handlers.Unmarshal.PushBack(func(r *request.Request) { 163 if reqNum >= len(resps) { 164 t.Errorf("too many polling requests made") 165 return 166 } 167 r.Data = resps[reqNum] 168 reqNum++ 169 }) 170 171 w := request.Waiter{ 172 MaxAttempts: 10, 173 Delay: request.ConstantWaiterDelay(0), 174 SleepWithContext: aws.SleepWithContext, 175 Acceptors: []request.WaiterAcceptor{ 176 { 177 State: request.SuccessWaiterState, 178 Matcher: request.PathWaiterMatch, 179 Argument: "States[].State", 180 Expected: "running", 181 }, 182 }, 183 NewRequest: BuildNewMockRequest(svc, &MockInput{}), 184 } 185 186 err := w.WaitWithContext(aws.BackgroundContext()) 187 if err != nil { 188 t.Errorf("expect nil, %v", err) 189 } 190 if e, a := 3, numBuiltReq; e != a { 191 t.Errorf("expect %v, got %v", e, a) 192 } 193 if e, a := 3, reqNum; e != a { 194 t.Errorf("expect %v, got %v", e, a) 195 } 196} 197 198func TestWaiterFailure(t *testing.T) { 199 svc := &mockClient{Client: awstesting.NewClient(&aws.Config{ 200 Region: aws.String("mock-region"), 201 })} 202 svc.Handlers.Send.Clear() // mock sending 203 svc.Handlers.Unmarshal.Clear() 204 svc.Handlers.UnmarshalMeta.Clear() 205 svc.Handlers.ValidateResponse.Clear() 206 207 reqNum := 0 208 resps := []*MockOutput{ 209 { // Request 1 210 States: []*MockState{ 211 {State: aws.String("pending")}, 212 {State: aws.String("pending")}, 213 }, 214 }, 215 { // Request 2 216 States: []*MockState{ 217 {State: aws.String("running")}, 218 {State: aws.String("pending")}, 219 }, 220 }, 221 { // Request 3 222 States: []*MockState{ 223 {State: aws.String("running")}, 224 {State: aws.String("stopping")}, 225 }, 226 }, 227 } 228 229 numBuiltReq := 0 230 svc.Handlers.Build.PushBack(func(r *request.Request) { 231 numBuiltReq++ 232 }) 233 svc.Handlers.Unmarshal.PushBack(func(r *request.Request) { 234 if reqNum >= len(resps) { 235 t.Errorf("too many polling requests made") 236 return 237 } 238 r.Data = resps[reqNum] 239 reqNum++ 240 }) 241 242 w := request.Waiter{ 243 MaxAttempts: 10, 244 Delay: request.ConstantWaiterDelay(0), 245 SleepWithContext: aws.SleepWithContext, 246 Acceptors: []request.WaiterAcceptor{ 247 { 248 State: request.SuccessWaiterState, 249 Matcher: request.PathAllWaiterMatch, 250 Argument: "States[].State", 251 Expected: "running", 252 }, 253 { 254 State: request.FailureWaiterState, 255 Matcher: request.PathAnyWaiterMatch, 256 Argument: "States[].State", 257 Expected: "stopping", 258 }, 259 }, 260 NewRequest: BuildNewMockRequest(svc, &MockInput{}), 261 } 262 263 err := w.WaitWithContext(aws.BackgroundContext()).(awserr.Error) 264 if err == nil { 265 t.Errorf("expect error") 266 } 267 if e, a := request.WaiterResourceNotReadyErrorCode, err.Code(); e != a { 268 t.Errorf("expect %v, got %v", e, a) 269 } 270 if e, a := "failed waiting for successful resource state", err.Message(); e != a { 271 t.Errorf("expect %v, got %v", e, a) 272 } 273 if e, a := 3, numBuiltReq; e != a { 274 t.Errorf("expect %v, got %v", e, a) 275 } 276 if e, a := 3, reqNum; e != a { 277 t.Errorf("expect %v, got %v", e, a) 278 } 279} 280 281func TestWaiterError(t *testing.T) { 282 svc := &mockClient{Client: awstesting.NewClient(&aws.Config{ 283 Region: aws.String("mock-region"), 284 })} 285 svc.Handlers.Send.Clear() // mock sending 286 svc.Handlers.Unmarshal.Clear() 287 svc.Handlers.UnmarshalMeta.Clear() 288 svc.Handlers.UnmarshalError.Clear() 289 svc.Handlers.ValidateResponse.Clear() 290 291 reqNum := 0 292 resps := []*MockOutput{ 293 { // Request 1 294 States: []*MockState{ 295 {State: aws.String("pending")}, 296 {State: aws.String("pending")}, 297 }, 298 }, 299 { // Request 1, error case retry 300 }, 301 { // Request 2, error case failure 302 }, 303 { // Request 3 304 States: []*MockState{ 305 {State: aws.String("running")}, 306 {State: aws.String("running")}, 307 }, 308 }, 309 } 310 reqErrs := make([]error, len(resps)) 311 reqErrs[1] = awserr.New("MockException", "mock exception message", nil) 312 reqErrs[2] = awserr.New("FailureException", "mock failure exception message", nil) 313 314 numBuiltReq := 0 315 svc.Handlers.Build.PushBack(func(r *request.Request) { 316 numBuiltReq++ 317 }) 318 svc.Handlers.Send.PushBack(func(r *request.Request) { 319 code := 200 320 if reqNum == 1 { 321 code = 400 322 } 323 r.HTTPResponse = &http.Response{ 324 StatusCode: code, 325 Status: http.StatusText(code), 326 Body: ioutil.NopCloser(bytes.NewReader([]byte{})), 327 } 328 }) 329 svc.Handlers.Unmarshal.PushBack(func(r *request.Request) { 330 if reqNum >= len(resps) { 331 t.Errorf("too many polling requests made") 332 return 333 } 334 r.Data = resps[reqNum] 335 reqNum++ 336 }) 337 svc.Handlers.UnmarshalMeta.PushBack(func(r *request.Request) { 338 // If there was an error unmarshal error will be called instead of unmarshal 339 // need to increment count here also 340 if err := reqErrs[reqNum]; err != nil { 341 r.Error = err 342 reqNum++ 343 } 344 }) 345 346 w := request.Waiter{ 347 MaxAttempts: 10, 348 Delay: request.ConstantWaiterDelay(0), 349 SleepWithContext: aws.SleepWithContext, 350 Acceptors: []request.WaiterAcceptor{ 351 { 352 State: request.SuccessWaiterState, 353 Matcher: request.PathAllWaiterMatch, 354 Argument: "States[].State", 355 Expected: "running", 356 }, 357 { 358 State: request.RetryWaiterState, 359 Matcher: request.ErrorWaiterMatch, 360 Argument: "", 361 Expected: "MockException", 362 }, 363 { 364 State: request.FailureWaiterState, 365 Matcher: request.ErrorWaiterMatch, 366 Argument: "", 367 Expected: "FailureException", 368 }, 369 }, 370 NewRequest: BuildNewMockRequest(svc, &MockInput{}), 371 } 372 373 err := w.WaitWithContext(aws.BackgroundContext()) 374 if err == nil { 375 t.Fatalf("expected error, but did not get one") 376 } 377 aerr := err.(awserr.Error) 378 if e, a := request.WaiterResourceNotReadyErrorCode, aerr.Code(); e != a { 379 t.Errorf("expect %q error code, got %q", e, a) 380 } 381 if e, a := 3, numBuiltReq; e != a { 382 t.Errorf("expect %d built requests got %d", e, a) 383 } 384 if e, a := 3, reqNum; e != a { 385 t.Errorf("expect %d reqNum got %d", e, a) 386 } 387} 388 389func TestWaiterStatus(t *testing.T) { 390 svc := &mockClient{Client: awstesting.NewClient(&aws.Config{ 391 Region: aws.String("mock-region"), 392 })} 393 svc.Handlers.Send.Clear() // mock sending 394 svc.Handlers.Unmarshal.Clear() 395 svc.Handlers.UnmarshalMeta.Clear() 396 svc.Handlers.ValidateResponse.Clear() 397 398 reqNum := 0 399 svc.Handlers.Build.PushBack(func(r *request.Request) { 400 reqNum++ 401 }) 402 svc.Handlers.Send.PushBack(func(r *request.Request) { 403 code := 200 404 if reqNum == 3 { 405 code = 404 406 r.Error = awserr.New("NotFound", "Not Found", nil) 407 } 408 r.HTTPResponse = &http.Response{ 409 StatusCode: code, 410 Status: http.StatusText(code), 411 Body: ioutil.NopCloser(bytes.NewReader([]byte{})), 412 } 413 }) 414 415 w := request.Waiter{ 416 MaxAttempts: 10, 417 Delay: request.ConstantWaiterDelay(0), 418 SleepWithContext: aws.SleepWithContext, 419 Acceptors: []request.WaiterAcceptor{ 420 { 421 State: request.SuccessWaiterState, 422 Matcher: request.StatusWaiterMatch, 423 Argument: "", 424 Expected: 404, 425 }, 426 }, 427 NewRequest: BuildNewMockRequest(svc, &MockInput{}), 428 } 429 430 err := w.WaitWithContext(aws.BackgroundContext()) 431 if err != nil { 432 t.Errorf("expect nil, %v", err) 433 } 434 if e, a := 3, reqNum; e != a { 435 t.Errorf("expect %v, got %v", e, a) 436 } 437} 438 439func TestWaiter_ApplyOptions(t *testing.T) { 440 w := request.Waiter{} 441 442 logger := aws.NewDefaultLogger() 443 444 w.ApplyOptions( 445 request.WithWaiterLogger(logger), 446 request.WithWaiterRequestOptions(request.WithLogLevel(aws.LogDebug)), 447 request.WithWaiterMaxAttempts(2), 448 request.WithWaiterDelay(request.ConstantWaiterDelay(5*time.Second)), 449 ) 450 451 if e, a := logger, w.Logger; e != a { 452 t.Errorf("expect logger to be set, and match, was not, %v, %v", e, a) 453 } 454 455 if len(w.RequestOptions) != 1 { 456 t.Fatalf("expect request options to be set to only a single option, %v", w.RequestOptions) 457 } 458 r := request.Request{} 459 r.ApplyOptions(w.RequestOptions...) 460 if e, a := aws.LogDebug, r.Config.LogLevel.Value(); e != a { 461 t.Errorf("expect %v loglevel got %v", e, a) 462 } 463 464 if e, a := 2, w.MaxAttempts; e != a { 465 t.Errorf("expect %d retryer max attempts, got %d", e, a) 466 } 467 if e, a := 5*time.Second, w.Delay(0); e != a { 468 t.Errorf("expect %d retryer delay, got %d", e, a) 469 } 470} 471 472func TestWaiter_WithContextCanceled(t *testing.T) { 473 c := awstesting.NewClient() 474 475 ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})} 476 reqCount := 0 477 478 w := request.Waiter{ 479 Name: "TestWaiter", 480 MaxAttempts: 10, 481 Delay: request.ConstantWaiterDelay(1 * time.Millisecond), 482 SleepWithContext: aws.SleepWithContext, 483 Acceptors: []request.WaiterAcceptor{ 484 { 485 State: request.SuccessWaiterState, 486 Matcher: request.StatusWaiterMatch, 487 Expected: 200, 488 }, 489 }, 490 Logger: aws.NewDefaultLogger(), 491 NewRequest: func(opts []request.Option) (*request.Request, error) { 492 req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil) 493 req.HTTPResponse = &http.Response{StatusCode: http.StatusNotFound} 494 req.Handlers.Clear() 495 req.Data = struct{}{} 496 req.Handlers.Send.PushBack(func(r *request.Request) { 497 if reqCount == 1 { 498 ctx.Error = fmt.Errorf("context canceled") 499 close(ctx.DoneCh) 500 } 501 reqCount++ 502 }) 503 504 return req, nil 505 }, 506 } 507 508 w.SleepWithContext = func(c aws.Context, delay time.Duration) error { 509 context := c.(*awstesting.FakeContext) 510 select { 511 case <-context.DoneCh: 512 return context.Err() 513 default: 514 return nil 515 } 516 } 517 518 err := w.WaitWithContext(ctx) 519 520 if err == nil { 521 t.Fatalf("expect waiter to be canceled.") 522 } 523 aerr := err.(awserr.Error) 524 if e, a := request.CanceledErrorCode, aerr.Code(); e != a { 525 t.Errorf("expect %q error code, got %q", e, a) 526 } 527 if e, a := 2, reqCount; e != a { 528 t.Errorf("expect %d requests, got %d", e, a) 529 } 530} 531 532func TestWaiter_WithContext(t *testing.T) { 533 c := awstesting.NewClient() 534 535 ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})} 536 reqCount := 0 537 538 statuses := []int{http.StatusNotFound, http.StatusOK} 539 540 w := request.Waiter{ 541 Name: "TestWaiter", 542 MaxAttempts: 10, 543 Delay: request.ConstantWaiterDelay(1 * time.Millisecond), 544 SleepWithContext: aws.SleepWithContext, 545 Acceptors: []request.WaiterAcceptor{ 546 { 547 State: request.SuccessWaiterState, 548 Matcher: request.StatusWaiterMatch, 549 Expected: 200, 550 }, 551 }, 552 Logger: aws.NewDefaultLogger(), 553 NewRequest: func(opts []request.Option) (*request.Request, error) { 554 req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil) 555 req.HTTPResponse = &http.Response{StatusCode: statuses[reqCount]} 556 req.Handlers.Clear() 557 req.Data = struct{}{} 558 req.Handlers.Send.PushBack(func(r *request.Request) { 559 if reqCount == 1 { 560 ctx.Error = fmt.Errorf("context canceled") 561 close(ctx.DoneCh) 562 } 563 reqCount++ 564 }) 565 566 return req, nil 567 }, 568 } 569 570 err := w.WaitWithContext(ctx) 571 572 if err != nil { 573 t.Fatalf("expect no error, got %v", err) 574 } 575 if e, a := 2, reqCount; e != a { 576 t.Errorf("expect %d requests, got %d", e, a) 577 } 578} 579 580func TestWaiter_AttemptsExpires(t *testing.T) { 581 c := awstesting.NewClient() 582 583 ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})} 584 reqCount := 0 585 586 w := request.Waiter{ 587 Name: "TestWaiter", 588 MaxAttempts: 2, 589 Delay: request.ConstantWaiterDelay(1 * time.Millisecond), 590 SleepWithContext: aws.SleepWithContext, 591 Acceptors: []request.WaiterAcceptor{ 592 { 593 State: request.SuccessWaiterState, 594 Matcher: request.StatusWaiterMatch, 595 Expected: 200, 596 }, 597 }, 598 Logger: aws.NewDefaultLogger(), 599 NewRequest: func(opts []request.Option) (*request.Request, error) { 600 req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil) 601 req.HTTPResponse = &http.Response{StatusCode: http.StatusNotFound} 602 req.Handlers.Clear() 603 req.Data = struct{}{} 604 req.Handlers.Send.PushBack(func(r *request.Request) { 605 reqCount++ 606 }) 607 608 return req, nil 609 }, 610 } 611 612 err := w.WaitWithContext(ctx) 613 614 if err == nil { 615 t.Fatalf("expect error did not get one") 616 } 617 aerr := err.(awserr.Error) 618 if e, a := request.WaiterResourceNotReadyErrorCode, aerr.Code(); e != a { 619 t.Errorf("expect %q error code, got %q", e, a) 620 } 621 if e, a := 2, reqCount; e != a { 622 t.Errorf("expect %d requests, got %d", e, a) 623 } 624} 625 626func TestWaiterNilInput(t *testing.T) { 627 // Code generation doesn't have a great way to verify the code is correct 628 // other than being run via unit tests in the SDK. This should be fixed 629 // So code generation can be validated independently. 630 631 client := s3.New(unit.Session) 632 client.Handlers.Validate.Clear() 633 client.Handlers.Send.Clear() // mock sending 634 client.Handlers.Send.PushBack(func(r *request.Request) { 635 r.HTTPResponse = &http.Response{ 636 StatusCode: http.StatusOK, 637 } 638 }) 639 client.Handlers.Unmarshal.Clear() 640 client.Handlers.UnmarshalMeta.Clear() 641 client.Handlers.ValidateResponse.Clear() 642 client.Config.SleepDelay = func(dur time.Duration) {} 643 644 // Ensure waiters do not panic on nil input. It doesn't make sense to 645 // call a waiter without an input, Validation will 646 err := client.WaitUntilBucketExists(nil) 647 if err != nil { 648 t.Fatalf("expect no error, but got %v", err) 649 } 650} 651 652func TestWaiterWithContextNilInput(t *testing.T) { 653 // Code generation doesn't have a great way to verify the code is correct 654 // other than being run via unit tests in the SDK. This should be fixed 655 // So code generation can be validated independently. 656 657 client := s3.New(unit.Session) 658 client.Handlers.Validate.Clear() 659 client.Handlers.Send.Clear() // mock sending 660 client.Handlers.Send.PushBack(func(r *request.Request) { 661 r.HTTPResponse = &http.Response{ 662 StatusCode: http.StatusOK, 663 } 664 }) 665 client.Handlers.Unmarshal.Clear() 666 client.Handlers.UnmarshalMeta.Clear() 667 client.Handlers.ValidateResponse.Clear() 668 669 // Ensure waiters do not panic on nil input 670 ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})} 671 err := client.WaitUntilBucketExistsWithContext(ctx, nil, 672 request.WithWaiterDelay(request.ConstantWaiterDelay(0)), 673 request.WithWaiterMaxAttempts(1), 674 ) 675 if err != nil { 676 t.Fatalf("expect no error, but got %v", err) 677 } 678} 679