1package retry
2
3import (
4	"context"
5	"fmt"
6	"net/http"
7	"reflect"
8	"strconv"
9	"strings"
10	"testing"
11	"time"
12
13	"github.com/aws/aws-sdk-go-v2/aws"
14	"github.com/aws/aws-sdk-go-v2/internal/sdk"
15	"github.com/aws/smithy-go/middleware"
16	smithyhttp "github.com/aws/smithy-go/transport/http"
17	"github.com/google/go-cmp/cmp"
18)
19
20func TestMetricsHeaderMiddleware(t *testing.T) {
21	cases := []struct {
22		input          middleware.FinalizeInput
23		ctx            context.Context
24		expectedHeader string
25		expectedErr    string
26	}{
27		{
28			input: middleware.FinalizeInput{Request: &smithyhttp.Request{Request: &http.Request{Header: make(http.Header)}}},
29			ctx: func() context.Context {
30				return setRetryMetadata(context.Background(), retryMetadata{
31					AttemptNum:       0,
32					AttemptTime:      time.Date(2020, 01, 02, 03, 04, 05, 0, time.UTC),
33					MaxAttempts:      5,
34					AttemptClockSkew: 0,
35				})
36			}(),
37			expectedHeader: "attempt=0; max=5",
38		},
39		{
40			input: middleware.FinalizeInput{Request: &smithyhttp.Request{Request: &http.Request{Header: make(http.Header)}}},
41			ctx: func() context.Context {
42				attemptTime := time.Date(2020, 01, 02, 03, 04, 05, 0, time.UTC)
43				ctx, cancel := context.WithDeadline(context.Background(), attemptTime.Add(time.Minute))
44				defer cancel()
45				return setRetryMetadata(ctx, retryMetadata{
46					AttemptNum:       1,
47					AttemptTime:      attemptTime,
48					MaxAttempts:      5,
49					AttemptClockSkew: time.Second * 1,
50				})
51			}(),
52			expectedHeader: "attempt=1; max=5; ttl=20200102T030506Z",
53		},
54		{
55			ctx: func() context.Context {
56				return setRetryMetadata(context.Background(), retryMetadata{})
57			}(),
58			expectedErr: "unknown transport type",
59		},
60	}
61
62	retryMiddleware := MetricsHeader{}
63	for i, tt := range cases {
64		t.Run(strconv.Itoa(i), func(t *testing.T) {
65			ctx := tt.ctx
66			_, _, err := retryMiddleware.HandleFinalize(ctx, tt.input, middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (
67				out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
68			) {
69				req := in.Request.(*smithyhttp.Request)
70
71				if e, a := tt.expectedHeader, req.Header.Get("amz-sdk-request"); e != a {
72					t.Errorf("expected %v, got %v", e, a)
73				}
74
75				return out, metadata, err
76			}))
77			if err != nil && len(tt.expectedErr) == 0 {
78				t.Fatalf("expected no error, got %q", err)
79			} else if err != nil && len(tt.expectedErr) != 0 {
80				if e, a := tt.expectedErr, err.Error(); !strings.Contains(a, e) {
81					t.Fatalf("expected %q, got %q", e, a)
82				}
83			} else if err == nil && len(tt.expectedErr) != 0 {
84				t.Fatalf("expected error, got nil")
85			}
86		})
87	}
88}
89
90type retryProvider struct {
91	Retryer aws.Retryer
92}
93
94func (t retryProvider) GetRetryer() aws.Retryer {
95	return t.Retryer
96}
97
98type mockHandler func(context.Context, interface{}) (interface{}, middleware.Metadata, error)
99
100func (m mockHandler) Handle(ctx context.Context, input interface{}) (output interface{}, metadata middleware.Metadata, err error) {
101	return m(ctx, input)
102}
103
104func (m mockHandler) ID() string {
105	return fmt.Sprintf("%T", m)
106}
107
108type testRequest struct {
109	DisableRewind bool
110}
111
112func (r testRequest) RewindStream() error {
113	if r.DisableRewind {
114		return fmt.Errorf("rewind disabled")
115	}
116	return nil
117}
118
119type mockRetryableError struct{ b bool }
120
121func (m mockRetryableError) RetryableError() bool { return m.b }
122func (m mockRetryableError) Error() string {
123	return fmt.Sprintf("mock retryable %t", m.b)
124}
125
126func TestAttemptMiddleware(t *testing.T) {
127	restoreSleep := sdk.TestingUseNopSleep()
128	defer restoreSleep()
129
130	sdkTime := sdk.NowTime
131	defer func() {
132		sdk.NowTime = sdkTime
133	}()
134
135	cases := map[string]struct {
136		Request       testRequest
137		Next          func(retries *[]retryMetadata) middleware.FinalizeHandler
138		Expect        []retryMetadata
139		Err           error
140		ExpectResults AttemptResults
141	}{
142		"no error, no response in a single attempt": {
143			Next: func(retries *[]retryMetadata) middleware.FinalizeHandler {
144				return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
145					m, ok := getRetryMetadata(ctx)
146					if ok {
147						*retries = append(*retries, m)
148					}
149					return out, metadata, err
150				})
151			},
152			Expect: []retryMetadata{
153				{
154					AttemptNum:  1,
155					AttemptTime: time.Date(2020, 8, 19, 10, 20, 30, 0, time.UTC),
156					MaxAttempts: 3,
157				},
158			},
159			ExpectResults: AttemptResults{Results: []AttemptResult{
160				{},
161			}},
162		},
163		"no error in a single attempt": {
164			Next: func(retries *[]retryMetadata) middleware.FinalizeHandler {
165				return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
166					m, ok := getRetryMetadata(ctx)
167					if ok {
168						*retries = append(*retries, m)
169					}
170					setMockRawResponse(&metadata, "mockResponse")
171					return out, metadata, err
172				})
173			},
174			Expect: []retryMetadata{
175				{
176					AttemptNum:  1,
177					AttemptTime: time.Date(2020, 8, 19, 10, 20, 30, 0, time.UTC),
178					MaxAttempts: 3,
179				},
180			},
181			ExpectResults: AttemptResults{Results: []AttemptResult{
182				{
183					ResponseMetadata: func() middleware.Metadata {
184						m := middleware.Metadata{}
185						setMockRawResponse(&m, "mockResponse")
186						return m
187					}(),
188				},
189			}},
190		},
191		"retries errors": {
192			Next: func(retries *[]retryMetadata) middleware.FinalizeHandler {
193				num := 0
194				reqsErrs := []error{
195					mockRetryableError{b: true},
196					mockRetryableError{b: true},
197					nil,
198				}
199				return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
200					m, ok := getRetryMetadata(ctx)
201					if ok {
202						*retries = append(*retries, m)
203					}
204					if num >= len(reqsErrs) {
205						err = fmt.Errorf("more requests then expected")
206					} else {
207						err = reqsErrs[num]
208						num++
209					}
210					return out, metadata, err
211				})
212			},
213			Expect: []retryMetadata{
214				{
215					AttemptNum:  1,
216					AttemptTime: time.Date(2020, 8, 19, 10, 20, 30, 0, time.UTC),
217					MaxAttempts: 3,
218				},
219				{
220					AttemptNum:  2,
221					AttemptTime: time.Date(2020, 8, 19, 10, 21, 30, 0, time.UTC),
222					MaxAttempts: 3,
223				},
224				{
225					AttemptNum:  3,
226					AttemptTime: time.Date(2020, 8, 19, 10, 22, 30, 0, time.UTC),
227					MaxAttempts: 3,
228				},
229			},
230			ExpectResults: AttemptResults{Results: []AttemptResult{
231				{
232					Err:       mockRetryableError{b: true},
233					Retryable: true,
234					Retried:   true,
235				},
236				{
237					Err:       mockRetryableError{b: true},
238					Retryable: true,
239					Retried:   true,
240				},
241				{},
242			}},
243		},
244		"stops after max attempts": {
245			Next: func(retries *[]retryMetadata) middleware.FinalizeHandler {
246				num := 0
247				reqsErrs := []error{
248					mockRetryableError{b: true},
249					mockRetryableError{b: true},
250					mockRetryableError{b: true},
251				}
252				return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
253					if num >= len(reqsErrs) {
254						err = fmt.Errorf("more requests then expected")
255					} else {
256						err = reqsErrs[num]
257						num++
258					}
259					return out, metadata, err
260				})
261			},
262			Err: fmt.Errorf("exceeded maximum number of attempts"),
263			ExpectResults: AttemptResults{Results: []AttemptResult{
264				{
265					Err:       mockRetryableError{b: true},
266					Retryable: true,
267					Retried:   true,
268				},
269				{
270					Err:       mockRetryableError{b: true},
271					Retryable: true,
272					Retried:   true,
273				},
274				{
275					Err:       &MaxAttemptsError{Attempt: 3, Err: mockRetryableError{b: true}},
276					Retryable: true,
277				},
278			}},
279		},
280		"stops on rewind error": {
281			Request: testRequest{DisableRewind: true},
282			Next: func(retries *[]retryMetadata) middleware.FinalizeHandler {
283				return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
284					m, ok := getRetryMetadata(ctx)
285					if ok {
286						*retries = append(*retries, m)
287					}
288					return out, metadata, mockRetryableError{b: true}
289				})
290			},
291			Expect: []retryMetadata{
292				{
293					AttemptNum:  1,
294					AttemptTime: time.Date(2020, 8, 19, 10, 20, 30, 0, time.UTC),
295					MaxAttempts: 3,
296				},
297			},
298			Err: fmt.Errorf("failed to rewind transport stream for retry"),
299			ExpectResults: AttemptResults{Results: []AttemptResult{
300				{
301					Err:       mockRetryableError{b: true},
302					Retryable: true,
303					Retried:   true,
304				},
305				{
306					Err: fmt.Errorf(
307						"failed to rewind transport stream for retry, %w",
308						fmt.Errorf("rewind disabled"),
309					),
310				},
311			}},
312		},
313		"stops on non-retryable errors": {
314			Next: func(retries *[]retryMetadata) middleware.FinalizeHandler {
315				return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
316					m, ok := getRetryMetadata(ctx)
317					if ok {
318						*retries = append(*retries, m)
319					}
320					return out, metadata, fmt.Errorf("some error")
321				})
322			},
323			Expect: []retryMetadata{
324				{
325					AttemptNum:  1,
326					AttemptTime: time.Date(2020, 8, 19, 10, 20, 30, 0, time.UTC),
327					MaxAttempts: 3,
328				},
329			},
330			Err: fmt.Errorf("some error"),
331			ExpectResults: AttemptResults{Results: []AttemptResult{
332				{
333					Err: fmt.Errorf("some error"),
334				},
335			}},
336		},
337		"nested metadata and valid response": {
338			Next: func(retries *[]retryMetadata) middleware.FinalizeHandler {
339				num := 0
340				reqsErrs := []error{
341					mockRetryableError{b: true},
342					nil,
343				}
344				return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
345					m, ok := getRetryMetadata(ctx)
346					if ok {
347						*retries = append(*retries, m)
348					}
349					if num >= len(reqsErrs) {
350						err = fmt.Errorf("more requests then expected")
351					} else {
352						err = reqsErrs[num]
353						num++
354					}
355
356					if err != nil {
357						metadata.Set("testKey", "testValue")
358					} else {
359						setMockRawResponse(&metadata, "mockResponse")
360					}
361					return out, metadata, err
362				})
363			},
364			Expect: []retryMetadata{
365				{
366					AttemptNum:  1,
367					AttemptTime: time.Date(2020, 8, 19, 10, 20, 30, 0, time.UTC),
368					MaxAttempts: 3,
369				},
370				{
371					AttemptNum:  2,
372					AttemptTime: time.Date(2020, 8, 19, 10, 21, 30, 0, time.UTC),
373					MaxAttempts: 3,
374				},
375			},
376			ExpectResults: AttemptResults{Results: []AttemptResult{
377				{
378					Err:       mockRetryableError{b: true},
379					Retryable: true,
380					Retried:   true,
381					ResponseMetadata: func() middleware.Metadata {
382						m := middleware.Metadata{}
383						m.Set("testKey", "testValue")
384						return m
385					}(),
386				},
387				{
388					ResponseMetadata: func() middleware.Metadata {
389						m := middleware.Metadata{}
390						setMockRawResponse(&m, "mockResponse")
391						return m
392					}(),
393				},
394			}},
395		},
396	}
397
398	for name, tt := range cases {
399		t.Run(name, func(t *testing.T) {
400			sdk.NowTime = func() func() time.Time {
401				base := time.Date(2020, 8, 19, 10, 20, 30, 0, time.UTC)
402				num := 0
403				return func() time.Time {
404					t := base.Add(time.Minute * time.Duration(num))
405					num++
406					return t
407				}
408			}()
409
410			am := NewAttemptMiddleware(NewStandard(func(s *StandardOptions) {
411				s.MaxAttempts = 3
412			}), func(i interface{}) interface{} {
413				return i
414			})
415
416			var recorded []retryMetadata
417			_, metadata, err := am.HandleFinalize(context.Background(), middleware.FinalizeInput{Request: tt.Request}, tt.Next(&recorded))
418			if err != nil && tt.Err == nil {
419				t.Errorf("expect no error, got %v", err)
420			} else if err == nil && tt.Err != nil {
421				t.Errorf("expect error, got none")
422			} else if err != nil && tt.Err != nil {
423				if !strings.Contains(err.Error(), tt.Err.Error()) {
424					t.Errorf("expect %v, got %v", tt.Err, err)
425				}
426			}
427			if diff := cmp.Diff(recorded, tt.Expect); len(diff) > 0 {
428				t.Error(diff)
429			}
430
431			attemptResults, ok := GetAttemptResults(metadata)
432			if !ok {
433				t.Fatalf("expected metadata to contain attempt results, got none")
434			}
435			if e, a := tt.ExpectResults, attemptResults; !reflect.DeepEqual(e, a) {
436				t.Fatalf("expected %v, got %v", e, a)
437			}
438		})
439	}
440}
441
442// mockRawResponseKey is used to test the behavior when response metadata is
443// nested within the attempt request.
444type mockRawResponseKey struct{}
445
446func setMockRawResponse(m *middleware.Metadata, v interface{}) {
447	m.Set(mockRawResponseKey{}, v)
448}
449