1package request_test
2
3import (
4	"bytes"
5	"fmt"
6	"io/ioutil"
7	"net/http"
8	"testing"
9	"time"
10
11	"github.com/stretchr/testify/assert"
12
13	"github.com/aws/aws-sdk-go/aws"
14	"github.com/aws/aws-sdk-go/aws/awserr"
15	"github.com/aws/aws-sdk-go/aws/client"
16	"github.com/aws/aws-sdk-go/aws/request"
17	"github.com/aws/aws-sdk-go/awstesting"
18	"github.com/aws/aws-sdk-go/awstesting/unit"
19	"github.com/aws/aws-sdk-go/service/s3"
20)
21
22type mockClient struct {
23	*client.Client
24}
25type MockInput struct{}
26type MockOutput struct {
27	States []*MockState
28}
29type MockState struct {
30	State *string
31}
32
33func (c *mockClient) MockRequest(input *MockInput) (*request.Request, *MockOutput) {
34	op := &request.Operation{
35		Name:       "Mock",
36		HTTPMethod: "POST",
37		HTTPPath:   "/",
38	}
39
40	if input == nil {
41		input = &MockInput{}
42	}
43
44	output := &MockOutput{}
45	req := c.NewRequest(op, input, output)
46	req.Data = output
47	return req, output
48}
49
50func BuildNewMockRequest(c *mockClient, in *MockInput) func([]request.Option) (*request.Request, error) {
51	return func(opts []request.Option) (*request.Request, error) {
52		req, _ := c.MockRequest(in)
53		req.ApplyOptions(opts...)
54		return req, nil
55	}
56}
57
58func TestWaiterPathAll(t *testing.T) {
59	svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
60		Region: aws.String("mock-region"),
61	})}
62	svc.Handlers.Send.Clear() // mock sending
63	svc.Handlers.Unmarshal.Clear()
64	svc.Handlers.UnmarshalMeta.Clear()
65	svc.Handlers.ValidateResponse.Clear()
66
67	reqNum := 0
68	resps := []*MockOutput{
69		{ // Request 1
70			States: []*MockState{
71				{State: aws.String("pending")},
72				{State: aws.String("pending")},
73			},
74		},
75		{ // Request 2
76			States: []*MockState{
77				{State: aws.String("running")},
78				{State: aws.String("pending")},
79			},
80		},
81		{ // Request 3
82			States: []*MockState{
83				{State: aws.String("running")},
84				{State: aws.String("running")},
85			},
86		},
87	}
88
89	numBuiltReq := 0
90	svc.Handlers.Build.PushBack(func(r *request.Request) {
91		numBuiltReq++
92	})
93	svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
94		if reqNum >= len(resps) {
95			assert.Fail(t, "too many polling requests made")
96			return
97		}
98		r.Data = resps[reqNum]
99		reqNum++
100	})
101
102	w := request.Waiter{
103		MaxAttempts:      10,
104		Delay:            request.ConstantWaiterDelay(0),
105		SleepWithContext: aws.SleepWithContext,
106		Acceptors: []request.WaiterAcceptor{
107			{
108				State:    request.SuccessWaiterState,
109				Matcher:  request.PathAllWaiterMatch,
110				Argument: "States[].State",
111				Expected: "running",
112			},
113		},
114		NewRequest: BuildNewMockRequest(svc, &MockInput{}),
115	}
116
117	err := w.WaitWithContext(aws.BackgroundContext())
118	assert.NoError(t, err)
119	assert.Equal(t, 3, numBuiltReq)
120	assert.Equal(t, 3, reqNum)
121}
122
123func TestWaiterPath(t *testing.T) {
124	svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
125		Region: aws.String("mock-region"),
126	})}
127	svc.Handlers.Send.Clear() // mock sending
128	svc.Handlers.Unmarshal.Clear()
129	svc.Handlers.UnmarshalMeta.Clear()
130	svc.Handlers.ValidateResponse.Clear()
131
132	reqNum := 0
133	resps := []*MockOutput{
134		{ // Request 1
135			States: []*MockState{
136				{State: aws.String("pending")},
137				{State: aws.String("pending")},
138			},
139		},
140		{ // Request 2
141			States: []*MockState{
142				{State: aws.String("running")},
143				{State: aws.String("pending")},
144			},
145		},
146		{ // Request 3
147			States: []*MockState{
148				{State: aws.String("running")},
149				{State: aws.String("running")},
150			},
151		},
152	}
153
154	numBuiltReq := 0
155	svc.Handlers.Build.PushBack(func(r *request.Request) {
156		numBuiltReq++
157	})
158	svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
159		if reqNum >= len(resps) {
160			assert.Fail(t, "too many polling requests made")
161			return
162		}
163		r.Data = resps[reqNum]
164		reqNum++
165	})
166
167	w := request.Waiter{
168		MaxAttempts:      10,
169		Delay:            request.ConstantWaiterDelay(0),
170		SleepWithContext: aws.SleepWithContext,
171		Acceptors: []request.WaiterAcceptor{
172			{
173				State:    request.SuccessWaiterState,
174				Matcher:  request.PathWaiterMatch,
175				Argument: "States[].State",
176				Expected: "running",
177			},
178		},
179		NewRequest: BuildNewMockRequest(svc, &MockInput{}),
180	}
181
182	err := w.WaitWithContext(aws.BackgroundContext())
183	assert.NoError(t, err)
184	assert.Equal(t, 3, numBuiltReq)
185	assert.Equal(t, 3, reqNum)
186}
187
188func TestWaiterFailure(t *testing.T) {
189	svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
190		Region: aws.String("mock-region"),
191	})}
192	svc.Handlers.Send.Clear() // mock sending
193	svc.Handlers.Unmarshal.Clear()
194	svc.Handlers.UnmarshalMeta.Clear()
195	svc.Handlers.ValidateResponse.Clear()
196
197	reqNum := 0
198	resps := []*MockOutput{
199		{ // Request 1
200			States: []*MockState{
201				{State: aws.String("pending")},
202				{State: aws.String("pending")},
203			},
204		},
205		{ // Request 2
206			States: []*MockState{
207				{State: aws.String("running")},
208				{State: aws.String("pending")},
209			},
210		},
211		{ // Request 3
212			States: []*MockState{
213				{State: aws.String("running")},
214				{State: aws.String("stopping")},
215			},
216		},
217	}
218
219	numBuiltReq := 0
220	svc.Handlers.Build.PushBack(func(r *request.Request) {
221		numBuiltReq++
222	})
223	svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
224		if reqNum >= len(resps) {
225			assert.Fail(t, "too many polling requests made")
226			return
227		}
228		r.Data = resps[reqNum]
229		reqNum++
230	})
231
232	w := request.Waiter{
233		MaxAttempts:      10,
234		Delay:            request.ConstantWaiterDelay(0),
235		SleepWithContext: aws.SleepWithContext,
236		Acceptors: []request.WaiterAcceptor{
237			{
238				State:    request.SuccessWaiterState,
239				Matcher:  request.PathAllWaiterMatch,
240				Argument: "States[].State",
241				Expected: "running",
242			},
243			{
244				State:    request.FailureWaiterState,
245				Matcher:  request.PathAnyWaiterMatch,
246				Argument: "States[].State",
247				Expected: "stopping",
248			},
249		},
250		NewRequest: BuildNewMockRequest(svc, &MockInput{}),
251	}
252
253	err := w.WaitWithContext(aws.BackgroundContext()).(awserr.Error)
254	assert.Error(t, err)
255	assert.Equal(t, request.WaiterResourceNotReadyErrorCode, err.Code())
256	assert.Equal(t, "failed waiting for successful resource state", err.Message())
257	assert.Equal(t, 3, numBuiltReq)
258	assert.Equal(t, 3, reqNum)
259}
260
261func TestWaiterError(t *testing.T) {
262	svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
263		Region: aws.String("mock-region"),
264	})}
265	svc.Handlers.Send.Clear() // mock sending
266	svc.Handlers.Unmarshal.Clear()
267	svc.Handlers.UnmarshalMeta.Clear()
268	svc.Handlers.UnmarshalError.Clear()
269	svc.Handlers.ValidateResponse.Clear()
270
271	reqNum := 0
272	resps := []*MockOutput{
273		{ // Request 1
274			States: []*MockState{
275				{State: aws.String("pending")},
276				{State: aws.String("pending")},
277			},
278		},
279		{ // Request 1, error case retry
280		},
281		{ // Request 2, error case failure
282		},
283		{ // Request 3
284			States: []*MockState{
285				{State: aws.String("running")},
286				{State: aws.String("running")},
287			},
288		},
289	}
290	reqErrs := make([]error, len(resps))
291	reqErrs[1] = awserr.New("MockException", "mock exception message", nil)
292	reqErrs[2] = awserr.New("FailureException", "mock failure exception message", nil)
293
294	numBuiltReq := 0
295	svc.Handlers.Build.PushBack(func(r *request.Request) {
296		numBuiltReq++
297	})
298	svc.Handlers.Send.PushBack(func(r *request.Request) {
299		code := 200
300		if reqNum == 1 {
301			code = 400
302		}
303		r.HTTPResponse = &http.Response{
304			StatusCode: code,
305			Status:     http.StatusText(code),
306			Body:       ioutil.NopCloser(bytes.NewReader([]byte{})),
307		}
308	})
309	svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
310		if reqNum >= len(resps) {
311			assert.Fail(t, "too many polling requests made")
312			return
313		}
314		r.Data = resps[reqNum]
315		reqNum++
316	})
317	svc.Handlers.UnmarshalMeta.PushBack(func(r *request.Request) {
318		// If there was an error unmarshal error will be called instead of unmarshal
319		// need to increment count here also
320		if err := reqErrs[reqNum]; err != nil {
321			r.Error = err
322			reqNum++
323		}
324	})
325
326	w := request.Waiter{
327		MaxAttempts:      10,
328		Delay:            request.ConstantWaiterDelay(0),
329		SleepWithContext: aws.SleepWithContext,
330		Acceptors: []request.WaiterAcceptor{
331			{
332				State:    request.SuccessWaiterState,
333				Matcher:  request.PathAllWaiterMatch,
334				Argument: "States[].State",
335				Expected: "running",
336			},
337			{
338				State:    request.RetryWaiterState,
339				Matcher:  request.ErrorWaiterMatch,
340				Argument: "",
341				Expected: "MockException",
342			},
343			{
344				State:    request.FailureWaiterState,
345				Matcher:  request.ErrorWaiterMatch,
346				Argument: "",
347				Expected: "FailureException",
348			},
349		},
350		NewRequest: BuildNewMockRequest(svc, &MockInput{}),
351	}
352
353	err := w.WaitWithContext(aws.BackgroundContext())
354	if err == nil {
355		t.Fatalf("expected error, but did not get one")
356	}
357	aerr := err.(awserr.Error)
358	if e, a := request.WaiterResourceNotReadyErrorCode, aerr.Code(); e != a {
359		t.Errorf("expect %q error code, got %q", e, a)
360	}
361	if e, a := 3, numBuiltReq; e != a {
362		t.Errorf("expect %d built requests got %d", e, a)
363	}
364	if e, a := 3, reqNum; e != a {
365		t.Errorf("expect %d reqNum got %d", e, a)
366	}
367}
368
369func TestWaiterStatus(t *testing.T) {
370	svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
371		Region: aws.String("mock-region"),
372	})}
373	svc.Handlers.Send.Clear() // mock sending
374	svc.Handlers.Unmarshal.Clear()
375	svc.Handlers.UnmarshalMeta.Clear()
376	svc.Handlers.ValidateResponse.Clear()
377
378	reqNum := 0
379	svc.Handlers.Build.PushBack(func(r *request.Request) {
380		reqNum++
381	})
382	svc.Handlers.Send.PushBack(func(r *request.Request) {
383		code := 200
384		if reqNum == 3 {
385			code = 404
386			r.Error = awserr.New("NotFound", "Not Found", nil)
387		}
388		r.HTTPResponse = &http.Response{
389			StatusCode: code,
390			Status:     http.StatusText(code),
391			Body:       ioutil.NopCloser(bytes.NewReader([]byte{})),
392		}
393	})
394
395	w := request.Waiter{
396		MaxAttempts:      10,
397		Delay:            request.ConstantWaiterDelay(0),
398		SleepWithContext: aws.SleepWithContext,
399		Acceptors: []request.WaiterAcceptor{
400			{
401				State:    request.SuccessWaiterState,
402				Matcher:  request.StatusWaiterMatch,
403				Argument: "",
404				Expected: 404,
405			},
406		},
407		NewRequest: BuildNewMockRequest(svc, &MockInput{}),
408	}
409
410	err := w.WaitWithContext(aws.BackgroundContext())
411	assert.NoError(t, err)
412	assert.Equal(t, 3, reqNum)
413}
414
415func TestWaiter_ApplyOptions(t *testing.T) {
416	w := request.Waiter{}
417
418	logger := aws.NewDefaultLogger()
419
420	w.ApplyOptions(
421		request.WithWaiterLogger(logger),
422		request.WithWaiterRequestOptions(request.WithLogLevel(aws.LogDebug)),
423		request.WithWaiterMaxAttempts(2),
424		request.WithWaiterDelay(request.ConstantWaiterDelay(5*time.Second)),
425	)
426
427	if e, a := logger, w.Logger; e != a {
428		t.Errorf("expect logger to be set, and match, was not, %v, %v", e, a)
429	}
430
431	if len(w.RequestOptions) != 1 {
432		t.Fatalf("expect request options to be set to only a single option, %v", w.RequestOptions)
433	}
434	r := request.Request{}
435	r.ApplyOptions(w.RequestOptions...)
436	if e, a := aws.LogDebug, r.Config.LogLevel.Value(); e != a {
437		t.Errorf("expect %v loglevel got %v", e, a)
438	}
439
440	if e, a := 2, w.MaxAttempts; e != a {
441		t.Errorf("expect %d retryer max attempts, got %d", e, a)
442	}
443	if e, a := 5*time.Second, w.Delay(0); e != a {
444		t.Errorf("expect %d retryer delay, got %d", e, a)
445	}
446}
447
448func TestWaiter_WithContextCanceled(t *testing.T) {
449	c := awstesting.NewClient()
450
451	ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
452	reqCount := 0
453
454	w := request.Waiter{
455		Name:             "TestWaiter",
456		MaxAttempts:      10,
457		Delay:            request.ConstantWaiterDelay(1 * time.Millisecond),
458		SleepWithContext: aws.SleepWithContext,
459		Acceptors: []request.WaiterAcceptor{
460			{
461				State:    request.SuccessWaiterState,
462				Matcher:  request.StatusWaiterMatch,
463				Expected: 200,
464			},
465		},
466		Logger: aws.NewDefaultLogger(),
467		NewRequest: func(opts []request.Option) (*request.Request, error) {
468			req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
469			req.HTTPResponse = &http.Response{StatusCode: http.StatusNotFound}
470			req.Handlers.Clear()
471			req.Data = struct{}{}
472			req.Handlers.Send.PushBack(func(r *request.Request) {
473				if reqCount == 1 {
474					ctx.Error = fmt.Errorf("context canceled")
475					close(ctx.DoneCh)
476				}
477				reqCount++
478			})
479
480			return req, nil
481		},
482	}
483
484	w.SleepWithContext = func(c aws.Context, delay time.Duration) error {
485		context := c.(*awstesting.FakeContext)
486		select {
487		case <-context.DoneCh:
488			return context.Err()
489		default:
490			return nil
491		}
492	}
493
494	err := w.WaitWithContext(ctx)
495
496	if err == nil {
497		t.Fatalf("expect waiter to be canceled.")
498	}
499	aerr := err.(awserr.Error)
500	if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
501		t.Errorf("expect %q error code, got %q", e, a)
502	}
503	if e, a := 2, reqCount; e != a {
504		t.Errorf("expect %d requests, got %d", e, a)
505	}
506}
507
508func TestWaiter_WithContext(t *testing.T) {
509	c := awstesting.NewClient()
510
511	ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
512	reqCount := 0
513
514	statuses := []int{http.StatusNotFound, http.StatusOK}
515
516	w := request.Waiter{
517		Name:             "TestWaiter",
518		MaxAttempts:      10,
519		Delay:            request.ConstantWaiterDelay(1 * time.Millisecond),
520		SleepWithContext: aws.SleepWithContext,
521		Acceptors: []request.WaiterAcceptor{
522			{
523				State:    request.SuccessWaiterState,
524				Matcher:  request.StatusWaiterMatch,
525				Expected: 200,
526			},
527		},
528		Logger: aws.NewDefaultLogger(),
529		NewRequest: func(opts []request.Option) (*request.Request, error) {
530			req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
531			req.HTTPResponse = &http.Response{StatusCode: statuses[reqCount]}
532			req.Handlers.Clear()
533			req.Data = struct{}{}
534			req.Handlers.Send.PushBack(func(r *request.Request) {
535				if reqCount == 1 {
536					ctx.Error = fmt.Errorf("context canceled")
537					close(ctx.DoneCh)
538				}
539				reqCount++
540			})
541
542			return req, nil
543		},
544	}
545
546	err := w.WaitWithContext(ctx)
547
548	if err != nil {
549		t.Fatalf("expect no error, got %v", err)
550	}
551	if e, a := 2, reqCount; e != a {
552		t.Errorf("expect %d requests, got %d", e, a)
553	}
554}
555
556func TestWaiter_AttemptsExpires(t *testing.T) {
557	c := awstesting.NewClient()
558
559	ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
560	reqCount := 0
561
562	w := request.Waiter{
563		Name:             "TestWaiter",
564		MaxAttempts:      2,
565		Delay:            request.ConstantWaiterDelay(1 * time.Millisecond),
566		SleepWithContext: aws.SleepWithContext,
567		Acceptors: []request.WaiterAcceptor{
568			{
569				State:    request.SuccessWaiterState,
570				Matcher:  request.StatusWaiterMatch,
571				Expected: 200,
572			},
573		},
574		Logger: aws.NewDefaultLogger(),
575		NewRequest: func(opts []request.Option) (*request.Request, error) {
576			req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
577			req.HTTPResponse = &http.Response{StatusCode: http.StatusNotFound}
578			req.Handlers.Clear()
579			req.Data = struct{}{}
580			req.Handlers.Send.PushBack(func(r *request.Request) {
581				reqCount++
582			})
583
584			return req, nil
585		},
586	}
587
588	err := w.WaitWithContext(ctx)
589
590	if err == nil {
591		t.Fatalf("expect error did not get one")
592	}
593	aerr := err.(awserr.Error)
594	if e, a := request.WaiterResourceNotReadyErrorCode, aerr.Code(); e != a {
595		t.Errorf("expect %q error code, got %q", e, a)
596	}
597	if e, a := 2, reqCount; e != a {
598		t.Errorf("expect %d requests, got %d", e, a)
599	}
600}
601
602func TestWaiterNilInput(t *testing.T) {
603	// Code generation doesn't have a great way to verify the code is correct
604	// other than being run via unit tests in the SDK. This should be fixed
605	// So code generation can be validated independently.
606
607	client := s3.New(unit.Session)
608	client.Handlers.Validate.Clear()
609	client.Handlers.Send.Clear() // mock sending
610	client.Handlers.Send.PushBack(func(r *request.Request) {
611		r.HTTPResponse = &http.Response{
612			StatusCode: http.StatusOK,
613		}
614	})
615	client.Handlers.Unmarshal.Clear()
616	client.Handlers.UnmarshalMeta.Clear()
617	client.Handlers.ValidateResponse.Clear()
618	client.Config.SleepDelay = func(dur time.Duration) {}
619
620	// Ensure waiters do not panic on nil input. It doesn't make sense to
621	// call a waiter without an input, Validation will
622	err := client.WaitUntilBucketExists(nil)
623	if err != nil {
624		t.Fatalf("expect no error, but got %v", err)
625	}
626}
627
628func TestWaiterWithContextNilInput(t *testing.T) {
629	// Code generation doesn't have a great way to verify the code is correct
630	// other than being run via unit tests in the SDK. This should be fixed
631	// So code generation can be validated independently.
632
633	client := s3.New(unit.Session)
634	client.Handlers.Validate.Clear()
635	client.Handlers.Send.Clear() // mock sending
636	client.Handlers.Send.PushBack(func(r *request.Request) {
637		r.HTTPResponse = &http.Response{
638			StatusCode: http.StatusOK,
639		}
640	})
641	client.Handlers.Unmarshal.Clear()
642	client.Handlers.UnmarshalMeta.Clear()
643	client.Handlers.ValidateResponse.Clear()
644
645	// Ensure waiters do not panic on nil input
646	ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
647	err := client.WaitUntilBucketExistsWithContext(ctx, nil,
648		request.WithWaiterDelay(request.ConstantWaiterDelay(0)),
649		request.WithWaiterMaxAttempts(1),
650	)
651	if err != nil {
652		t.Fatalf("expect no error, but got %v", err)
653	}
654}
655