1package s3
2
3import (
4	"bytes"
5	"crypto/md5"
6	"crypto/sha256"
7	"encoding/base64"
8	"encoding/hex"
9	"fmt"
10	"hash"
11	"io"
12
13	"github.com/aws/aws-sdk-go/aws"
14	"github.com/aws/aws-sdk-go/aws/awserr"
15	"github.com/aws/aws-sdk-go/aws/request"
16)
17
18const (
19	contentMD5Header    = "Content-Md5"
20	contentSha256Header = "X-Amz-Content-Sha256"
21	amzTeHeader         = "X-Amz-Te"
22	amzTxEncodingHeader = "X-Amz-Transfer-Encoding"
23
24	appendMD5TxEncoding = "append-md5"
25)
26
27// computeBodyHashes will add Content MD5 and Content Sha256 hashes to the
28// request. If the body is not seekable or S3DisableContentMD5Validation set
29// this handler will be ignored.
30func computeBodyHashes(r *request.Request) {
31	if aws.BoolValue(r.Config.S3DisableContentMD5Validation) {
32		return
33	}
34	if r.IsPresigned() {
35		return
36	}
37	if r.Error != nil || !aws.IsReaderSeekable(r.Body) {
38		return
39	}
40
41	var md5Hash, sha256Hash hash.Hash
42	hashers := make([]io.Writer, 0, 2)
43
44	// Determine upfront which hashes can be set without overriding user
45	// provide header data.
46	if v := r.HTTPRequest.Header.Get(contentMD5Header); len(v) == 0 {
47		md5Hash = md5.New()
48		hashers = append(hashers, md5Hash)
49	}
50
51	if v := r.HTTPRequest.Header.Get(contentSha256Header); len(v) == 0 {
52		sha256Hash = sha256.New()
53		hashers = append(hashers, sha256Hash)
54	}
55
56	// Create the destination writer based on the hashes that are not already
57	// provided by the user.
58	var dst io.Writer
59	switch len(hashers) {
60	case 0:
61		return
62	case 1:
63		dst = hashers[0]
64	default:
65		dst = io.MultiWriter(hashers...)
66	}
67
68	if _, err := aws.CopySeekableBody(dst, r.Body); err != nil {
69		r.Error = awserr.New("BodyHashError", "failed to compute body hashes", err)
70		return
71	}
72
73	// For the hashes created, set the associated headers that the user did not
74	// already provide.
75	if md5Hash != nil {
76		sum := make([]byte, md5.Size)
77		encoded := make([]byte, md5Base64EncLen)
78
79		base64.StdEncoding.Encode(encoded, md5Hash.Sum(sum[0:0]))
80		r.HTTPRequest.Header[contentMD5Header] = []string{string(encoded)}
81	}
82
83	if sha256Hash != nil {
84		encoded := make([]byte, sha256HexEncLen)
85		sum := make([]byte, sha256.Size)
86
87		hex.Encode(encoded, sha256Hash.Sum(sum[0:0]))
88		r.HTTPRequest.Header[contentSha256Header] = []string{string(encoded)}
89	}
90}
91
92const (
93	md5Base64EncLen = (md5.Size + 2) / 3 * 4 // base64.StdEncoding.EncodedLen
94	sha256HexEncLen = sha256.Size * 2        // hex.EncodedLen
95)
96
97// Adds the x-amz-te: append_md5 header to the request. This requests the service
98// responds with a trailing MD5 checksum.
99//
100// Will not ask for append MD5 if disabled, the request is presigned or,
101// or the API operation does not support content MD5 validation.
102func askForTxEncodingAppendMD5(r *request.Request) {
103	if aws.BoolValue(r.Config.S3DisableContentMD5Validation) {
104		return
105	}
106	if r.IsPresigned() {
107		return
108	}
109	r.HTTPRequest.Header.Set(amzTeHeader, appendMD5TxEncoding)
110}
111
112func useMD5ValidationReader(r *request.Request) {
113	if r.Error != nil {
114		return
115	}
116
117	if v := r.HTTPResponse.Header.Get(amzTxEncodingHeader); v != appendMD5TxEncoding {
118		return
119	}
120
121	var bodyReader *io.ReadCloser
122	var contentLen int64
123	switch tv := r.Data.(type) {
124	case *GetObjectOutput:
125		bodyReader = &tv.Body
126		contentLen = aws.Int64Value(tv.ContentLength)
127		// Update ContentLength hiden the trailing MD5 checksum.
128		tv.ContentLength = aws.Int64(contentLen - md5.Size)
129		tv.ContentRange = aws.String(r.HTTPResponse.Header.Get("X-Amz-Content-Range"))
130	default:
131		r.Error = awserr.New("ChecksumValidationError",
132			fmt.Sprintf("%s: %s header received on unsupported API, %s",
133				amzTxEncodingHeader, appendMD5TxEncoding, r.Operation.Name,
134			), nil)
135		return
136	}
137
138	if contentLen < md5.Size {
139		r.Error = awserr.New("ChecksumValidationError",
140			fmt.Sprintf("invalid Content-Length %d for %s %s",
141				contentLen, appendMD5TxEncoding, amzTxEncodingHeader,
142			), nil)
143		return
144	}
145
146	// Wrap and swap the response body reader with the validation reader.
147	*bodyReader = newMD5ValidationReader(*bodyReader, contentLen-md5.Size)
148}
149
150type md5ValidationReader struct {
151	rawReader io.ReadCloser
152	payload   io.Reader
153	hash      hash.Hash
154
155	payloadLen int64
156	read       int64
157}
158
159func newMD5ValidationReader(reader io.ReadCloser, payloadLen int64) *md5ValidationReader {
160	h := md5.New()
161	return &md5ValidationReader{
162		rawReader:  reader,
163		payload:    io.TeeReader(&io.LimitedReader{R: reader, N: payloadLen}, h),
164		hash:       h,
165		payloadLen: payloadLen,
166	}
167}
168
169func (v *md5ValidationReader) Read(p []byte) (n int, err error) {
170	n, err = v.payload.Read(p)
171	if err != nil && err != io.EOF {
172		return n, err
173	}
174
175	v.read += int64(n)
176
177	if err == io.EOF {
178		if v.read != v.payloadLen {
179			return n, io.ErrUnexpectedEOF
180		}
181		expectSum := make([]byte, md5.Size)
182		actualSum := make([]byte, md5.Size)
183		if _, sumReadErr := io.ReadFull(v.rawReader, expectSum); sumReadErr != nil {
184			return n, sumReadErr
185		}
186		actualSum = v.hash.Sum(actualSum[0:0])
187		if !bytes.Equal(expectSum, actualSum) {
188			return n, awserr.New("InvalidChecksum",
189				fmt.Sprintf("expected MD5 checksum %s, got %s",
190					hex.EncodeToString(expectSum),
191					hex.EncodeToString(actualSum),
192				),
193				nil)
194		}
195	}
196
197	return n, err
198}
199
200func (v *md5ValidationReader) Close() error {
201	return v.rawReader.Close()
202}
203