1package retry
2
3import (
4	"context"
5	"fmt"
6	"strconv"
7	"strings"
8	"time"
9
10	"github.com/aws/aws-sdk-go-v2/aws"
11	awsmiddle "github.com/aws/aws-sdk-go-v2/aws/middleware"
12	"github.com/aws/aws-sdk-go-v2/internal/sdk"
13	"github.com/aws/smithy-go/logging"
14	"github.com/aws/smithy-go/middleware"
15	smithymiddle "github.com/aws/smithy-go/middleware"
16	"github.com/aws/smithy-go/transport/http"
17)
18
19// RequestCloner is a function that can take an input request type and clone the request
20// for use in a subsequent retry attempt
21type RequestCloner func(interface{}) interface{}
22
23type retryMetadata struct {
24	AttemptNum       int
25	AttemptTime      time.Time
26	MaxAttempts      int
27	AttemptClockSkew time.Duration
28}
29
30// Attempt is a Smithy FinalizeMiddleware that handles retry attempts using the provided
31// Retryer implementation
32type Attempt struct {
33	// Enable the logging of retry attempts performed by the SDK.
34	// This will include logging retry attempts, unretryable errors, and when max attempts are reached.
35	LogAttempts bool
36
37	retryer       aws.Retryer
38	requestCloner RequestCloner
39}
40
41// NewAttemptMiddleware returns a new Attempt retry middleware.
42func NewAttemptMiddleware(retryer aws.Retryer, requestCloner RequestCloner, optFns ...func(*Attempt)) *Attempt {
43	m := &Attempt{retryer: retryer, requestCloner: requestCloner}
44	for _, fn := range optFns {
45		fn(m)
46	}
47	return m
48}
49
50// ID returns the middleware identifier
51func (r *Attempt) ID() string {
52	return "Retry"
53}
54
55func (r Attempt) logf(logger logging.Logger, classification logging.Classification, format string, v ...interface{}) {
56	if !r.LogAttempts {
57		return
58	}
59	logger.Logf(classification, format, v...)
60}
61
62// HandleFinalize utilizes the provider Retryer implementation to attempt retries over the next handler
63func (r Attempt) HandleFinalize(ctx context.Context, in smithymiddle.FinalizeInput, next smithymiddle.FinalizeHandler) (
64	out smithymiddle.FinalizeOutput, metadata smithymiddle.Metadata, err error,
65) {
66	var attemptNum int
67	var attemptClockSkew time.Duration
68	var attemptResults AttemptResults
69
70	maxAttempts := r.retryer.MaxAttempts()
71
72	for {
73		attemptNum++
74		attemptInput := in
75		attemptInput.Request = r.requestCloner(attemptInput.Request)
76
77		attemptCtx := setRetryMetadata(ctx, retryMetadata{
78			AttemptNum:       attemptNum,
79			AttemptTime:      sdk.NowTime().UTC(),
80			MaxAttempts:      maxAttempts,
81			AttemptClockSkew: attemptClockSkew,
82		})
83
84		var attemptResult AttemptResult
85
86		out, attemptResult, err = r.handleAttempt(attemptCtx, attemptInput, next)
87
88		var ok bool
89		attemptClockSkew, ok = awsmiddle.GetAttemptSkew(attemptResult.ResponseMetadata)
90		if !ok {
91			attemptClockSkew = 0
92		}
93
94		shouldRetry := attemptResult.Retried
95
96		// add attempt metadata to list of all attempt metadata
97		attemptResults.Results = append(attemptResults.Results, attemptResult)
98
99		if !shouldRetry {
100			break
101		}
102	}
103
104	addAttemptResults(&metadata, attemptResults)
105	return out, metadata, err
106}
107
108// handleAttempt handles an individual request attempt.
109func (r Attempt) handleAttempt(ctx context.Context, in smithymiddle.FinalizeInput, next smithymiddle.FinalizeHandler) (
110	out smithymiddle.FinalizeOutput, attemptResult AttemptResult, err error,
111) {
112	defer func() {
113		attemptResult.Err = err
114	}()
115
116	relRetryToken := r.retryer.GetInitialToken()
117	logger := smithymiddle.GetLogger(ctx)
118	service, operation := awsmiddle.GetServiceID(ctx), awsmiddle.GetOperationName(ctx)
119
120	retryMetadata, _ := getRetryMetadata(ctx)
121	attemptNum := retryMetadata.AttemptNum
122	maxAttempts := retryMetadata.MaxAttempts
123
124	if attemptNum > 1 {
125		if rewindable, ok := in.Request.(interface{ RewindStream() error }); ok {
126			if rewindErr := rewindable.RewindStream(); rewindErr != nil {
127				err = fmt.Errorf("failed to rewind transport stream for retry, %w", rewindErr)
128				return out, attemptResult, err
129			}
130		}
131
132		r.logf(logger, logging.Debug, "retrying request %s/%s, attempt %d", service, operation, attemptNum)
133	}
134
135	var metadata smithymiddle.Metadata
136	out, metadata, err = next.HandleFinalize(ctx, in)
137	attemptResult.ResponseMetadata = metadata
138
139	if releaseError := relRetryToken(err); releaseError != nil && err != nil {
140		err = fmt.Errorf("failed to release token after request error, %w", err)
141		return out, attemptResult, err
142	}
143
144	if err == nil {
145		return out, attemptResult, err
146	}
147
148	retryable := r.retryer.IsErrorRetryable(err)
149	if !retryable {
150		r.logf(logger, logging.Debug, "request failed with unretryable error %v", err)
151		return out, attemptResult, err
152	}
153
154	// set retryable to true
155	attemptResult.Retryable = true
156
157	if maxAttempts > 0 && attemptNum >= maxAttempts {
158		r.logf(logger, logging.Debug, "max retry attempts exhausted, max %d", maxAttempts)
159		err = &MaxAttemptsError{
160			Attempt: attemptNum,
161			Err:     err,
162		}
163		return out, attemptResult, err
164	}
165
166	relRetryToken, reqErr := r.retryer.GetRetryToken(ctx, err)
167	if reqErr != nil {
168		return out, attemptResult, reqErr
169	}
170
171	retryDelay, reqErr := r.retryer.RetryDelay(attemptNum, err)
172	if reqErr != nil {
173		return out, attemptResult, reqErr
174	}
175
176	if reqErr = sdk.SleepWithContext(ctx, retryDelay); reqErr != nil {
177		err = &aws.RequestCanceledError{Err: reqErr}
178		return out, attemptResult, err
179	}
180
181	attemptResult.Retried = true
182
183	return out, attemptResult, err
184}
185
186// MetricsHeader attaches SDK request metric header for retries to the transport
187type MetricsHeader struct{}
188
189// ID returns the middleware identifier
190func (r *MetricsHeader) ID() string {
191	return "RetryMetricsHeader"
192}
193
194// HandleFinalize attaches the sdk request metric header to the transport layer
195func (r MetricsHeader) HandleFinalize(ctx context.Context, in smithymiddle.FinalizeInput, next smithymiddle.FinalizeHandler) (
196	out smithymiddle.FinalizeOutput, metadata smithymiddle.Metadata, err error,
197) {
198	retryMetadata, _ := getRetryMetadata(ctx)
199
200	const retryMetricHeader = "Amz-Sdk-Request"
201	var parts []string
202
203	parts = append(parts, "attempt="+strconv.Itoa(retryMetadata.AttemptNum))
204	if retryMetadata.MaxAttempts != 0 {
205		parts = append(parts, "max="+strconv.Itoa(retryMetadata.MaxAttempts))
206	}
207
208	var ttl time.Time
209	if deadline, ok := ctx.Deadline(); ok {
210		ttl = deadline
211	}
212
213	// Only append the TTL if it can be determined.
214	if !ttl.IsZero() && retryMetadata.AttemptClockSkew > 0 {
215		const unixTimeFormat = "20060102T150405Z"
216		ttl = ttl.Add(retryMetadata.AttemptClockSkew)
217		parts = append(parts, "ttl="+ttl.Format(unixTimeFormat))
218	}
219
220	switch req := in.Request.(type) {
221	case *http.Request:
222		req.Header[retryMetricHeader] = append(req.Header[retryMetricHeader][:0], strings.Join(parts, "; "))
223	default:
224		return out, metadata, fmt.Errorf("unknown transport type %T", req)
225	}
226
227	return next.HandleFinalize(ctx, in)
228}
229
230type retryMetadataKey struct{}
231
232// getRetryMetadata retrieves retryMetadata from the context and a bool
233// indicating if it was set.
234//
235// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
236// to clear all stack values.
237func getRetryMetadata(ctx context.Context) (metadata retryMetadata, ok bool) {
238	metadata, ok = middleware.GetStackValue(ctx, retryMetadataKey{}).(retryMetadata)
239	return metadata, ok
240}
241
242// setRetryMetadata sets the retryMetadata on the context.
243//
244// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
245// to clear all stack values.
246func setRetryMetadata(ctx context.Context, metadata retryMetadata) context.Context {
247	return middleware.WithStackValue(ctx, retryMetadataKey{}, metadata)
248}
249
250// AddRetryMiddlewaresOptions is the set of options that can be passed to AddRetryMiddlewares for configuring retry
251// associated middleware.
252type AddRetryMiddlewaresOptions struct {
253	Retryer aws.Retryer
254
255	// Enable the logging of retry attempts performed by the SDK.
256	// This will include logging retry attempts, unretryable errors, and when max attempts are reached.
257	LogRetryAttempts bool
258}
259
260// AddRetryMiddlewares adds retry middleware to operation middleware stack
261func AddRetryMiddlewares(stack *smithymiddle.Stack, options AddRetryMiddlewaresOptions) error {
262	attempt := NewAttemptMiddleware(options.Retryer, http.RequestCloner, func(middleware *Attempt) {
263		middleware.LogAttempts = options.LogRetryAttempts
264	})
265
266	if err := stack.Finalize.Add(attempt, smithymiddle.After); err != nil {
267		return err
268	}
269	if err := stack.Finalize.Add(&MetricsHeader{}, smithymiddle.After); err != nil {
270		return err
271	}
272	return nil
273}
274