1package request_test
2
3import (
4	"bytes"
5	"fmt"
6	"io/ioutil"
7	"net/http"
8	"testing"
9	"time"
10
11	"github.com/aws/aws-sdk-go/aws"
12	"github.com/aws/aws-sdk-go/aws/awserr"
13	"github.com/aws/aws-sdk-go/aws/client"
14	"github.com/aws/aws-sdk-go/aws/request"
15	"github.com/aws/aws-sdk-go/awstesting"
16	"github.com/aws/aws-sdk-go/awstesting/unit"
17	"github.com/aws/aws-sdk-go/service/s3"
18)
19
20type mockClient struct {
21	*client.Client
22}
23type MockInput struct{}
24type MockOutput struct {
25	States []*MockState
26}
27type MockState struct {
28	State *string
29}
30
31func (c *mockClient) MockRequest(input *MockInput) (*request.Request, *MockOutput) {
32	op := &request.Operation{
33		Name:       "Mock",
34		HTTPMethod: "POST",
35		HTTPPath:   "/",
36	}
37
38	if input == nil {
39		input = &MockInput{}
40	}
41
42	output := &MockOutput{}
43	req := c.NewRequest(op, input, output)
44	req.Data = output
45	return req, output
46}
47
48func BuildNewMockRequest(c *mockClient, in *MockInput) func([]request.Option) (*request.Request, error) {
49	return func(opts []request.Option) (*request.Request, error) {
50		req, _ := c.MockRequest(in)
51		req.ApplyOptions(opts...)
52		return req, nil
53	}
54}
55
56func TestWaiterPathAll(t *testing.T) {
57	svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
58		Region: aws.String("mock-region"),
59	})}
60	svc.Handlers.Send.Clear() // mock sending
61	svc.Handlers.Unmarshal.Clear()
62	svc.Handlers.UnmarshalMeta.Clear()
63	svc.Handlers.ValidateResponse.Clear()
64
65	reqNum := 0
66	resps := []*MockOutput{
67		{ // Request 1
68			States: []*MockState{
69				{State: aws.String("pending")},
70				{State: aws.String("pending")},
71			},
72		},
73		{ // Request 2
74			States: []*MockState{
75				{State: aws.String("running")},
76				{State: aws.String("pending")},
77			},
78		},
79		{ // Request 3
80			States: []*MockState{
81				{State: aws.String("running")},
82				{State: aws.String("running")},
83			},
84		},
85	}
86
87	numBuiltReq := 0
88	svc.Handlers.Build.PushBack(func(r *request.Request) {
89		numBuiltReq++
90	})
91	svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
92		if reqNum >= len(resps) {
93			t.Errorf("too many polling requests made")
94			return
95		}
96		r.Data = resps[reqNum]
97		reqNum++
98	})
99
100	w := request.Waiter{
101		MaxAttempts:      10,
102		Delay:            request.ConstantWaiterDelay(0),
103		SleepWithContext: aws.SleepWithContext,
104		Acceptors: []request.WaiterAcceptor{
105			{
106				State:    request.SuccessWaiterState,
107				Matcher:  request.PathAllWaiterMatch,
108				Argument: "States[].State",
109				Expected: "running",
110			},
111		},
112		NewRequest: BuildNewMockRequest(svc, &MockInput{}),
113	}
114
115	err := w.WaitWithContext(aws.BackgroundContext())
116	if err != nil {
117		t.Errorf("expect nil, %v", err)
118	}
119	if e, a := 3, numBuiltReq; e != a {
120		t.Errorf("expect %v, got %v", e, a)
121	}
122	if e, a := 3, reqNum; e != a {
123		t.Errorf("expect %v, got %v", e, a)
124	}
125}
126
127func TestWaiterPath(t *testing.T) {
128	svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
129		Region: aws.String("mock-region"),
130	})}
131	svc.Handlers.Send.Clear() // mock sending
132	svc.Handlers.Unmarshal.Clear()
133	svc.Handlers.UnmarshalMeta.Clear()
134	svc.Handlers.ValidateResponse.Clear()
135
136	reqNum := 0
137	resps := []*MockOutput{
138		{ // Request 1
139			States: []*MockState{
140				{State: aws.String("pending")},
141				{State: aws.String("pending")},
142			},
143		},
144		{ // Request 2
145			States: []*MockState{
146				{State: aws.String("running")},
147				{State: aws.String("pending")},
148			},
149		},
150		{ // Request 3
151			States: []*MockState{
152				{State: aws.String("running")},
153				{State: aws.String("running")},
154			},
155		},
156	}
157
158	numBuiltReq := 0
159	svc.Handlers.Build.PushBack(func(r *request.Request) {
160		numBuiltReq++
161	})
162	svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
163		if reqNum >= len(resps) {
164			t.Errorf("too many polling requests made")
165			return
166		}
167		r.Data = resps[reqNum]
168		reqNum++
169	})
170
171	w := request.Waiter{
172		MaxAttempts:      10,
173		Delay:            request.ConstantWaiterDelay(0),
174		SleepWithContext: aws.SleepWithContext,
175		Acceptors: []request.WaiterAcceptor{
176			{
177				State:    request.SuccessWaiterState,
178				Matcher:  request.PathWaiterMatch,
179				Argument: "States[].State",
180				Expected: "running",
181			},
182		},
183		NewRequest: BuildNewMockRequest(svc, &MockInput{}),
184	}
185
186	err := w.WaitWithContext(aws.BackgroundContext())
187	if err != nil {
188		t.Errorf("expect nil, %v", err)
189	}
190	if e, a := 3, numBuiltReq; e != a {
191		t.Errorf("expect %v, got %v", e, a)
192	}
193	if e, a := 3, reqNum; e != a {
194		t.Errorf("expect %v, got %v", e, a)
195	}
196}
197
198func TestWaiterFailure(t *testing.T) {
199	svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
200		Region: aws.String("mock-region"),
201	})}
202	svc.Handlers.Send.Clear() // mock sending
203	svc.Handlers.Unmarshal.Clear()
204	svc.Handlers.UnmarshalMeta.Clear()
205	svc.Handlers.ValidateResponse.Clear()
206
207	reqNum := 0
208	resps := []*MockOutput{
209		{ // Request 1
210			States: []*MockState{
211				{State: aws.String("pending")},
212				{State: aws.String("pending")},
213			},
214		},
215		{ // Request 2
216			States: []*MockState{
217				{State: aws.String("running")},
218				{State: aws.String("pending")},
219			},
220		},
221		{ // Request 3
222			States: []*MockState{
223				{State: aws.String("running")},
224				{State: aws.String("stopping")},
225			},
226		},
227	}
228
229	numBuiltReq := 0
230	svc.Handlers.Build.PushBack(func(r *request.Request) {
231		numBuiltReq++
232	})
233	svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
234		if reqNum >= len(resps) {
235			t.Errorf("too many polling requests made")
236			return
237		}
238		r.Data = resps[reqNum]
239		reqNum++
240	})
241
242	w := request.Waiter{
243		MaxAttempts:      10,
244		Delay:            request.ConstantWaiterDelay(0),
245		SleepWithContext: aws.SleepWithContext,
246		Acceptors: []request.WaiterAcceptor{
247			{
248				State:    request.SuccessWaiterState,
249				Matcher:  request.PathAllWaiterMatch,
250				Argument: "States[].State",
251				Expected: "running",
252			},
253			{
254				State:    request.FailureWaiterState,
255				Matcher:  request.PathAnyWaiterMatch,
256				Argument: "States[].State",
257				Expected: "stopping",
258			},
259		},
260		NewRequest: BuildNewMockRequest(svc, &MockInput{}),
261	}
262
263	err := w.WaitWithContext(aws.BackgroundContext()).(awserr.Error)
264	if err == nil {
265		t.Errorf("expect error")
266	}
267	if e, a := request.WaiterResourceNotReadyErrorCode, err.Code(); e != a {
268		t.Errorf("expect %v, got %v", e, a)
269	}
270	if e, a := "failed waiting for successful resource state", err.Message(); e != a {
271		t.Errorf("expect %v, got %v", e, a)
272	}
273	if e, a := 3, numBuiltReq; e != a {
274		t.Errorf("expect %v, got %v", e, a)
275	}
276	if e, a := 3, reqNum; e != a {
277		t.Errorf("expect %v, got %v", e, a)
278	}
279}
280
281func TestWaiterError(t *testing.T) {
282	svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
283		Region: aws.String("mock-region"),
284	})}
285	svc.Handlers.Send.Clear() // mock sending
286	svc.Handlers.Unmarshal.Clear()
287	svc.Handlers.UnmarshalMeta.Clear()
288	svc.Handlers.UnmarshalError.Clear()
289	svc.Handlers.ValidateResponse.Clear()
290
291	reqNum := 0
292	resps := []*MockOutput{
293		{ // Request 1
294			States: []*MockState{
295				{State: aws.String("pending")},
296				{State: aws.String("pending")},
297			},
298		},
299		{ // Request 1, error case retry
300		},
301		{ // Request 2, error case failure
302		},
303		{ // Request 3
304			States: []*MockState{
305				{State: aws.String("running")},
306				{State: aws.String("running")},
307			},
308		},
309	}
310	reqErrs := make([]error, len(resps))
311	reqErrs[1] = awserr.New("MockException", "mock exception message", nil)
312	reqErrs[2] = awserr.New("FailureException", "mock failure exception message", nil)
313
314	numBuiltReq := 0
315	svc.Handlers.Build.PushBack(func(r *request.Request) {
316		numBuiltReq++
317	})
318	svc.Handlers.Send.PushBack(func(r *request.Request) {
319		code := 200
320		if reqNum == 1 {
321			code = 400
322		}
323		r.HTTPResponse = &http.Response{
324			StatusCode: code,
325			Status:     http.StatusText(code),
326			Body:       ioutil.NopCloser(bytes.NewReader([]byte{})),
327		}
328	})
329	svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
330		if reqNum >= len(resps) {
331			t.Errorf("too many polling requests made")
332			return
333		}
334		r.Data = resps[reqNum]
335		reqNum++
336	})
337	svc.Handlers.UnmarshalMeta.PushBack(func(r *request.Request) {
338		// If there was an error unmarshal error will be called instead of unmarshal
339		// need to increment count here also
340		if err := reqErrs[reqNum]; err != nil {
341			r.Error = err
342			reqNum++
343		}
344	})
345
346	w := request.Waiter{
347		MaxAttempts:      10,
348		Delay:            request.ConstantWaiterDelay(0),
349		SleepWithContext: aws.SleepWithContext,
350		Acceptors: []request.WaiterAcceptor{
351			{
352				State:    request.SuccessWaiterState,
353				Matcher:  request.PathAllWaiterMatch,
354				Argument: "States[].State",
355				Expected: "running",
356			},
357			{
358				State:    request.RetryWaiterState,
359				Matcher:  request.ErrorWaiterMatch,
360				Argument: "",
361				Expected: "MockException",
362			},
363			{
364				State:    request.FailureWaiterState,
365				Matcher:  request.ErrorWaiterMatch,
366				Argument: "",
367				Expected: "FailureException",
368			},
369		},
370		NewRequest: BuildNewMockRequest(svc, &MockInput{}),
371	}
372
373	err := w.WaitWithContext(aws.BackgroundContext())
374	if err == nil {
375		t.Fatalf("expected error, but did not get one")
376	}
377	aerr := err.(awserr.Error)
378	if e, a := request.WaiterResourceNotReadyErrorCode, aerr.Code(); e != a {
379		t.Errorf("expect %q error code, got %q", e, a)
380	}
381	if e, a := 3, numBuiltReq; e != a {
382		t.Errorf("expect %d built requests got %d", e, a)
383	}
384	if e, a := 3, reqNum; e != a {
385		t.Errorf("expect %d reqNum got %d", e, a)
386	}
387}
388
389func TestWaiterStatus(t *testing.T) {
390	svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
391		Region: aws.String("mock-region"),
392	})}
393	svc.Handlers.Send.Clear() // mock sending
394	svc.Handlers.Unmarshal.Clear()
395	svc.Handlers.UnmarshalMeta.Clear()
396	svc.Handlers.ValidateResponse.Clear()
397
398	reqNum := 0
399	svc.Handlers.Build.PushBack(func(r *request.Request) {
400		reqNum++
401	})
402	svc.Handlers.Send.PushBack(func(r *request.Request) {
403		code := 200
404		if reqNum == 3 {
405			code = 404
406			r.Error = awserr.New("NotFound", "Not Found", nil)
407		}
408		r.HTTPResponse = &http.Response{
409			StatusCode: code,
410			Status:     http.StatusText(code),
411			Body:       ioutil.NopCloser(bytes.NewReader([]byte{})),
412		}
413	})
414
415	w := request.Waiter{
416		MaxAttempts:      10,
417		Delay:            request.ConstantWaiterDelay(0),
418		SleepWithContext: aws.SleepWithContext,
419		Acceptors: []request.WaiterAcceptor{
420			{
421				State:    request.SuccessWaiterState,
422				Matcher:  request.StatusWaiterMatch,
423				Argument: "",
424				Expected: 404,
425			},
426		},
427		NewRequest: BuildNewMockRequest(svc, &MockInput{}),
428	}
429
430	err := w.WaitWithContext(aws.BackgroundContext())
431	if err != nil {
432		t.Errorf("expect nil, %v", err)
433	}
434	if e, a := 3, reqNum; e != a {
435		t.Errorf("expect %v, got %v", e, a)
436	}
437}
438
439func TestWaiter_ApplyOptions(t *testing.T) {
440	w := request.Waiter{}
441
442	logger := aws.NewDefaultLogger()
443
444	w.ApplyOptions(
445		request.WithWaiterLogger(logger),
446		request.WithWaiterRequestOptions(request.WithLogLevel(aws.LogDebug)),
447		request.WithWaiterMaxAttempts(2),
448		request.WithWaiterDelay(request.ConstantWaiterDelay(5*time.Second)),
449	)
450
451	if e, a := logger, w.Logger; e != a {
452		t.Errorf("expect logger to be set, and match, was not, %v, %v", e, a)
453	}
454
455	if len(w.RequestOptions) != 1 {
456		t.Fatalf("expect request options to be set to only a single option, %v", w.RequestOptions)
457	}
458	r := request.Request{}
459	r.ApplyOptions(w.RequestOptions...)
460	if e, a := aws.LogDebug, r.Config.LogLevel.Value(); e != a {
461		t.Errorf("expect %v loglevel got %v", e, a)
462	}
463
464	if e, a := 2, w.MaxAttempts; e != a {
465		t.Errorf("expect %d retryer max attempts, got %d", e, a)
466	}
467	if e, a := 5*time.Second, w.Delay(0); e != a {
468		t.Errorf("expect %d retryer delay, got %d", e, a)
469	}
470}
471
472func TestWaiter_WithContextCanceled(t *testing.T) {
473	c := awstesting.NewClient()
474
475	ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
476	reqCount := 0
477
478	w := request.Waiter{
479		Name:             "TestWaiter",
480		MaxAttempts:      10,
481		Delay:            request.ConstantWaiterDelay(1 * time.Millisecond),
482		SleepWithContext: aws.SleepWithContext,
483		Acceptors: []request.WaiterAcceptor{
484			{
485				State:    request.SuccessWaiterState,
486				Matcher:  request.StatusWaiterMatch,
487				Expected: 200,
488			},
489		},
490		Logger: aws.NewDefaultLogger(),
491		NewRequest: func(opts []request.Option) (*request.Request, error) {
492			req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
493			req.HTTPResponse = &http.Response{StatusCode: http.StatusNotFound}
494			req.Handlers.Clear()
495			req.Data = struct{}{}
496			req.Handlers.Send.PushBack(func(r *request.Request) {
497				if reqCount == 1 {
498					ctx.Error = fmt.Errorf("context canceled")
499					close(ctx.DoneCh)
500				}
501				reqCount++
502			})
503
504			return req, nil
505		},
506	}
507
508	w.SleepWithContext = func(c aws.Context, delay time.Duration) error {
509		context := c.(*awstesting.FakeContext)
510		select {
511		case <-context.DoneCh:
512			return context.Err()
513		default:
514			return nil
515		}
516	}
517
518	err := w.WaitWithContext(ctx)
519
520	if err == nil {
521		t.Fatalf("expect waiter to be canceled.")
522	}
523	aerr := err.(awserr.Error)
524	if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
525		t.Errorf("expect %q error code, got %q", e, a)
526	}
527	if e, a := 2, reqCount; e != a {
528		t.Errorf("expect %d requests, got %d", e, a)
529	}
530}
531
532func TestWaiter_WithContext(t *testing.T) {
533	c := awstesting.NewClient()
534
535	ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
536	reqCount := 0
537
538	statuses := []int{http.StatusNotFound, http.StatusOK}
539
540	w := request.Waiter{
541		Name:             "TestWaiter",
542		MaxAttempts:      10,
543		Delay:            request.ConstantWaiterDelay(1 * time.Millisecond),
544		SleepWithContext: aws.SleepWithContext,
545		Acceptors: []request.WaiterAcceptor{
546			{
547				State:    request.SuccessWaiterState,
548				Matcher:  request.StatusWaiterMatch,
549				Expected: 200,
550			},
551		},
552		Logger: aws.NewDefaultLogger(),
553		NewRequest: func(opts []request.Option) (*request.Request, error) {
554			req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
555			req.HTTPResponse = &http.Response{StatusCode: statuses[reqCount]}
556			req.Handlers.Clear()
557			req.Data = struct{}{}
558			req.Handlers.Send.PushBack(func(r *request.Request) {
559				if reqCount == 1 {
560					ctx.Error = fmt.Errorf("context canceled")
561					close(ctx.DoneCh)
562				}
563				reqCount++
564			})
565
566			return req, nil
567		},
568	}
569
570	err := w.WaitWithContext(ctx)
571
572	if err != nil {
573		t.Fatalf("expect no error, got %v", err)
574	}
575	if e, a := 2, reqCount; e != a {
576		t.Errorf("expect %d requests, got %d", e, a)
577	}
578}
579
580func TestWaiter_AttemptsExpires(t *testing.T) {
581	c := awstesting.NewClient()
582
583	ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
584	reqCount := 0
585
586	w := request.Waiter{
587		Name:             "TestWaiter",
588		MaxAttempts:      2,
589		Delay:            request.ConstantWaiterDelay(1 * time.Millisecond),
590		SleepWithContext: aws.SleepWithContext,
591		Acceptors: []request.WaiterAcceptor{
592			{
593				State:    request.SuccessWaiterState,
594				Matcher:  request.StatusWaiterMatch,
595				Expected: 200,
596			},
597		},
598		Logger: aws.NewDefaultLogger(),
599		NewRequest: func(opts []request.Option) (*request.Request, error) {
600			req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
601			req.HTTPResponse = &http.Response{StatusCode: http.StatusNotFound}
602			req.Handlers.Clear()
603			req.Data = struct{}{}
604			req.Handlers.Send.PushBack(func(r *request.Request) {
605				reqCount++
606			})
607
608			return req, nil
609		},
610	}
611
612	err := w.WaitWithContext(ctx)
613
614	if err == nil {
615		t.Fatalf("expect error did not get one")
616	}
617	aerr := err.(awserr.Error)
618	if e, a := request.WaiterResourceNotReadyErrorCode, aerr.Code(); e != a {
619		t.Errorf("expect %q error code, got %q", e, a)
620	}
621	if e, a := 2, reqCount; e != a {
622		t.Errorf("expect %d requests, got %d", e, a)
623	}
624}
625
626func TestWaiterNilInput(t *testing.T) {
627	// Code generation doesn't have a great way to verify the code is correct
628	// other than being run via unit tests in the SDK. This should be fixed
629	// So code generation can be validated independently.
630
631	client := s3.New(unit.Session)
632	client.Handlers.Validate.Clear()
633	client.Handlers.Send.Clear() // mock sending
634	client.Handlers.Send.PushBack(func(r *request.Request) {
635		r.HTTPResponse = &http.Response{
636			StatusCode: http.StatusOK,
637		}
638	})
639	client.Handlers.Unmarshal.Clear()
640	client.Handlers.UnmarshalMeta.Clear()
641	client.Handlers.ValidateResponse.Clear()
642	client.Config.SleepDelay = func(dur time.Duration) {}
643
644	// Ensure waiters do not panic on nil input. It doesn't make sense to
645	// call a waiter without an input, Validation will
646	err := client.WaitUntilBucketExists(nil)
647	if err != nil {
648		t.Fatalf("expect no error, but got %v", err)
649	}
650}
651
652func TestWaiterWithContextNilInput(t *testing.T) {
653	// Code generation doesn't have a great way to verify the code is correct
654	// other than being run via unit tests in the SDK. This should be fixed
655	// So code generation can be validated independently.
656
657	client := s3.New(unit.Session)
658	client.Handlers.Validate.Clear()
659	client.Handlers.Send.Clear() // mock sending
660	client.Handlers.Send.PushBack(func(r *request.Request) {
661		r.HTTPResponse = &http.Response{
662			StatusCode: http.StatusOK,
663		}
664	})
665	client.Handlers.Unmarshal.Clear()
666	client.Handlers.UnmarshalMeta.Clear()
667	client.Handlers.ValidateResponse.Clear()
668
669	// Ensure waiters do not panic on nil input
670	ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
671	err := client.WaitUntilBucketExistsWithContext(ctx, nil,
672		request.WithWaiterDelay(request.ConstantWaiterDelay(0)),
673		request.WithWaiterMaxAttempts(1),
674	)
675	if err != nil {
676		t.Fatalf("expect no error, but got %v", err)
677	}
678}
679