1package retry
2
3import (
4	"context"
5	"fmt"
6	"strconv"
7	"strings"
8	"time"
9
10	"github.com/aws/aws-sdk-go-v2/aws"
11	awsmiddle "github.com/aws/aws-sdk-go-v2/aws/middleware"
12	"github.com/aws/aws-sdk-go-v2/internal/sdk"
13	"github.com/aws/smithy-go/logging"
14	"github.com/aws/smithy-go/middleware"
15	smithymiddle "github.com/aws/smithy-go/middleware"
16	"github.com/aws/smithy-go/transport/http"
17)
18
19// RequestCloner is a function that can take an input request type and clone the request
20// for use in a subsequent retry attempt
21type RequestCloner func(interface{}) interface{}
22
23type retryMetadata struct {
24	AttemptNum       int
25	AttemptTime      time.Time
26	MaxAttempts      int
27	AttemptClockSkew time.Duration
28}
29
30// Attempt is a Smithy FinalizeMiddleware that handles retry attempts using the provided
31// Retryer implementation
32type Attempt struct {
33	// Enable the logging of retry attempts performed by the SDK.
34	// This will include logging retry attempts, unretryable errors, and when max attempts are reached.
35	LogAttempts bool
36
37	retryer       aws.Retryer
38	requestCloner RequestCloner
39}
40
41// NewAttemptMiddleware returns a new Attempt retry middleware.
42func NewAttemptMiddleware(retryer aws.Retryer, requestCloner RequestCloner, optFns ...func(*Attempt)) *Attempt {
43	m := &Attempt{retryer: retryer, requestCloner: requestCloner}
44	for _, fn := range optFns {
45		fn(m)
46	}
47	return m
48}
49
50// ID returns the middleware identifier
51func (r *Attempt) ID() string {
52	return "Retry"
53}
54
55func (r Attempt) logf(logger logging.Logger, classification logging.Classification, format string, v ...interface{}) {
56	if !r.LogAttempts {
57		return
58	}
59	logger.Logf(classification, format, v...)
60}
61
62// HandleFinalize utilizes the provider Retryer implementation to attempt retries over the next handler
63func (r Attempt) HandleFinalize(ctx context.Context, in smithymiddle.FinalizeInput, next smithymiddle.FinalizeHandler) (
64	out smithymiddle.FinalizeOutput, metadata smithymiddle.Metadata, err error,
65) {
66	var attemptNum int
67	var attemptClockSkew time.Duration
68	var attemptResults AttemptResults
69
70	maxAttempts := r.retryer.MaxAttempts()
71
72	for {
73		attemptNum++
74		attemptInput := in
75		attemptInput.Request = r.requestCloner(attemptInput.Request)
76
77		attemptCtx := setRetryMetadata(ctx, retryMetadata{
78			AttemptNum:       attemptNum,
79			AttemptTime:      sdk.NowTime().UTC(),
80			MaxAttempts:      maxAttempts,
81			AttemptClockSkew: attemptClockSkew,
82		})
83
84		var attemptResult AttemptResult
85
86		out, attemptResult, err = r.handleAttempt(attemptCtx, attemptInput, next)
87
88		var ok bool
89		attemptClockSkew, ok = awsmiddle.GetAttemptSkew(attemptResult.ResponseMetadata)
90		if !ok {
91			attemptClockSkew = 0
92		}
93
94		shouldRetry := attemptResult.Retried
95
96		// add attempt metadata to list of all attempt metadata
97		attemptResults.Results = append(attemptResults.Results, attemptResult)
98
99		if !shouldRetry {
100			// Ensure the last response's metadata is used as the bases for result
101			// metadata returned by the stack.
102			metadata = attemptResult.ResponseMetadata.Clone()
103
104			break
105		}
106	}
107
108	addAttemptResults(&metadata, attemptResults)
109	return out, metadata, err
110}
111
112// handleAttempt handles an individual request attempt.
113func (r Attempt) handleAttempt(ctx context.Context, in smithymiddle.FinalizeInput, next smithymiddle.FinalizeHandler) (
114	out smithymiddle.FinalizeOutput, attemptResult AttemptResult, err error,
115) {
116	defer func() {
117		attemptResult.Err = err
118	}()
119
120	relRetryToken := r.retryer.GetInitialToken()
121	logger := smithymiddle.GetLogger(ctx)
122	service, operation := awsmiddle.GetServiceID(ctx), awsmiddle.GetOperationName(ctx)
123
124	retryMetadata, _ := getRetryMetadata(ctx)
125	attemptNum := retryMetadata.AttemptNum
126	maxAttempts := retryMetadata.MaxAttempts
127
128	if attemptNum > 1 {
129		if rewindable, ok := in.Request.(interface{ RewindStream() error }); ok {
130			if rewindErr := rewindable.RewindStream(); rewindErr != nil {
131				err = fmt.Errorf("failed to rewind transport stream for retry, %w", rewindErr)
132				return out, attemptResult, err
133			}
134		}
135
136		r.logf(logger, logging.Debug, "retrying request %s/%s, attempt %d", service, operation, attemptNum)
137	}
138
139	var metadata smithymiddle.Metadata
140	out, metadata, err = next.HandleFinalize(ctx, in)
141	attemptResult.ResponseMetadata = metadata
142
143	if releaseError := relRetryToken(err); releaseError != nil && err != nil {
144		err = fmt.Errorf("failed to release token after request error, %w", err)
145		return out, attemptResult, err
146	}
147
148	if err == nil {
149		return out, attemptResult, err
150	}
151
152	retryable := r.retryer.IsErrorRetryable(err)
153	if !retryable {
154		r.logf(logger, logging.Debug, "request failed with unretryable error %v", err)
155		return out, attemptResult, err
156	}
157
158	// set retryable to true
159	attemptResult.Retryable = true
160
161	if maxAttempts > 0 && attemptNum >= maxAttempts {
162		r.logf(logger, logging.Debug, "max retry attempts exhausted, max %d", maxAttempts)
163		err = &MaxAttemptsError{
164			Attempt: attemptNum,
165			Err:     err,
166		}
167		return out, attemptResult, err
168	}
169
170	relRetryToken, reqErr := r.retryer.GetRetryToken(ctx, err)
171	if reqErr != nil {
172		return out, attemptResult, reqErr
173	}
174
175	retryDelay, reqErr := r.retryer.RetryDelay(attemptNum, err)
176	if reqErr != nil {
177		return out, attemptResult, reqErr
178	}
179
180	if reqErr = sdk.SleepWithContext(ctx, retryDelay); reqErr != nil {
181		err = &aws.RequestCanceledError{Err: reqErr}
182		return out, attemptResult, err
183	}
184
185	attemptResult.Retried = true
186
187	return out, attemptResult, err
188}
189
190// MetricsHeader attaches SDK request metric header for retries to the transport
191type MetricsHeader struct{}
192
193// ID returns the middleware identifier
194func (r *MetricsHeader) ID() string {
195	return "RetryMetricsHeader"
196}
197
198// HandleFinalize attaches the sdk request metric header to the transport layer
199func (r MetricsHeader) HandleFinalize(ctx context.Context, in smithymiddle.FinalizeInput, next smithymiddle.FinalizeHandler) (
200	out smithymiddle.FinalizeOutput, metadata smithymiddle.Metadata, err error,
201) {
202	retryMetadata, _ := getRetryMetadata(ctx)
203
204	const retryMetricHeader = "Amz-Sdk-Request"
205	var parts []string
206
207	parts = append(parts, "attempt="+strconv.Itoa(retryMetadata.AttemptNum))
208	if retryMetadata.MaxAttempts != 0 {
209		parts = append(parts, "max="+strconv.Itoa(retryMetadata.MaxAttempts))
210	}
211
212	var ttl time.Time
213	if deadline, ok := ctx.Deadline(); ok {
214		ttl = deadline
215	}
216
217	// Only append the TTL if it can be determined.
218	if !ttl.IsZero() && retryMetadata.AttemptClockSkew > 0 {
219		const unixTimeFormat = "20060102T150405Z"
220		ttl = ttl.Add(retryMetadata.AttemptClockSkew)
221		parts = append(parts, "ttl="+ttl.Format(unixTimeFormat))
222	}
223
224	switch req := in.Request.(type) {
225	case *http.Request:
226		req.Header[retryMetricHeader] = append(req.Header[retryMetricHeader][:0], strings.Join(parts, "; "))
227	default:
228		return out, metadata, fmt.Errorf("unknown transport type %T", req)
229	}
230
231	return next.HandleFinalize(ctx, in)
232}
233
234type retryMetadataKey struct{}
235
236// getRetryMetadata retrieves retryMetadata from the context and a bool
237// indicating if it was set.
238//
239// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
240// to clear all stack values.
241func getRetryMetadata(ctx context.Context) (metadata retryMetadata, ok bool) {
242	metadata, ok = middleware.GetStackValue(ctx, retryMetadataKey{}).(retryMetadata)
243	return metadata, ok
244}
245
246// setRetryMetadata sets the retryMetadata on the context.
247//
248// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
249// to clear all stack values.
250func setRetryMetadata(ctx context.Context, metadata retryMetadata) context.Context {
251	return middleware.WithStackValue(ctx, retryMetadataKey{}, metadata)
252}
253
254// AddRetryMiddlewaresOptions is the set of options that can be passed to AddRetryMiddlewares for configuring retry
255// associated middleware.
256type AddRetryMiddlewaresOptions struct {
257	Retryer aws.Retryer
258
259	// Enable the logging of retry attempts performed by the SDK.
260	// This will include logging retry attempts, unretryable errors, and when max attempts are reached.
261	LogRetryAttempts bool
262}
263
264// AddRetryMiddlewares adds retry middleware to operation middleware stack
265func AddRetryMiddlewares(stack *smithymiddle.Stack, options AddRetryMiddlewaresOptions) error {
266	attempt := NewAttemptMiddleware(options.Retryer, http.RequestCloner, func(middleware *Attempt) {
267		middleware.LogAttempts = options.LogRetryAttempts
268	})
269
270	if err := stack.Finalize.Add(attempt, smithymiddle.After); err != nil {
271		return err
272	}
273	if err := stack.Finalize.Add(&MetricsHeader{}, smithymiddle.After); err != nil {
274		return err
275	}
276	return nil
277}
278