1package v4
2
3import (
4	"context"
5	"crypto/sha256"
6	"encoding/hex"
7	"fmt"
8	"io"
9
10	"github.com/aws/aws-sdk-go-v2/aws"
11	awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
12	v4Internal "github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4"
13	"github.com/aws/aws-sdk-go-v2/internal/sdk"
14	"github.com/aws/smithy-go/middleware"
15	smithyHTTP "github.com/aws/smithy-go/transport/http"
16)
17
18const computePayloadHashMiddlewareID = "ComputePayloadHash"
19
20// HashComputationError indicates an error occurred while computing the signing hash
21type HashComputationError struct {
22	Err error
23}
24
25// Error is the error message
26func (e *HashComputationError) Error() string {
27	return fmt.Sprintf("failed to compute payload hash: %v", e.Err)
28}
29
30// Unwrap returns the underlying error if one is set
31func (e *HashComputationError) Unwrap() error {
32	return e.Err
33}
34
35// SigningError indicates an error condition occurred while performing SigV4 signing
36type SigningError struct {
37	Err error
38}
39
40func (e *SigningError) Error() string {
41	return fmt.Sprintf("failed to sign request: %v", e.Err)
42}
43
44// Unwrap returns the underlying error cause
45func (e *SigningError) Unwrap() error {
46	return e.Err
47}
48
49// unsignedPayload sets the SigV4 request payload hash to unsigned.
50//
51// Will not set the Unsigned Payload magic SHA value, if a SHA has already been
52// stored in the context. (e.g. application pre-computed SHA256 before making
53// API call).
54//
55// This middleware does not check the X-Amz-Content-Sha256 header, if that
56// header is serialized a middleware must translate it into the context.
57type unsignedPayload struct{}
58
59// AddUnsignedPayloadMiddleware adds unsignedPayload to the operation
60// middleware stack
61func AddUnsignedPayloadMiddleware(stack *middleware.Stack) error {
62	return stack.Build.Add(&unsignedPayload{}, middleware.After)
63}
64
65// ID returns the unsignedPayload identifier
66func (m *unsignedPayload) ID() string {
67	return computePayloadHashMiddlewareID
68}
69
70// HandleBuild sets the payload hash to be an unsigned payload
71func (m *unsignedPayload) HandleBuild(
72	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
73) (
74	out middleware.BuildOutput, metadata middleware.Metadata, err error,
75) {
76	// This should not compute the content SHA256 if the value is already
77	// known. (e.g. application pre-computed SHA256 before making API call).
78	// Does not have any tight coupling to the X-Amz-Content-Sha256 header, if
79	// that header is provided a middleware must translate it into the context.
80	contentSHA := GetPayloadHash(ctx)
81	if len(contentSHA) == 0 {
82		contentSHA = v4Internal.UnsignedPayload
83	}
84
85	ctx = SetPayloadHash(ctx, contentSHA)
86	return next.HandleBuild(ctx, in)
87}
88
89// computePayloadSHA256 computes SHA256 payload hash to sign.
90//
91// Will not set the Unsigned Payload magic SHA value, if a SHA has already been
92// stored in the context. (e.g. application pre-computed SHA256 before making
93// API call).
94//
95// This middleware does not check the X-Amz-Content-Sha256 header, if that
96// header is serialized a middleware must translate it into the context.
97type computePayloadSHA256 struct{}
98
99// AddComputePayloadSHA256Middleware adds computePayloadSHA256 to the
100// operation middleware stack
101func AddComputePayloadSHA256Middleware(stack *middleware.Stack) error {
102	return stack.Build.Add(&computePayloadSHA256{}, middleware.After)
103}
104
105// RemoveComputePayloadSHA256Middleware removes computePayloadSHA256 from the
106// operation middleware stack
107func RemoveComputePayloadSHA256Middleware(stack *middleware.Stack) error {
108	_, err := stack.Build.Remove(computePayloadHashMiddlewareID)
109	return err
110}
111
112// ID is the middleware name
113func (m *computePayloadSHA256) ID() string {
114	return computePayloadHashMiddlewareID
115}
116
117// HandleBuild compute the payload hash for the request payload
118func (m *computePayloadSHA256) HandleBuild(
119	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
120) (
121	out middleware.BuildOutput, metadata middleware.Metadata, err error,
122) {
123	req, ok := in.Request.(*smithyHTTP.Request)
124	if !ok {
125		return out, metadata, &HashComputationError{
126			Err: fmt.Errorf("unexpected request middleware type %T", in.Request),
127		}
128	}
129
130	// This should not compute the content SHA256 if the value is already
131	// known. (e.g. application pre-computed SHA256 before making API call)
132	// Does not have any tight coupling to the X-Amz-Content-Sha256 header, if
133	// that header is provided a middleware must translate it into the context.
134	if contentSHA := GetPayloadHash(ctx); len(contentSHA) != 0 {
135		return next.HandleBuild(ctx, in)
136	}
137
138	hash := sha256.New()
139	if stream := req.GetStream(); stream != nil {
140		_, err = io.Copy(hash, stream)
141		if err != nil {
142			return out, metadata, &HashComputationError{
143				Err: fmt.Errorf("failed to compute payload hash, %w", err),
144			}
145		}
146
147		if err := req.RewindStream(); err != nil {
148			return out, metadata, &HashComputationError{
149				Err: fmt.Errorf("failed to seek body to start, %w", err),
150			}
151		}
152	}
153
154	ctx = SetPayloadHash(ctx, hex.EncodeToString(hash.Sum(nil)))
155
156	return next.HandleBuild(ctx, in)
157}
158
159// contentSHA256Header sets the X-Amz-Content-Sha256 header value to
160// the Payload hash stored in the context.
161type contentSHA256Header struct{}
162
163// AddContentSHA256HeaderMiddleware adds ContentSHA256Header to the
164// operation middleware stack
165func AddContentSHA256HeaderMiddleware(stack *middleware.Stack) error {
166	return stack.Build.Insert(&contentSHA256Header{}, computePayloadHashMiddlewareID, middleware.After)
167}
168
169// RemoveContentSHA256HeaderMiddleware removes contentSHA256Header middleware
170// from the operation middleware stack
171func RemoveContentSHA256HeaderMiddleware(stack *middleware.Stack) error {
172	_, err := stack.Build.Remove((*contentSHA256Header)(nil).ID())
173	return err
174}
175
176// ID returns the ContentSHA256HeaderMiddleware identifier
177func (m *contentSHA256Header) ID() string {
178	return "SigV4ContentSHA256Header"
179}
180
181// HandleBuild sets the X-Amz-Content-Sha256 header value to the Payload hash
182// stored in the context.
183func (m *contentSHA256Header) HandleBuild(
184	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
185) (
186	out middleware.BuildOutput, metadata middleware.Metadata, err error,
187) {
188	req, ok := in.Request.(*smithyHTTP.Request)
189	if !ok {
190		return out, metadata, &HashComputationError{Err: fmt.Errorf("unexpected request middleware type %T", in.Request)}
191	}
192
193	req.Header.Set(v4Internal.ContentSHAKey, GetPayloadHash(ctx))
194
195	return next.HandleBuild(ctx, in)
196}
197
198// SignHTTPRequestMiddlewareOptions is the configuration options for the SignHTTPRequestMiddleware middleware.
199type SignHTTPRequestMiddlewareOptions struct {
200	CredentialsProvider aws.CredentialsProvider
201	Signer              HTTPSigner
202	LogSigning          bool
203}
204
205// SignHTTPRequestMiddleware is a `FinalizeMiddleware` implementation for SigV4 HTTP Signing
206type SignHTTPRequestMiddleware struct {
207	credentialsProvider aws.CredentialsProvider
208	signer              HTTPSigner
209	logSigning          bool
210}
211
212// NewSignHTTPRequestMiddleware constructs a SignHTTPRequestMiddleware using the given Signer for signing requests
213func NewSignHTTPRequestMiddleware(options SignHTTPRequestMiddlewareOptions) *SignHTTPRequestMiddleware {
214	return &SignHTTPRequestMiddleware{
215		credentialsProvider: options.CredentialsProvider,
216		signer:              options.Signer,
217		logSigning:          options.LogSigning,
218	}
219}
220
221// ID is the SignHTTPRequestMiddleware identifier
222func (s *SignHTTPRequestMiddleware) ID() string {
223	return "Signing"
224}
225
226// HandleFinalize will take the provided input and sign the request using the SigV4 authentication scheme
227func (s *SignHTTPRequestMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
228	out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
229) {
230	if !haveCredentialProvider(s.credentialsProvider) {
231		return next.HandleFinalize(ctx, in)
232	}
233
234	req, ok := in.Request.(*smithyHTTP.Request)
235	if !ok {
236		return out, metadata, &SigningError{Err: fmt.Errorf("unexpected request middleware type %T", in.Request)}
237	}
238
239	signingName, signingRegion := awsmiddleware.GetSigningName(ctx), awsmiddleware.GetSigningRegion(ctx)
240	payloadHash := GetPayloadHash(ctx)
241	if len(payloadHash) == 0 {
242		return out, metadata, &SigningError{Err: fmt.Errorf("computed payload hash missing from context")}
243	}
244
245	credentials, err := s.credentialsProvider.Retrieve(ctx)
246	if err != nil {
247		return out, metadata, &SigningError{Err: fmt.Errorf("failed to retrieve credentials: %w", err)}
248	}
249
250	err = s.signer.SignHTTP(ctx, credentials, req.Request, payloadHash, signingName, signingRegion, sdk.NowTime(),
251		func(o *SignerOptions) {
252			o.Logger = middleware.GetLogger(ctx)
253			o.LogSigning = s.logSigning
254		})
255	if err != nil {
256		return out, metadata, &SigningError{Err: fmt.Errorf("failed to sign http request, %w", err)}
257	}
258
259	return next.HandleFinalize(ctx, in)
260}
261
262func haveCredentialProvider(p aws.CredentialsProvider) bool {
263	if p == nil {
264		return false
265	}
266	switch p.(type) {
267	case aws.AnonymousCredentials,
268		*aws.AnonymousCredentials:
269		return false
270	}
271
272	return true
273}
274
275type payloadHashKey struct{}
276
277// GetPayloadHash retrieves the payload hash to use for signing
278//
279// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
280// to clear all stack values.
281func GetPayloadHash(ctx context.Context) (v string) {
282	v, _ = middleware.GetStackValue(ctx, payloadHashKey{}).(string)
283	return v
284}
285
286// SetPayloadHash sets the payload hash to be used for signing the request
287//
288// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
289// to clear all stack values.
290func SetPayloadHash(ctx context.Context, hash string) context.Context {
291	return middleware.WithStackValue(ctx, payloadHashKey{}, hash)
292}
293