1package s3shared
2
3import (
4	"context"
5
6	awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
7	awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
8	"github.com/aws/smithy-go/middleware"
9	smithyhttp "github.com/aws/smithy-go/transport/http"
10)
11
12// AddResponseErrorMiddleware adds response error wrapper middleware
13func AddResponseErrorMiddleware(stack *middleware.Stack) error {
14	// add error wrapper middleware before request id retriever middleware so that it can wrap the error response
15	// returned by operation deserializers
16	return stack.Deserialize.Insert(&errorWrapper{}, metadataRetrieverID, middleware.Before)
17}
18
19type errorWrapper struct {
20}
21
22// ID returns the middleware identifier
23func (m *errorWrapper) ID() string {
24	return "ResponseErrorWrapper"
25}
26
27func (m *errorWrapper) HandleDeserialize(ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) (
28	out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
29) {
30	out, metadata, err = next.HandleDeserialize(ctx, in)
31	if err == nil {
32		// Nothing to do when there is no error.
33		return out, metadata, err
34	}
35
36	resp, ok := out.RawResponse.(*smithyhttp.Response)
37	if !ok {
38		// No raw response to wrap with.
39		return out, metadata, err
40	}
41
42	// look for request id in metadata
43	reqID, _ := awsmiddleware.GetRequestIDMetadata(metadata)
44	// look for host id in metadata
45	hostID, _ := GetHostIDMetadata(metadata)
46
47	// Wrap the returned smithy error with the request id retrieved from the metadata
48	err = &ResponseError{
49		ResponseError: &awshttp.ResponseError{
50			ResponseError: &smithyhttp.ResponseError{
51				Response: resp,
52				Err:      err,
53			},
54			RequestID: reqID,
55		},
56		HostID: hostID,
57	}
58
59	return out, metadata, err
60}
61