1package customizations
2
3import (
4	"bytes"
5	"context"
6	"encoding/xml"
7	"fmt"
8	"io"
9	"io/ioutil"
10	"strings"
11
12	"github.com/aws/smithy-go"
13	smithyxml "github.com/aws/smithy-go/encoding/xml"
14	"github.com/aws/smithy-go/middleware"
15	smithyhttp "github.com/aws/smithy-go/transport/http"
16)
17
18// HandleResponseErrorWith200Status check for S3 200 error response.
19// If an s3 200 error is found, status code for the response is modified temporarily to
20// 5xx response status code.
21func HandleResponseErrorWith200Status(stack *middleware.Stack) error {
22	return stack.Deserialize.Insert(&processResponseFor200ErrorMiddleware{}, "OperationDeserializer", middleware.After)
23}
24
25// middleware to process raw response and look for error response with 200 status code
26type processResponseFor200ErrorMiddleware struct{}
27
28// ID returns the middleware ID.
29func (*processResponseFor200ErrorMiddleware) ID() string {
30	return "S3:ProcessResponseFor200Error"
31}
32
33func (m *processResponseFor200ErrorMiddleware) HandleDeserialize(
34	ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) (
35	out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
36) {
37	out, metadata, err = next.HandleDeserialize(ctx, in)
38	if err != nil {
39		return out, metadata, err
40	}
41
42	response, ok := out.RawResponse.(*smithyhttp.Response)
43	if !ok {
44		return out, metadata, &smithy.DeserializationError{Err: fmt.Errorf("unknown transport type %T", out.RawResponse)}
45	}
46
47	// check if response status code is 2xx.
48	if response.StatusCode < 200 || response.StatusCode >= 300 {
49		return
50	}
51
52	var readBuff bytes.Buffer
53	body := io.TeeReader(response.Body, &readBuff)
54
55	rootDecoder := xml.NewDecoder(body)
56	t, err := smithyxml.FetchRootElement(rootDecoder)
57	if err == io.EOF {
58		return out, metadata, &smithy.DeserializationError{
59			Err: fmt.Errorf("received empty response payload"),
60		}
61	}
62
63	// rewind response body
64	response.Body = ioutil.NopCloser(io.MultiReader(&readBuff, response.Body))
65
66	// if start tag is "Error", the response is consider error response.
67	if strings.EqualFold(t.Name.Local, "Error") {
68		// according to https://aws.amazon.com/premiumsupport/knowledge-center/s3-resolve-200-internalerror/
69		// 200 error responses are similar to 5xx errors.
70		response.StatusCode = 500
71	}
72
73	return out, metadata, err
74}
75