1package manager_test
2
3import (
4	"bytes"
5	"context"
6	"fmt"
7	"io"
8	"io/ioutil"
9	"net/http"
10	"net/http/httptest"
11	"os"
12	"reflect"
13	"regexp"
14	"sort"
15	"strconv"
16	"strings"
17	"testing"
18
19	"github.com/aws/aws-sdk-go-v2/aws"
20	"github.com/aws/aws-sdk-go-v2/aws/retry"
21	"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
22	s3testing "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/testing"
23	"github.com/aws/aws-sdk-go-v2/internal/awstesting"
24	"github.com/aws/aws-sdk-go-v2/internal/sdk"
25	"github.com/aws/aws-sdk-go-v2/service/s3"
26	"github.com/aws/aws-sdk-go-v2/service/s3/types"
27	"github.com/google/go-cmp/cmp"
28)
29
30// getReaderLength discards the bytes from reader and returns the length
31func getReaderLength(r io.Reader) int64 {
32	n, _ := io.Copy(ioutil.Discard, r)
33	return n
34}
35
36func TestUploadOrderMulti(t *testing.T) {
37	c, invocations, args := s3testing.NewUploadLoggingClient(nil)
38	u := manager.NewUploader(c)
39
40	resp, err := u.Upload(context.Background(), &s3.PutObjectInput{
41		Bucket:               aws.String("Bucket"),
42		Key:                  aws.String("Key - value"),
43		Body:                 bytes.NewReader(buf12MB),
44		ServerSideEncryption: "aws:kms",
45		SSEKMSKeyId:          aws.String("KmsId"),
46		ContentType:          aws.String("content/type"),
47	})
48
49	if err != nil {
50		t.Errorf("Expected no error but received %v", err)
51	}
52
53	if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart",
54		"UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 {
55		t.Error(err)
56	}
57
58	if e, a := `https://mock.amazonaws.com/key`, resp.Location; e != a {
59		t.Errorf("expect %q, got %q", e, a)
60	}
61
62	if "UPLOAD-ID" != resp.UploadID {
63		t.Errorf("expect %q, got %q", "UPLOAD-ID", resp.UploadID)
64	}
65
66	if "VERSION-ID" != *resp.VersionID {
67		t.Errorf("expect %q, got %q", "VERSION-ID", *resp.VersionID)
68	}
69
70	// Validate input values
71
72	// UploadPart
73	for i := 1; i < 4; i++ {
74		v := aws.ToString((*args)[i].(*s3.UploadPartInput).UploadId)
75		if "UPLOAD-ID" != v {
76			t.Errorf("Expected %q, but received %q", "UPLOAD-ID", v)
77		}
78	}
79
80	// CompleteMultipartUpload
81	v := aws.ToString((*args)[4].(*s3.CompleteMultipartUploadInput).UploadId)
82	if "UPLOAD-ID" != v {
83		t.Errorf("Expected %q, but received %q", "UPLOAD-ID", v)
84	}
85
86	parts := (*args)[4].(*s3.CompleteMultipartUploadInput).MultipartUpload.Parts
87
88	for i := 0; i < 3; i++ {
89		num := parts[i].PartNumber
90		etag := aws.ToString(parts[i].ETag)
91
92		if int32(i+1) != num {
93			t.Errorf("expect %d, got %d", i+1, num)
94		}
95
96		if matched, err := regexp.MatchString(`^ETAG\d+$`, etag); !matched || err != nil {
97			t.Errorf("Failed regexp expression `^ETAG\\d+$`")
98		}
99	}
100
101	// Custom headers
102	cmu := (*args)[0].(*s3.CreateMultipartUploadInput)
103
104	if e, a := types.ServerSideEncryption("aws:kms"), cmu.ServerSideEncryption; e != a {
105		t.Errorf("expect %q, got %q", e, a)
106	}
107
108	if e, a := "KmsId", aws.ToString(cmu.SSEKMSKeyId); e != a {
109		t.Errorf("expect %q, got %q", e, a)
110	}
111
112	if e, a := "content/type", aws.ToString(cmu.ContentType); e != a {
113		t.Errorf("expect %q, got %q", e, a)
114	}
115}
116
117func TestUploadOrderMultiDifferentPartSize(t *testing.T) {
118	s, ops, args := s3testing.NewUploadLoggingClient(nil)
119	mgr := manager.NewUploader(s, func(u *manager.Uploader) {
120		u.PartSize = 1024 * 1024 * 7
121		u.Concurrency = 1
122	})
123	_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
124		Bucket: aws.String("Bucket"),
125		Key:    aws.String("Key"),
126		Body:   bytes.NewReader(buf12MB),
127	})
128
129	if err != nil {
130		t.Errorf("expect no error, got %v", err)
131	}
132
133	vals := []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}
134	if !reflect.DeepEqual(vals, *ops) {
135		t.Errorf("expect %v, got %v", vals, *ops)
136	}
137
138	// Part lengths
139	if len := getReaderLength((*args)[1].(*s3.UploadPartInput).Body); 1024*1024*7 != len {
140		t.Errorf("expect %d, got %d", 1024*1024*7, len)
141	}
142	if len := getReaderLength((*args)[2].(*s3.UploadPartInput).Body); 1024*1024*5 != len {
143		t.Errorf("expect %d, got %d", 1024*1024*5, len)
144	}
145}
146
147func TestUploadIncreasePartSize(t *testing.T) {
148	s, invocations, args := s3testing.NewUploadLoggingClient(nil)
149	mgr := manager.NewUploader(s, func(u *manager.Uploader) {
150		u.Concurrency = 1
151		u.MaxUploadParts = 2
152	})
153	_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
154		Bucket: aws.String("Bucket"),
155		Key:    aws.String("Key"),
156		Body:   bytes.NewReader(buf12MB),
157	})
158
159	if err != nil {
160		t.Errorf("expect no error, got %v", err)
161	}
162
163	if int64(manager.DefaultDownloadPartSize) != mgr.PartSize {
164		t.Errorf("expect %d, got %d", manager.DefaultDownloadPartSize, mgr.PartSize)
165	}
166
167	if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 {
168		t.Error(diff)
169	}
170
171	// Part lengths
172	if len := getReaderLength((*args)[1].(*s3.UploadPartInput).Body); (1024*1024*6)+1 != len {
173		t.Errorf("expect %d, got %d", (1024*1024*6)+1, len)
174	}
175
176	if len := getReaderLength((*args)[2].(*s3.UploadPartInput).Body); (1024*1024*6)-1 != len {
177		t.Errorf("expect %d, got %d", (1024*1024*6)-1, len)
178	}
179}
180
181func TestUploadFailIfPartSizeTooSmall(t *testing.T) {
182	mgr := manager.NewUploader(s3.New(s3.Options{}), func(u *manager.Uploader) {
183		u.PartSize = 5
184	})
185	resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
186		Bucket: aws.String("Bucket"),
187		Key:    aws.String("Key"),
188		Body:   bytes.NewReader(buf12MB),
189	})
190
191	if resp != nil {
192		t.Errorf("Expected response to be nil, but received %v", resp)
193	}
194
195	if err == nil {
196		t.Errorf("Expected error, but received nil")
197	}
198
199	if e, a := "part size must be at least", err.Error(); !strings.Contains(a, e) {
200		t.Errorf("expect %v to be in %v", e, a)
201	}
202}
203
204func TestUploadOrderSingle(t *testing.T) {
205	client, invocations, params := s3testing.NewUploadLoggingClient(nil)
206	mgr := manager.NewUploader(client)
207	resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
208		Bucket:               aws.String("Bucket"),
209		Key:                  aws.String("Key - value"),
210		Body:                 bytes.NewReader(buf2MB),
211		ServerSideEncryption: "aws:kms",
212		SSEKMSKeyId:          aws.String("KmsId"),
213		ContentType:          aws.String("content/type"),
214	})
215
216	if err != nil {
217		t.Errorf("expect no error but received %v", err)
218	}
219
220	if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 {
221		t.Error(diff)
222	}
223
224	if e, a := `https://mock.amazonaws.com/key`, resp.Location; e != a {
225		t.Errorf("expect %q, got %q", e, a)
226	}
227
228	if e := "VERSION-ID"; e != *resp.VersionID {
229		t.Errorf("expect %q, got %q", e, *resp.VersionID)
230	}
231
232	if len(resp.UploadID) > 0 {
233		t.Errorf("expect empty string, got %q", resp.UploadID)
234	}
235
236	putObjectInput := (*params)[0].(*s3.PutObjectInput)
237
238	if e, a := types.ServerSideEncryption("aws:kms"), putObjectInput.ServerSideEncryption; e != a {
239		t.Errorf("expect %q, got %q", e, a)
240	}
241
242	if e, a := "KmsId", aws.ToString(putObjectInput.SSEKMSKeyId); e != a {
243		t.Errorf("expect %q, got %q", e, a)
244	}
245
246	if e, a := "content/type", aws.ToString(putObjectInput.ContentType); e != a {
247		t.Errorf("Expected %q, but received %q", e, a)
248	}
249}
250
251func TestUploadOrderSingleFailure(t *testing.T) {
252	client, ops, _ := s3testing.NewUploadLoggingClient(nil)
253
254	client.PutObjectFn = func(*s3testing.UploadLoggingClient, *s3.PutObjectInput) (*s3.PutObjectOutput, error) {
255		return nil, fmt.Errorf("put object failure")
256	}
257
258	mgr := manager.NewUploader(client)
259	resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
260		Bucket: aws.String("Bucket"),
261		Key:    aws.String("Key"),
262		Body:   bytes.NewReader(buf2MB),
263	})
264
265	if err == nil {
266		t.Error("expect error, got nil")
267	}
268
269	if diff := cmp.Diff([]string{"PutObject"}, *ops); len(diff) > 0 {
270		t.Error(diff)
271	}
272
273	if resp != nil {
274		t.Errorf("expect response to be nil, got %v", resp)
275	}
276}
277
278func TestUploadOrderZero(t *testing.T) {
279	c, invocations, params := s3testing.NewUploadLoggingClient(nil)
280	mgr := manager.NewUploader(c)
281	resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
282		Bucket: aws.String("Bucket"),
283		Key:    aws.String("Key"),
284		Body:   bytes.NewReader(make([]byte, 0)),
285	})
286
287	if err != nil {
288		t.Errorf("expect no error, got %v", err)
289	}
290
291	if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 {
292		t.Error(diff)
293	}
294
295	if len(resp.Location) == 0 {
296		t.Error("expect Location to not be empty")
297	}
298
299	if len(resp.UploadID) > 0 {
300		t.Errorf("expect empty string, got %q", resp.UploadID)
301	}
302
303	if e, a := int64(0), getReaderLength((*params)[0].(*s3.PutObjectInput).Body); e != a {
304		t.Errorf("Expected %d, but received %d", e, a)
305	}
306}
307
308func TestUploadOrderMultiFailure(t *testing.T) {
309	c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
310
311	c.UploadPartFn = func(u *s3testing.UploadLoggingClient, _ *s3.UploadPartInput) (*s3.UploadPartOutput, error) {
312		if u.PartNum == 2 {
313			return nil, fmt.Errorf("an unexpected error")
314		}
315		return &s3.UploadPartOutput{ETag: aws.String(fmt.Sprintf("ETAG%d", u.PartNum))}, nil
316	}
317
318	mgr := manager.NewUploader(c, func(u *manager.Uploader) {
319		u.Concurrency = 1
320	})
321	_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
322		Bucket: aws.String("Bucket"),
323		Key:    aws.String("Key"),
324		Body:   bytes.NewReader(buf12MB),
325	})
326
327	if err == nil {
328		t.Error("expect error, got nil")
329	}
330
331	if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "AbortMultipartUpload"}, *invocations); len(diff) > 0 {
332		t.Error(diff)
333	}
334}
335
336func TestUploadOrderMultiFailureOnComplete(t *testing.T) {
337	c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
338
339	c.CompleteMultipartUploadFn = func(*s3testing.UploadLoggingClient, *s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error) {
340		return nil, fmt.Errorf("complete multipart error")
341	}
342
343	mgr := manager.NewUploader(c, func(u *manager.Uploader) {
344		u.Concurrency = 1
345	})
346	_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
347		Bucket: aws.String("Bucket"),
348		Key:    aws.String("Key"),
349		Body:   bytes.NewReader(buf12MB),
350	})
351
352	if err == nil {
353		t.Error("expect error, got nil")
354	}
355
356	if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "UploadPart",
357		"CompleteMultipartUpload", "AbortMultipartUpload"}, *invocations); len(diff) > 0 {
358		t.Error(diff)
359	}
360}
361
362func TestUploadOrderMultiFailureOnCreate(t *testing.T) {
363	c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
364
365	c.CreateMultipartUploadFn = func(*s3testing.UploadLoggingClient, *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error) {
366		return nil, fmt.Errorf("create multipart upload failure")
367	}
368
369	mgr := manager.NewUploader(c)
370	_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
371		Bucket: aws.String("Bucket"),
372		Key:    aws.String("Key"),
373		Body:   bytes.NewReader(make([]byte, 1024*1024*12)),
374	})
375
376	if err == nil {
377		t.Error("expect error, got nil")
378	}
379
380	if diff := cmp.Diff([]string{"CreateMultipartUpload"}, *invocations); len(diff) > 0 {
381		t.Error(diff)
382	}
383}
384
385func TestUploadOrderMultiFailureLeaveParts(t *testing.T) {
386	c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
387
388	c.UploadPartFn = func(u *s3testing.UploadLoggingClient, _ *s3.UploadPartInput) (*s3.UploadPartOutput, error) {
389		if u.PartNum == 2 {
390			return nil, fmt.Errorf("upload part failure")
391		}
392		return &s3.UploadPartOutput{ETag: aws.String(fmt.Sprintf("ETAG%d", u.PartNum))}, nil
393	}
394
395	mgr := manager.NewUploader(c, func(u *manager.Uploader) {
396		u.Concurrency = 1
397		u.LeavePartsOnError = true
398	})
399	_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
400		Bucket: aws.String("Bucket"),
401		Key:    aws.String("Key"),
402		Body:   bytes.NewReader(make([]byte, 1024*1024*12)),
403	})
404
405	if err == nil {
406		t.Error("expect error, got nil")
407	}
408
409	if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart"}, *invocations); len(diff) > 0 {
410		t.Error(err)
411	}
412}
413
414type failreader struct {
415	times     int
416	failCount int
417}
418
419func (f *failreader) Read(b []byte) (int, error) {
420	f.failCount++
421	if f.failCount >= f.times {
422		return 0, fmt.Errorf("random failure")
423	}
424	return len(b), nil
425}
426
427func TestUploadOrderReadFail1(t *testing.T) {
428	c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
429	mgr := manager.NewUploader(c)
430	_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
431		Bucket: aws.String("Bucket"),
432		Key:    aws.String("Key"),
433		Body:   &failreader{times: 1},
434	})
435	if err == nil {
436		t.Fatalf("expect error to not be nil")
437	}
438
439	if e, a := "random failure", err.Error(); !strings.Contains(a, e) {
440		t.Errorf("expect %v, got %v", e, a)
441	}
442
443	if diff := cmp.Diff([]string(nil), *invocations); len(diff) > 0 {
444		t.Error(diff)
445	}
446}
447
448func TestUploadOrderReadFail2(t *testing.T) {
449	c, invocations, _ := s3testing.NewUploadLoggingClient([]string{"UploadPart"})
450	mgr := manager.NewUploader(c, func(u *manager.Uploader) {
451		u.Concurrency = 1
452	})
453	_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
454		Bucket: aws.String("Bucket"),
455		Key:    aws.String("Key"),
456		Body:   &failreader{times: 2},
457	})
458	if err == nil {
459		t.Fatalf("expect error to not be nil")
460	}
461
462	if e, a := "random failure", err.Error(); !strings.Contains(a, e) {
463		t.Errorf("expect %v, got %q", e, a)
464	}
465
466	if diff := cmp.Diff([]string{"CreateMultipartUpload", "AbortMultipartUpload"}, *invocations); len(diff) > 0 {
467		t.Error(diff)
468	}
469}
470
471type sizedReader struct {
472	size int
473	cur  int
474	err  error
475}
476
477func (s *sizedReader) Read(p []byte) (n int, err error) {
478	if s.cur >= s.size {
479		if s.err == nil {
480			s.err = io.EOF
481		}
482		return 0, s.err
483	}
484
485	n = len(p)
486	s.cur += len(p)
487	if s.cur > s.size {
488		n -= s.cur - s.size
489	}
490
491	return n, err
492}
493
494func TestUploadOrderMultiBufferedReader(t *testing.T) {
495	c, invocations, params := s3testing.NewUploadLoggingClient(nil)
496	mgr := manager.NewUploader(c)
497	_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
498		Bucket: aws.String("Bucket"),
499		Key:    aws.String("Key"),
500		Body:   &sizedReader{size: 1024 * 1024 * 12},
501	})
502	if err != nil {
503		t.Errorf("expect no error, got %v", err)
504	}
505
506	if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart",
507		"UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 {
508		t.Error(diff)
509	}
510
511	// Part lengths
512	var parts []int64
513	for i := 1; i <= 3; i++ {
514		parts = append(parts, getReaderLength((*params)[i].(*s3.UploadPartInput).Body))
515	}
516	sort.Slice(parts, func(i, j int) bool {
517		return parts[i] < parts[j]
518	})
519
520	if diff := cmp.Diff([]int64{1024 * 1024 * 2, 1024 * 1024 * 5, 1024 * 1024 * 5}, parts); len(diff) > 0 {
521		t.Error(diff)
522	}
523}
524
525func TestUploadOrderMultiBufferedReaderPartial(t *testing.T) {
526	c, invocations, params := s3testing.NewUploadLoggingClient(nil)
527	mgr := manager.NewUploader(c)
528	_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
529		Bucket: aws.String("Bucket"),
530		Key:    aws.String("Key"),
531		Body:   &sizedReader{size: 1024 * 1024 * 12, err: io.EOF},
532	})
533	if err != nil {
534		t.Errorf("expect no error, got %v", err)
535	}
536
537	if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart",
538		"UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 {
539		t.Error(diff)
540	}
541
542	// Part lengths
543	var parts []int64
544	for i := 1; i <= 3; i++ {
545		parts = append(parts, getReaderLength((*params)[i].(*s3.UploadPartInput).Body))
546	}
547	sort.Slice(parts, func(i, j int) bool {
548		return parts[i] < parts[j]
549	})
550
551	if diff := cmp.Diff([]int64{1024 * 1024 * 2, 1024 * 1024 * 5, 1024 * 1024 * 5}, parts); len(diff) > 0 {
552		t.Error(diff)
553	}
554}
555
556// TestUploadOrderMultiBufferedReaderEOF tests the edge case where the
557// file size is the same as part size.
558func TestUploadOrderMultiBufferedReaderEOF(t *testing.T) {
559	c, invocations, params := s3testing.NewUploadLoggingClient(nil)
560	mgr := manager.NewUploader(c)
561	_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
562		Bucket: aws.String("Bucket"),
563		Key:    aws.String("Key"),
564		Body:   &sizedReader{size: 1024 * 1024 * 10, err: io.EOF},
565	})
566
567	if err != nil {
568		t.Errorf("expect no error, got %v", err)
569	}
570
571	if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 {
572		t.Error(diff)
573	}
574
575	// Part lengths
576	var parts []int64
577	for i := 1; i <= 2; i++ {
578		parts = append(parts, getReaderLength((*params)[i].(*s3.UploadPartInput).Body))
579	}
580	sort.Slice(parts, func(i, j int) bool {
581		return parts[i] < parts[j]
582	})
583
584	if diff := cmp.Diff([]int64{1024 * 1024 * 5, 1024 * 1024 * 5}, parts); len(diff) > 0 {
585		t.Error(diff)
586	}
587}
588
589func TestUploadOrderMultiBufferedReaderExceedTotalParts(t *testing.T) {
590	c, invocations, _ := s3testing.NewUploadLoggingClient([]string{"UploadPart"})
591	mgr := manager.NewUploader(c, func(u *manager.Uploader) {
592		u.Concurrency = 1
593		u.MaxUploadParts = 2
594	})
595	resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
596		Bucket: aws.String("Bucket"),
597		Key:    aws.String("Key"),
598		Body:   &sizedReader{size: 1024 * 1024 * 12},
599	})
600	if err == nil {
601		t.Fatal("expect error, got nil")
602	}
603
604	if resp != nil {
605		t.Errorf("expect nil, got %v", resp)
606	}
607
608	if diff := cmp.Diff([]string{"CreateMultipartUpload", "AbortMultipartUpload"}, *invocations); len(diff) > 0 {
609		t.Error(diff)
610	}
611
612	if !strings.Contains(err.Error(), "configured MaxUploadParts (2)") {
613		t.Errorf("expect 'configured MaxUploadParts (2)', got %q", err.Error())
614	}
615}
616
617func TestUploadOrderSingleBufferedReader(t *testing.T) {
618	c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
619	mgr := manager.NewUploader(c)
620	resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
621		Bucket: aws.String("Bucket"),
622		Key:    aws.String("Key"),
623		Body:   &sizedReader{size: 1024 * 1024 * 2},
624	})
625
626	if err != nil {
627		t.Errorf("expect no error, got %v", err)
628	}
629
630	if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 {
631		t.Error(diff)
632	}
633
634	if len(resp.Location) == 0 {
635		t.Error("expect a value in Location")
636	}
637
638	if len(resp.UploadID) > 0 {
639		t.Errorf("expect no value, got %q", resp.UploadID)
640	}
641}
642
643func TestUploadZeroLenObject(t *testing.T) {
644	client, invocations, _ := s3testing.NewUploadLoggingClient(nil)
645
646	mgr := manager.NewUploader(client)
647	resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
648		Bucket: aws.String("Bucket"),
649		Key:    aws.String("Key"),
650		Body:   strings.NewReader(""),
651	})
652
653	if err != nil {
654		t.Errorf("expect no error but received %v", err)
655	}
656	if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 {
657		t.Errorf("expect request to have been made, but was not, %v", diff)
658	}
659
660	// TODO: not needed?
661	if len(resp.Location) == 0 {
662		t.Error("expect a non-empty string value for Location")
663	}
664
665	if len(resp.UploadID) > 0 {
666		t.Errorf("expect empty string, but received %q", resp.UploadID)
667	}
668}
669
670type testIncompleteReader struct {
671	Size int64
672	read int64
673}
674
675func (r *testIncompleteReader) Read(p []byte) (n int, err error) {
676	r.read += int64(len(p))
677	if r.read >= r.Size {
678		return int(r.read - r.Size), io.ErrUnexpectedEOF
679	}
680	return len(p), nil
681}
682
683func TestUploadUnexpectedEOF(t *testing.T) {
684	c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
685	mgr := manager.NewUploader(c, func(u *manager.Uploader) {
686		u.Concurrency = 1
687		u.PartSize = manager.MinUploadPartSize
688	})
689	_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
690		Bucket: aws.String("Bucket"),
691		Key:    aws.String("Key"),
692		Body: &testIncompleteReader{
693			Size: manager.MinUploadPartSize + 1,
694		},
695	})
696	if err == nil {
697		t.Error("expect error, got nil")
698	}
699
700	// Ensure upload started.
701	if e, a := "CreateMultipartUpload", (*invocations)[0]; e != a {
702		t.Errorf("expect %q, got %q", e, a)
703	}
704
705	// Part may or may not be sent because of timing of sending parts and
706	// reading next part in upload manager. Just check for the last abort.
707	if e, a := "AbortMultipartUpload", (*invocations)[len(*invocations)-1]; e != a {
708		t.Errorf("expect %q, got %q", e, a)
709	}
710}
711
712func TestSSE(t *testing.T) {
713	client, _, _ := s3testing.NewUploadLoggingClient(nil)
714	client.UploadPartFn = func(u *s3testing.UploadLoggingClient, params *s3.UploadPartInput) (*s3.UploadPartOutput, error) {
715		if params.SSECustomerAlgorithm == nil {
716			t.Fatal("SSECustomerAlgoritm should not be nil")
717		}
718		if params.SSECustomerKey == nil {
719			t.Fatal("SSECustomerKey should not be nil")
720		}
721		return &s3.UploadPartOutput{ETag: aws.String(fmt.Sprintf("ETAG%d", u.PartNum))}, nil
722	}
723
724	mgr := manager.NewUploader(client, func(u *manager.Uploader) {
725		u.Concurrency = 5
726	})
727
728	_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
729		Bucket:               aws.String("Bucket"),
730		Key:                  aws.String("Key"),
731		SSECustomerAlgorithm: aws.String("AES256"),
732		SSECustomerKey:       aws.String("foo"),
733		Body:                 bytes.NewBuffer(make([]byte, 1024*1024*10)),
734	})
735
736	if err != nil {
737		t.Fatal("Expected no error, but received" + err.Error())
738	}
739}
740
741func TestUploadWithContextCanceled(t *testing.T) {
742	u := manager.NewUploader(s3.New(s3.Options{
743		UsePathStyle: true,
744		Region:       "mock-region",
745	}))
746
747	params := s3.PutObjectInput{
748		Bucket: aws.String("Bucket"),
749		Key:    aws.String("Key"),
750		Body:   bytes.NewReader(make([]byte, 0)),
751	}
752
753	ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
754	ctx.Error = fmt.Errorf("context canceled")
755	close(ctx.DoneCh)
756
757	_, err := u.Upload(ctx, &params)
758	if err == nil {
759		t.Fatalf("expect error, got nil")
760	}
761
762	if e, a := "canceled", err.Error(); !strings.Contains(a, e) {
763		t.Errorf("expected error message to contain %q, but did not %q", e, a)
764	}
765}
766
767// S3 Uploader incorrectly fails an upload if the content being uploaded
768// has a size of MinPartSize * MaxUploadParts.
769// Github:  aws/aws-sdk-go#2557
770func TestUploadMaxPartsEOF(t *testing.T) {
771	c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
772	mgr := manager.NewUploader(c, func(u *manager.Uploader) {
773		u.Concurrency = 1
774		u.PartSize = manager.DefaultUploadPartSize
775		u.MaxUploadParts = 2
776	})
777	f := bytes.NewReader(make([]byte, int(mgr.PartSize)*int(mgr.MaxUploadParts)))
778
779	r1 := io.NewSectionReader(f, 0, manager.DefaultUploadPartSize)
780	r2 := io.NewSectionReader(f, manager.DefaultUploadPartSize, 2*manager.DefaultUploadPartSize)
781	body := io.MultiReader(r1, r2)
782
783	_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
784		Bucket: aws.String("Bucket"),
785		Key:    aws.String("Key"),
786		Body:   body,
787	})
788
789	if err != nil {
790		t.Fatalf("expect no error, got %v", err)
791	}
792
793	expectOps := []string{
794		"CreateMultipartUpload",
795		"UploadPart",
796		"UploadPart",
797		"CompleteMultipartUpload",
798	}
799	if diff := cmp.Diff(expectOps, *invocations); len(diff) > 0 {
800		t.Error(diff)
801	}
802}
803
804func createTempFile(t *testing.T, size int64) (*os.File, func(*testing.T), error) {
805	file, err := ioutil.TempFile(os.TempDir(), aws.SDKName+t.Name())
806	if err != nil {
807		return nil, nil, err
808	}
809	filename := file.Name()
810	if err := file.Truncate(size); err != nil {
811		return nil, nil, err
812	}
813
814	return file,
815		func(t *testing.T) {
816			if err := file.Close(); err != nil {
817				t.Errorf("failed to close temp file, %s, %v", filename, err)
818			}
819			if err := os.Remove(filename); err != nil {
820				t.Errorf("failed to remove temp file, %s, %v", filename, err)
821			}
822		},
823		nil
824}
825
826func buildFailHandlers(tb testing.TB, parts, retry int) []http.Handler {
827	handlers := make([]http.Handler, parts)
828	for i := 0; i < len(handlers); i++ {
829		handlers[i] = &failPartHandler{
830			tb:             tb,
831			failsRemaining: retry,
832			successHandler: successPartHandler{tb: tb},
833		}
834	}
835
836	return handlers
837}
838
839func TestUploadRetry(t *testing.T) {
840	const numParts, retries = 3, 10
841
842	testFile, testFileCleanup, err := createTempFile(t, manager.DefaultUploadPartSize*numParts)
843	if err != nil {
844		t.Fatalf("failed to create test file, %v", err)
845	}
846	defer testFileCleanup(t)
847
848	cases := map[string]struct {
849		Body         io.Reader
850		PartHandlers func(testing.TB) []http.Handler
851	}{
852		"bytes.Buffer": {
853			Body: bytes.NewBuffer(make([]byte, manager.DefaultUploadPartSize*numParts)),
854			PartHandlers: func(tb testing.TB) []http.Handler {
855				return buildFailHandlers(tb, numParts, retries)
856			},
857		},
858		"bytes.Reader": {
859			Body: bytes.NewReader(make([]byte, manager.DefaultUploadPartSize*numParts)),
860			PartHandlers: func(tb testing.TB) []http.Handler {
861				return buildFailHandlers(tb, numParts, retries)
862			},
863		},
864		"os.File": {
865			Body: testFile,
866			PartHandlers: func(tb testing.TB) []http.Handler {
867				return buildFailHandlers(tb, numParts, retries)
868			},
869		},
870	}
871
872	for name, c := range cases {
873		t.Run(name, func(t *testing.T) {
874			restoreSleep := sdk.TestingUseNopSleep()
875			defer restoreSleep()
876
877			mux := newMockS3UploadServer(t, c.PartHandlers(t))
878			server := httptest.NewServer(mux)
879			defer server.Close()
880
881			client := s3.New(s3.Options{
882				EndpointResolver: s3testing.EndpointResolverFunc(func(region string, options s3.EndpointResolverOptions) (aws.Endpoint, error) {
883					return aws.Endpoint{
884						URL: server.URL,
885					}, nil
886				}),
887				UsePathStyle: true,
888				Retryer: retry.NewStandard(func(o *retry.StandardOptions) {
889					o.MaxAttempts = retries + 1
890				}),
891			})
892
893			uploader := manager.NewUploader(client)
894			_, err := uploader.Upload(context.Background(), &s3.PutObjectInput{
895				Bucket: aws.String("bucket"),
896				Key:    aws.String("key"),
897				Body:   c.Body,
898			})
899
900			if err != nil {
901				t.Fatalf("expect no error, got %v", err)
902			}
903		})
904	}
905}
906
907func TestUploadBufferStrategy(t *testing.T) {
908	cases := map[string]struct {
909		PartSize  int64
910		Size      int64
911		Strategy  manager.ReadSeekerWriteToProvider
912		callbacks int
913	}{
914		"NoBuffer": {
915			PartSize: manager.DefaultUploadPartSize,
916			Strategy: nil,
917		},
918		"SinglePart": {
919			PartSize:  manager.DefaultUploadPartSize,
920			Size:      manager.DefaultUploadPartSize,
921			Strategy:  &recordedBufferProvider{size: int(manager.DefaultUploadPartSize)},
922			callbacks: 1,
923		},
924		"MultiPart": {
925			PartSize:  manager.DefaultUploadPartSize,
926			Size:      manager.DefaultUploadPartSize * 2,
927			Strategy:  &recordedBufferProvider{size: int(manager.DefaultUploadPartSize)},
928			callbacks: 2,
929		},
930	}
931
932	for name, tCase := range cases {
933		t.Run(name, func(t *testing.T) {
934			client, _, _ := s3testing.NewUploadLoggingClient(nil)
935			client.ConsumeBody = true
936
937			uploader := manager.NewUploader(client, func(u *manager.Uploader) {
938				u.PartSize = tCase.PartSize
939				u.BufferProvider = tCase.Strategy
940				u.Concurrency = 1
941			})
942
943			expected := s3testing.GetTestBytes(int(tCase.Size))
944			_, err := uploader.Upload(context.Background(), &s3.PutObjectInput{
945				Bucket: aws.String("bucket"),
946				Key:    aws.String("key"),
947				Body:   bytes.NewReader(expected),
948			})
949			if err != nil {
950				t.Fatalf("failed to upload file: %v", err)
951			}
952
953			switch strat := tCase.Strategy.(type) {
954			case *recordedBufferProvider:
955				if !bytes.Equal(expected, strat.content) {
956					t.Errorf("content buffered did not match expected")
957				}
958				if tCase.callbacks != strat.callbackCount {
959					t.Errorf("expected %v, got %v callbacks", tCase.callbacks, strat.callbackCount)
960				}
961			}
962		})
963	}
964}
965
966func TestUploaderValidARN(t *testing.T) {
967	cases := map[string]struct {
968		input   s3.PutObjectInput
969		wantErr bool
970	}{
971		"standard bucket": {
972			input: s3.PutObjectInput{
973				Bucket: aws.String("test-bucket"),
974				Key:    aws.String("test-key"),
975				Body:   bytes.NewReader([]byte("test body content")),
976			},
977		},
978		"accesspoint": {
979			input: s3.PutObjectInput{
980				Bucket: aws.String("arn:aws:s3:us-west-2:123456789012:accesspoint/myap"),
981				Key:    aws.String("test-key"),
982				Body:   bytes.NewReader([]byte("test body content")),
983			},
984		},
985		"outpost accesspoint": {
986			input: s3.PutObjectInput{
987				Bucket: aws.String("arn:aws:s3-outposts:us-west-2:012345678901:outpost/op-1234567890123456/accesspoint/myaccesspoint"),
988				Key:    aws.String("test-key"),
989				Body:   bytes.NewReader([]byte("test body content")),
990			},
991		},
992		"s3-object-lambda accesspoint": {
993			input: s3.PutObjectInput{
994				Bucket: aws.String("arn:aws:s3-object-lambda:us-west-2:123456789012:accesspoint/myap"),
995				Key:    aws.String("test-key"),
996				Body:   bytes.NewReader([]byte("test body content")),
997			},
998			wantErr: true,
999		},
1000	}
1001
1002	for name, tt := range cases {
1003		t.Run(name, func(t *testing.T) {
1004			client, _, _ := s3testing.NewUploadLoggingClient(nil)
1005			client.ConsumeBody = true
1006
1007			uploader := manager.NewUploader(client)
1008
1009			_, err := uploader.Upload(context.Background(), &tt.input)
1010			if (err != nil) != tt.wantErr {
1011				t.Errorf("err: %v, wantErr: %v", err, tt.wantErr)
1012			}
1013		})
1014	}
1015}
1016
1017type mockS3UploadServer struct {
1018	*http.ServeMux
1019
1020	tb          testing.TB
1021	partHandler []http.Handler
1022}
1023
1024func newMockS3UploadServer(tb testing.TB, partHandler []http.Handler) *mockS3UploadServer {
1025	s := &mockS3UploadServer{
1026		ServeMux:    http.NewServeMux(),
1027		partHandler: partHandler,
1028		tb:          tb,
1029	}
1030
1031	s.HandleFunc("/", s.handleRequest)
1032
1033	return s
1034}
1035
1036func (s mockS3UploadServer) handleRequest(w http.ResponseWriter, r *http.Request) {
1037	defer r.Body.Close()
1038
1039	_, hasUploads := r.URL.Query()["uploads"]
1040
1041	switch {
1042	case r.Method == "POST" && hasUploads:
1043		// CreateMultipartUpload
1044		w.Header().Set("Content-Length", strconv.Itoa(len(createUploadResp)))
1045		w.Write([]byte(createUploadResp))
1046
1047	case r.Method == "PUT":
1048		// UploadPart
1049		partNumStr := r.URL.Query().Get("partNumber")
1050		id, err := strconv.Atoi(partNumStr)
1051		if err != nil {
1052			failRequest(w, 400, "BadRequest",
1053				fmt.Sprintf("unable to parse partNumber, %q, %v",
1054					partNumStr, err))
1055			return
1056		}
1057		id--
1058		if id < 0 || id >= len(s.partHandler) {
1059			failRequest(w, 400, "BadRequest",
1060				fmt.Sprintf("invalid partNumber %v", id))
1061			return
1062		}
1063		s.partHandler[id].ServeHTTP(w, r)
1064
1065	case r.Method == "POST":
1066		// CompleteMultipartUpload
1067		w.Header().Set("Content-Length", strconv.Itoa(len(completeUploadResp)))
1068		w.Write([]byte(completeUploadResp))
1069
1070	case r.Method == "DELETE":
1071		// AbortMultipartUpload
1072		w.Header().Set("Content-Length", strconv.Itoa(len(abortUploadResp)))
1073		w.WriteHeader(200)
1074		w.Write([]byte(abortUploadResp))
1075
1076	default:
1077		failRequest(w, 400, "BadRequest",
1078			fmt.Sprintf("invalid request %v %v", r.Method, r.URL))
1079	}
1080}
1081
1082func failRequest(w http.ResponseWriter, status int, code, msg string) {
1083	msg = fmt.Sprintf(baseRequestErrorResp, code, msg)
1084	w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
1085	w.WriteHeader(status)
1086	w.Write([]byte(msg))
1087}
1088
1089type successPartHandler struct {
1090	tb testing.TB
1091}
1092
1093func (h successPartHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
1094	defer r.Body.Close()
1095
1096	n, err := io.Copy(ioutil.Discard, r.Body)
1097	if err != nil {
1098		failRequest(w, 400, "BadRequest",
1099			fmt.Sprintf("failed to read body, %v", err))
1100		return
1101	}
1102
1103	contLenStr := r.Header.Get("Content-Length")
1104	expectLen, err := strconv.ParseInt(contLenStr, 10, 64)
1105	if err != nil {
1106		h.tb.Logf("expect content-length, got %q, %v", contLenStr, err)
1107		failRequest(w, 400, "BadRequest",
1108			fmt.Sprintf("unable to get content-length %v", err))
1109		return
1110	}
1111	if e, a := expectLen, n; e != a {
1112		h.tb.Logf("expect %v read, got %v", e, a)
1113		failRequest(w, 400, "BadRequest",
1114			fmt.Sprintf(
1115				"content-length and body do not match, %v, %v", e, a))
1116		return
1117	}
1118
1119	w.Header().Set("Content-Length", strconv.Itoa(len(uploadPartResp)))
1120	w.Write([]byte(uploadPartResp))
1121}
1122
1123type failPartHandler struct {
1124	tb testing.TB
1125
1126	failsRemaining int
1127	successHandler http.Handler
1128}
1129
1130func (h *failPartHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
1131	defer r.Body.Close()
1132
1133	if h.failsRemaining == 0 && h.successHandler != nil {
1134		h.successHandler.ServeHTTP(w, r)
1135		return
1136	}
1137
1138	io.Copy(ioutil.Discard, r.Body)
1139
1140	failRequest(w, 500, "InternalException",
1141		fmt.Sprintf("mock error, partNumber %v", r.URL.Query().Get("partNumber")))
1142
1143	h.failsRemaining--
1144}
1145
1146type recordedBufferProvider struct {
1147	content       []byte
1148	size          int
1149	callbackCount int
1150}
1151
1152func (r *recordedBufferProvider) GetWriteTo(seeker io.ReadSeeker) (manager.ReadSeekerWriteTo, func()) {
1153	b := make([]byte, r.size)
1154	w := &manager.BufferedReadSeekerWriteTo{BufferedReadSeeker: manager.NewBufferedReadSeeker(seeker, b)}
1155
1156	return w, func() {
1157		r.content = append(r.content, b...)
1158		r.callbackCount++
1159	}
1160}
1161
1162const createUploadResp = `<CreateMultipartUploadResponse>
1163  <Bucket>bucket</Bucket>
1164  <Key>key</Key>
1165  <UploadId>abc123</UploadId>
1166</CreateMultipartUploadResponse>`
1167
1168const uploadPartResp = `<UploadPartResponse>
1169  <ETag>key</ETag>
1170</UploadPartResponse>`
1171const baseRequestErrorResp = `<batchItemError>
1172  <Code>%s</Code>
1173  <Message>%s</Message>
1174  <RequestId>request-id</RequestId>
1175  <HostId>host-id</HostId>
1176</batchItemError>`
1177
1178const completeUploadResp = `<CompleteMultipartUploadResponse>
1179  <Bucket>bucket</Bucket>
1180  <Key>key</Key>
1181  <ETag>key</ETag>
1182  <Location>https://bucket.us-west-2.amazonaws.com/key</Location>
1183  <UploadId>abc123</UploadId>
1184</CompleteMultipartUploadResponse>`
1185
1186const abortUploadResp = `<AbortMultipartUploadResponse></AbortMultipartUploadResponse>`
1187