1package v4
2
3import (
4	"context"
5	"crypto/sha256"
6	"encoding/hex"
7	"fmt"
8	"io"
9
10	"github.com/aws/aws-sdk-go-v2/aws"
11	awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
12	v4Internal "github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4"
13	"github.com/aws/aws-sdk-go-v2/internal/sdk"
14	"github.com/aws/smithy-go/middleware"
15	smithyHTTP "github.com/aws/smithy-go/transport/http"
16)
17
18const computePayloadHashMiddlewareID = "ComputePayloadHash"
19
20// HashComputationError indicates an error occurred while computing the signing hash
21type HashComputationError struct {
22	Err error
23}
24
25// Error is the error message
26func (e *HashComputationError) Error() string {
27	return fmt.Sprintf("failed to compute payload hash: %v", e.Err)
28}
29
30// Unwrap returns the underlying error if one is set
31func (e *HashComputationError) Unwrap() error {
32	return e.Err
33}
34
35// SigningError indicates an error condition occurred while performing SigV4 signing
36type SigningError struct {
37	Err error
38}
39
40func (e *SigningError) Error() string {
41	return fmt.Sprintf("failed to sign request: %v", e.Err)
42}
43
44// Unwrap returns the underlying error cause
45func (e *SigningError) Unwrap() error {
46	return e.Err
47}
48
49// unsignedPayload sets the SigV4 request payload hash to unsigned.
50//
51// Will not set the Unsigned Payload magic SHA value, if a SHA has already been
52// stored in the context. (e.g. application pre-computed SHA256 before making
53// API call).
54//
55// This middleware does not check the X-Amz-Content-Sha256 header, if that
56// header is serialized a middleware must translate it into the context.
57type unsignedPayload struct{}
58
59// AddUnsignedPayloadMiddleware adds unsignedPayload to the operation
60// middleware stack
61func AddUnsignedPayloadMiddleware(stack *middleware.Stack) error {
62	return stack.Build.Add(&unsignedPayload{}, middleware.After)
63}
64
65// ID returns the unsignedPayload identifier
66func (m *unsignedPayload) ID() string {
67	return computePayloadHashMiddlewareID
68}
69
70// HandleBuild sets the payload hash to be an unsigned payload
71func (m *unsignedPayload) HandleBuild(
72	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
73) (
74	out middleware.BuildOutput, metadata middleware.Metadata, err error,
75) {
76	// This should not compute the content SHA256 if the value is already
77	// known. (e.g. application pre-computed SHA256 before making API call).
78	// Does not have any tight coupling to the X-Amz-Content-Sha256 header, if
79	// that header is provided a middleware must translate it into the context.
80	contentSHA := GetPayloadHash(ctx)
81	if len(contentSHA) == 0 {
82		contentSHA = v4Internal.UnsignedPayload
83	}
84
85	ctx = SetPayloadHash(ctx, contentSHA)
86	return next.HandleBuild(ctx, in)
87}
88
89// computePayloadSHA256 computes SHA256 payload hash to sign.
90//
91// Will not set the Unsigned Payload magic SHA value, if a SHA has already been
92// stored in the context. (e.g. application pre-computed SHA256 before making
93// API call).
94//
95// This middleware does not check the X-Amz-Content-Sha256 header, if that
96// header is serialized a middleware must translate it into the context.
97type computePayloadSHA256 struct{}
98
99// AddComputePayloadSHA256Middleware adds computePayloadSHA256 to the
100// operation middleware stack
101func AddComputePayloadSHA256Middleware(stack *middleware.Stack) error {
102	return stack.Build.Add(&computePayloadSHA256{}, middleware.After)
103}
104
105// RemoveComputePayloadSHA256Middleware removes computePayloadSHA256 from the
106// operation middleware stack
107func RemoveComputePayloadSHA256Middleware(stack *middleware.Stack) error {
108	_, err := stack.Build.Remove(computePayloadHashMiddlewareID)
109	return err
110}
111
112// ID is the middleware name
113func (m *computePayloadSHA256) ID() string {
114	return computePayloadHashMiddlewareID
115}
116
117// HandleBuild compute the payload hash for the request payload
118func (m *computePayloadSHA256) HandleBuild(
119	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
120) (
121	out middleware.BuildOutput, metadata middleware.Metadata, err error,
122) {
123	req, ok := in.Request.(*smithyHTTP.Request)
124	if !ok {
125		return out, metadata, &HashComputationError{
126			Err: fmt.Errorf("unexpected request middleware type %T", in.Request),
127		}
128	}
129
130	// This should not compute the content SHA256 if the value is already
131	// known. (e.g. application pre-computed SHA256 before making API call)
132	// Does not have any tight coupling to the X-Amz-Content-Sha256 header, if
133	// that header is provided a middleware must translate it into the context.
134	if contentSHA := GetPayloadHash(ctx); len(contentSHA) != 0 {
135		return next.HandleBuild(ctx, in)
136	}
137
138	hash := sha256.New()
139	if stream := req.GetStream(); stream != nil {
140		_, err = io.Copy(hash, stream)
141		if err != nil {
142			return out, metadata, &HashComputationError{
143				Err: fmt.Errorf("failed to compute payload hash, %w", err),
144			}
145		}
146
147		if err := req.RewindStream(); err != nil {
148			return out, metadata, &HashComputationError{
149				Err: fmt.Errorf("failed to seek body to start, %w", err),
150			}
151		}
152	}
153
154	ctx = SetPayloadHash(ctx, hex.EncodeToString(hash.Sum(nil)))
155
156	return next.HandleBuild(ctx, in)
157}
158
159// SwapComputePayloadSHA256ForUnsignedPayloadMiddleware replaces the
160// ComputePayloadSHA256 middleware with the UnsignedPayload middleware.
161//
162// Use this to disable computing the Payload SHA256 checksum and instead use
163// UNSIGNED-PAYLOAD for the SHA256 value.
164func SwapComputePayloadSHA256ForUnsignedPayloadMiddleware(stack *middleware.Stack) error {
165	_, err := stack.Build.Swap(computePayloadHashMiddlewareID, &unsignedPayload{})
166	return err
167}
168
169// contentSHA256Header sets the X-Amz-Content-Sha256 header value to
170// the Payload hash stored in the context.
171type contentSHA256Header struct{}
172
173// AddContentSHA256HeaderMiddleware adds ContentSHA256Header to the
174// operation middleware stack
175func AddContentSHA256HeaderMiddleware(stack *middleware.Stack) error {
176	return stack.Build.Insert(&contentSHA256Header{}, computePayloadHashMiddlewareID, middleware.After)
177}
178
179// RemoveContentSHA256HeaderMiddleware removes contentSHA256Header middleware
180// from the operation middleware stack
181func RemoveContentSHA256HeaderMiddleware(stack *middleware.Stack) error {
182	_, err := stack.Build.Remove((*contentSHA256Header)(nil).ID())
183	return err
184}
185
186// ID returns the ContentSHA256HeaderMiddleware identifier
187func (m *contentSHA256Header) ID() string {
188	return "SigV4ContentSHA256Header"
189}
190
191// HandleBuild sets the X-Amz-Content-Sha256 header value to the Payload hash
192// stored in the context.
193func (m *contentSHA256Header) HandleBuild(
194	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
195) (
196	out middleware.BuildOutput, metadata middleware.Metadata, err error,
197) {
198	req, ok := in.Request.(*smithyHTTP.Request)
199	if !ok {
200		return out, metadata, &HashComputationError{Err: fmt.Errorf("unexpected request middleware type %T", in.Request)}
201	}
202
203	req.Header.Set(v4Internal.ContentSHAKey, GetPayloadHash(ctx))
204
205	return next.HandleBuild(ctx, in)
206}
207
208// SignHTTPRequestMiddlewareOptions is the configuration options for the SignHTTPRequestMiddleware middleware.
209type SignHTTPRequestMiddlewareOptions struct {
210	CredentialsProvider aws.CredentialsProvider
211	Signer              HTTPSigner
212	LogSigning          bool
213}
214
215// SignHTTPRequestMiddleware is a `FinalizeMiddleware` implementation for SigV4 HTTP Signing
216type SignHTTPRequestMiddleware struct {
217	credentialsProvider aws.CredentialsProvider
218	signer              HTTPSigner
219	logSigning          bool
220}
221
222// NewSignHTTPRequestMiddleware constructs a SignHTTPRequestMiddleware using the given Signer for signing requests
223func NewSignHTTPRequestMiddleware(options SignHTTPRequestMiddlewareOptions) *SignHTTPRequestMiddleware {
224	return &SignHTTPRequestMiddleware{
225		credentialsProvider: options.CredentialsProvider,
226		signer:              options.Signer,
227		logSigning:          options.LogSigning,
228	}
229}
230
231// ID is the SignHTTPRequestMiddleware identifier
232func (s *SignHTTPRequestMiddleware) ID() string {
233	return "Signing"
234}
235
236// HandleFinalize will take the provided input and sign the request using the SigV4 authentication scheme
237func (s *SignHTTPRequestMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
238	out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
239) {
240	if !haveCredentialProvider(s.credentialsProvider) {
241		return next.HandleFinalize(ctx, in)
242	}
243
244	req, ok := in.Request.(*smithyHTTP.Request)
245	if !ok {
246		return out, metadata, &SigningError{Err: fmt.Errorf("unexpected request middleware type %T", in.Request)}
247	}
248
249	signingName, signingRegion := awsmiddleware.GetSigningName(ctx), awsmiddleware.GetSigningRegion(ctx)
250	payloadHash := GetPayloadHash(ctx)
251	if len(payloadHash) == 0 {
252		return out, metadata, &SigningError{Err: fmt.Errorf("computed payload hash missing from context")}
253	}
254
255	credentials, err := s.credentialsProvider.Retrieve(ctx)
256	if err != nil {
257		return out, metadata, &SigningError{Err: fmt.Errorf("failed to retrieve credentials: %w", err)}
258	}
259
260	err = s.signer.SignHTTP(ctx, credentials, req.Request, payloadHash, signingName, signingRegion, sdk.NowTime(),
261		func(o *SignerOptions) {
262			o.Logger = middleware.GetLogger(ctx)
263			o.LogSigning = s.logSigning
264		})
265	if err != nil {
266		return out, metadata, &SigningError{Err: fmt.Errorf("failed to sign http request, %w", err)}
267	}
268
269	return next.HandleFinalize(ctx, in)
270}
271
272func haveCredentialProvider(p aws.CredentialsProvider) bool {
273	if p == nil {
274		return false
275	}
276	switch p.(type) {
277	case aws.AnonymousCredentials,
278		*aws.AnonymousCredentials:
279		return false
280	}
281
282	return true
283}
284
285type payloadHashKey struct{}
286
287// GetPayloadHash retrieves the payload hash to use for signing
288//
289// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
290// to clear all stack values.
291func GetPayloadHash(ctx context.Context) (v string) {
292	v, _ = middleware.GetStackValue(ctx, payloadHashKey{}).(string)
293	return v
294}
295
296// SetPayloadHash sets the payload hash to be used for signing the request
297//
298// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
299// to clear all stack values.
300func SetPayloadHash(ctx context.Context, hash string) context.Context {
301	return middleware.WithStackValue(ctx, payloadHashKey{}, hash)
302}
303