1package manager_test 2 3import ( 4 "bytes" 5 "context" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "net/http" 10 "net/http/httptest" 11 "os" 12 "reflect" 13 "regexp" 14 "sort" 15 "strconv" 16 "strings" 17 "testing" 18 19 "github.com/aws/aws-sdk-go-v2/aws" 20 "github.com/aws/aws-sdk-go-v2/aws/retry" 21 "github.com/aws/aws-sdk-go-v2/feature/s3/manager" 22 s3testing "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/testing" 23 "github.com/aws/aws-sdk-go-v2/internal/awstesting" 24 "github.com/aws/aws-sdk-go-v2/internal/sdk" 25 "github.com/aws/aws-sdk-go-v2/service/s3" 26 "github.com/aws/aws-sdk-go-v2/service/s3/types" 27 "github.com/google/go-cmp/cmp" 28) 29 30// getReaderLength discards the bytes from reader and returns the length 31func getReaderLength(r io.Reader) int64 { 32 n, _ := io.Copy(ioutil.Discard, r) 33 return n 34} 35 36func TestUploadOrderMulti(t *testing.T) { 37 c, invocations, args := s3testing.NewUploadLoggingClient(nil) 38 u := manager.NewUploader(c) 39 40 resp, err := u.Upload(context.Background(), &s3.PutObjectInput{ 41 Bucket: aws.String("Bucket"), 42 Key: aws.String("Key - value"), 43 Body: bytes.NewReader(buf12MB), 44 ServerSideEncryption: "aws:kms", 45 SSEKMSKeyId: aws.String("KmsId"), 46 ContentType: aws.String("content/type"), 47 }) 48 49 if err != nil { 50 t.Errorf("Expected no error but received %v", err) 51 } 52 53 if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", 54 "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 { 55 t.Error(err) 56 } 57 58 if e, a := `https://mock.amazonaws.com/key`, resp.Location; e != a { 59 t.Errorf("expect %q, got %q", e, a) 60 } 61 62 if "UPLOAD-ID" != resp.UploadID { 63 t.Errorf("expect %q, got %q", "UPLOAD-ID", resp.UploadID) 64 } 65 66 if "VERSION-ID" != *resp.VersionID { 67 t.Errorf("expect %q, got %q", "VERSION-ID", *resp.VersionID) 68 } 69 70 // Validate input values 71 72 // UploadPart 73 for i := 1; i < 4; i++ { 74 v := aws.ToString((*args)[i].(*s3.UploadPartInput).UploadId) 75 if "UPLOAD-ID" != v { 76 t.Errorf("Expected %q, but received %q", "UPLOAD-ID", v) 77 } 78 } 79 80 // CompleteMultipartUpload 81 v := aws.ToString((*args)[4].(*s3.CompleteMultipartUploadInput).UploadId) 82 if "UPLOAD-ID" != v { 83 t.Errorf("Expected %q, but received %q", "UPLOAD-ID", v) 84 } 85 86 parts := (*args)[4].(*s3.CompleteMultipartUploadInput).MultipartUpload.Parts 87 88 for i := 0; i < 3; i++ { 89 num := parts[i].PartNumber 90 etag := aws.ToString(parts[i].ETag) 91 92 if int32(i+1) != num { 93 t.Errorf("expect %d, got %d", i+1, num) 94 } 95 96 if matched, err := regexp.MatchString(`^ETAG\d+$`, etag); !matched || err != nil { 97 t.Errorf("Failed regexp expression `^ETAG\\d+$`") 98 } 99 } 100 101 // Custom headers 102 cmu := (*args)[0].(*s3.CreateMultipartUploadInput) 103 104 if e, a := types.ServerSideEncryption("aws:kms"), cmu.ServerSideEncryption; e != a { 105 t.Errorf("expect %q, got %q", e, a) 106 } 107 108 if e, a := "KmsId", aws.ToString(cmu.SSEKMSKeyId); e != a { 109 t.Errorf("expect %q, got %q", e, a) 110 } 111 112 if e, a := "content/type", aws.ToString(cmu.ContentType); e != a { 113 t.Errorf("expect %q, got %q", e, a) 114 } 115} 116 117func TestUploadOrderMultiDifferentPartSize(t *testing.T) { 118 s, ops, args := s3testing.NewUploadLoggingClient(nil) 119 mgr := manager.NewUploader(s, func(u *manager.Uploader) { 120 u.PartSize = 1024 * 1024 * 7 121 u.Concurrency = 1 122 }) 123 _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 124 Bucket: aws.String("Bucket"), 125 Key: aws.String("Key"), 126 Body: bytes.NewReader(buf12MB), 127 }) 128 129 if err != nil { 130 t.Errorf("expect no error, got %v", err) 131 } 132 133 vals := []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"} 134 if !reflect.DeepEqual(vals, *ops) { 135 t.Errorf("expect %v, got %v", vals, *ops) 136 } 137 138 // Part lengths 139 if len := getReaderLength((*args)[1].(*s3.UploadPartInput).Body); 1024*1024*7 != len { 140 t.Errorf("expect %d, got %d", 1024*1024*7, len) 141 } 142 if len := getReaderLength((*args)[2].(*s3.UploadPartInput).Body); 1024*1024*5 != len { 143 t.Errorf("expect %d, got %d", 1024*1024*5, len) 144 } 145} 146 147func TestUploadIncreasePartSize(t *testing.T) { 148 s, invocations, args := s3testing.NewUploadLoggingClient(nil) 149 mgr := manager.NewUploader(s, func(u *manager.Uploader) { 150 u.Concurrency = 1 151 u.MaxUploadParts = 2 152 }) 153 _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 154 Bucket: aws.String("Bucket"), 155 Key: aws.String("Key"), 156 Body: bytes.NewReader(buf12MB), 157 }) 158 159 if err != nil { 160 t.Errorf("expect no error, got %v", err) 161 } 162 163 if int64(manager.DefaultDownloadPartSize) != mgr.PartSize { 164 t.Errorf("expect %d, got %d", manager.DefaultDownloadPartSize, mgr.PartSize) 165 } 166 167 if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 { 168 t.Error(diff) 169 } 170 171 // Part lengths 172 if len := getReaderLength((*args)[1].(*s3.UploadPartInput).Body); (1024*1024*6)+1 != len { 173 t.Errorf("expect %d, got %d", (1024*1024*6)+1, len) 174 } 175 176 if len := getReaderLength((*args)[2].(*s3.UploadPartInput).Body); (1024*1024*6)-1 != len { 177 t.Errorf("expect %d, got %d", (1024*1024*6)-1, len) 178 } 179} 180 181func TestUploadFailIfPartSizeTooSmall(t *testing.T) { 182 mgr := manager.NewUploader(s3.New(s3.Options{}), func(u *manager.Uploader) { 183 u.PartSize = 5 184 }) 185 resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 186 Bucket: aws.String("Bucket"), 187 Key: aws.String("Key"), 188 Body: bytes.NewReader(buf12MB), 189 }) 190 191 if resp != nil { 192 t.Errorf("Expected response to be nil, but received %v", resp) 193 } 194 195 if err == nil { 196 t.Errorf("Expected error, but received nil") 197 } 198 199 if e, a := "part size must be at least", err.Error(); !strings.Contains(a, e) { 200 t.Errorf("expect %v to be in %v", e, a) 201 } 202} 203 204func TestUploadOrderSingle(t *testing.T) { 205 client, invocations, params := s3testing.NewUploadLoggingClient(nil) 206 mgr := manager.NewUploader(client) 207 resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 208 Bucket: aws.String("Bucket"), 209 Key: aws.String("Key - value"), 210 Body: bytes.NewReader(buf2MB), 211 ServerSideEncryption: "aws:kms", 212 SSEKMSKeyId: aws.String("KmsId"), 213 ContentType: aws.String("content/type"), 214 }) 215 216 if err != nil { 217 t.Errorf("expect no error but received %v", err) 218 } 219 220 if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 { 221 t.Error(diff) 222 } 223 224 if e, a := `https://mock.amazonaws.com/key`, resp.Location; e != a { 225 t.Errorf("expect %q, got %q", e, a) 226 } 227 228 if e := "VERSION-ID"; e != *resp.VersionID { 229 t.Errorf("expect %q, got %q", e, *resp.VersionID) 230 } 231 232 if len(resp.UploadID) > 0 { 233 t.Errorf("expect empty string, got %q", resp.UploadID) 234 } 235 236 putObjectInput := (*params)[0].(*s3.PutObjectInput) 237 238 if e, a := types.ServerSideEncryption("aws:kms"), putObjectInput.ServerSideEncryption; e != a { 239 t.Errorf("expect %q, got %q", e, a) 240 } 241 242 if e, a := "KmsId", aws.ToString(putObjectInput.SSEKMSKeyId); e != a { 243 t.Errorf("expect %q, got %q", e, a) 244 } 245 246 if e, a := "content/type", aws.ToString(putObjectInput.ContentType); e != a { 247 t.Errorf("Expected %q, but received %q", e, a) 248 } 249} 250 251func TestUploadOrderSingleFailure(t *testing.T) { 252 client, ops, _ := s3testing.NewUploadLoggingClient(nil) 253 254 client.PutObjectFn = func(*s3testing.UploadLoggingClient, *s3.PutObjectInput) (*s3.PutObjectOutput, error) { 255 return nil, fmt.Errorf("put object failure") 256 } 257 258 mgr := manager.NewUploader(client) 259 resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 260 Bucket: aws.String("Bucket"), 261 Key: aws.String("Key"), 262 Body: bytes.NewReader(buf2MB), 263 }) 264 265 if err == nil { 266 t.Error("expect error, got nil") 267 } 268 269 if diff := cmp.Diff([]string{"PutObject"}, *ops); len(diff) > 0 { 270 t.Error(diff) 271 } 272 273 if resp != nil { 274 t.Errorf("expect response to be nil, got %v", resp) 275 } 276} 277 278func TestUploadOrderZero(t *testing.T) { 279 c, invocations, params := s3testing.NewUploadLoggingClient(nil) 280 mgr := manager.NewUploader(c) 281 resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 282 Bucket: aws.String("Bucket"), 283 Key: aws.String("Key"), 284 Body: bytes.NewReader(make([]byte, 0)), 285 }) 286 287 if err != nil { 288 t.Errorf("expect no error, got %v", err) 289 } 290 291 if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 { 292 t.Error(diff) 293 } 294 295 if len(resp.Location) == 0 { 296 t.Error("expect Location to not be empty") 297 } 298 299 if len(resp.UploadID) > 0 { 300 t.Errorf("expect empty string, got %q", resp.UploadID) 301 } 302 303 if e, a := int64(0), getReaderLength((*params)[0].(*s3.PutObjectInput).Body); e != a { 304 t.Errorf("Expected %d, but received %d", e, a) 305 } 306} 307 308func TestUploadOrderMultiFailure(t *testing.T) { 309 c, invocations, _ := s3testing.NewUploadLoggingClient(nil) 310 311 c.UploadPartFn = func(u *s3testing.UploadLoggingClient, _ *s3.UploadPartInput) (*s3.UploadPartOutput, error) { 312 if u.PartNum == 2 { 313 return nil, fmt.Errorf("an unexpected error") 314 } 315 return &s3.UploadPartOutput{ETag: aws.String(fmt.Sprintf("ETAG%d", u.PartNum))}, nil 316 } 317 318 mgr := manager.NewUploader(c, func(u *manager.Uploader) { 319 u.Concurrency = 1 320 }) 321 _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 322 Bucket: aws.String("Bucket"), 323 Key: aws.String("Key"), 324 Body: bytes.NewReader(buf12MB), 325 }) 326 327 if err == nil { 328 t.Error("expect error, got nil") 329 } 330 331 if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "AbortMultipartUpload"}, *invocations); len(diff) > 0 { 332 t.Error(diff) 333 } 334} 335 336func TestUploadOrderMultiFailureOnComplete(t *testing.T) { 337 c, invocations, _ := s3testing.NewUploadLoggingClient(nil) 338 339 c.CompleteMultipartUploadFn = func(*s3testing.UploadLoggingClient, *s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error) { 340 return nil, fmt.Errorf("complete multipart error") 341 } 342 343 mgr := manager.NewUploader(c, func(u *manager.Uploader) { 344 u.Concurrency = 1 345 }) 346 _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 347 Bucket: aws.String("Bucket"), 348 Key: aws.String("Key"), 349 Body: bytes.NewReader(buf12MB), 350 }) 351 352 if err == nil { 353 t.Error("expect error, got nil") 354 } 355 356 if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "UploadPart", 357 "CompleteMultipartUpload", "AbortMultipartUpload"}, *invocations); len(diff) > 0 { 358 t.Error(diff) 359 } 360} 361 362func TestUploadOrderMultiFailureOnCreate(t *testing.T) { 363 c, invocations, _ := s3testing.NewUploadLoggingClient(nil) 364 365 c.CreateMultipartUploadFn = func(*s3testing.UploadLoggingClient, *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error) { 366 return nil, fmt.Errorf("create multipart upload failure") 367 } 368 369 mgr := manager.NewUploader(c) 370 _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 371 Bucket: aws.String("Bucket"), 372 Key: aws.String("Key"), 373 Body: bytes.NewReader(make([]byte, 1024*1024*12)), 374 }) 375 376 if err == nil { 377 t.Error("expect error, got nil") 378 } 379 380 if diff := cmp.Diff([]string{"CreateMultipartUpload"}, *invocations); len(diff) > 0 { 381 t.Error(diff) 382 } 383} 384 385func TestUploadOrderMultiFailureLeaveParts(t *testing.T) { 386 c, invocations, _ := s3testing.NewUploadLoggingClient(nil) 387 388 c.UploadPartFn = func(u *s3testing.UploadLoggingClient, _ *s3.UploadPartInput) (*s3.UploadPartOutput, error) { 389 if u.PartNum == 2 { 390 return nil, fmt.Errorf("upload part failure") 391 } 392 return &s3.UploadPartOutput{ETag: aws.String(fmt.Sprintf("ETAG%d", u.PartNum))}, nil 393 } 394 395 mgr := manager.NewUploader(c, func(u *manager.Uploader) { 396 u.Concurrency = 1 397 u.LeavePartsOnError = true 398 }) 399 _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 400 Bucket: aws.String("Bucket"), 401 Key: aws.String("Key"), 402 Body: bytes.NewReader(make([]byte, 1024*1024*12)), 403 }) 404 405 if err == nil { 406 t.Error("expect error, got nil") 407 } 408 409 if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart"}, *invocations); len(diff) > 0 { 410 t.Error(err) 411 } 412} 413 414type failreader struct { 415 times int 416 failCount int 417} 418 419func (f *failreader) Read(b []byte) (int, error) { 420 f.failCount++ 421 if f.failCount >= f.times { 422 return 0, fmt.Errorf("random failure") 423 } 424 return len(b), nil 425} 426 427func TestUploadOrderReadFail1(t *testing.T) { 428 c, invocations, _ := s3testing.NewUploadLoggingClient(nil) 429 mgr := manager.NewUploader(c) 430 _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 431 Bucket: aws.String("Bucket"), 432 Key: aws.String("Key"), 433 Body: &failreader{times: 1}, 434 }) 435 if err == nil { 436 t.Fatalf("expect error to not be nil") 437 } 438 439 if e, a := "random failure", err.Error(); !strings.Contains(a, e) { 440 t.Errorf("expect %v, got %v", e, a) 441 } 442 443 if diff := cmp.Diff([]string(nil), *invocations); len(diff) > 0 { 444 t.Error(diff) 445 } 446} 447 448func TestUploadOrderReadFail2(t *testing.T) { 449 c, invocations, _ := s3testing.NewUploadLoggingClient([]string{"UploadPart"}) 450 mgr := manager.NewUploader(c, func(u *manager.Uploader) { 451 u.Concurrency = 1 452 }) 453 _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 454 Bucket: aws.String("Bucket"), 455 Key: aws.String("Key"), 456 Body: &failreader{times: 2}, 457 }) 458 if err == nil { 459 t.Fatalf("expect error to not be nil") 460 } 461 462 if e, a := "random failure", err.Error(); !strings.Contains(a, e) { 463 t.Errorf("expect %v, got %q", e, a) 464 } 465 466 if diff := cmp.Diff([]string{"CreateMultipartUpload", "AbortMultipartUpload"}, *invocations); len(diff) > 0 { 467 t.Error(diff) 468 } 469} 470 471type sizedReader struct { 472 size int 473 cur int 474 err error 475} 476 477func (s *sizedReader) Read(p []byte) (n int, err error) { 478 if s.cur >= s.size { 479 if s.err == nil { 480 s.err = io.EOF 481 } 482 return 0, s.err 483 } 484 485 n = len(p) 486 s.cur += len(p) 487 if s.cur > s.size { 488 n -= s.cur - s.size 489 } 490 491 return n, err 492} 493 494func TestUploadOrderMultiBufferedReader(t *testing.T) { 495 c, invocations, params := s3testing.NewUploadLoggingClient(nil) 496 mgr := manager.NewUploader(c) 497 _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 498 Bucket: aws.String("Bucket"), 499 Key: aws.String("Key"), 500 Body: &sizedReader{size: 1024 * 1024 * 12}, 501 }) 502 if err != nil { 503 t.Errorf("expect no error, got %v", err) 504 } 505 506 if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", 507 "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 { 508 t.Error(diff) 509 } 510 511 // Part lengths 512 var parts []int64 513 for i := 1; i <= 3; i++ { 514 parts = append(parts, getReaderLength((*params)[i].(*s3.UploadPartInput).Body)) 515 } 516 sort.Slice(parts, func(i, j int) bool { 517 return parts[i] < parts[j] 518 }) 519 520 if diff := cmp.Diff([]int64{1024 * 1024 * 2, 1024 * 1024 * 5, 1024 * 1024 * 5}, parts); len(diff) > 0 { 521 t.Error(diff) 522 } 523} 524 525func TestUploadOrderMultiBufferedReaderPartial(t *testing.T) { 526 c, invocations, params := s3testing.NewUploadLoggingClient(nil) 527 mgr := manager.NewUploader(c) 528 _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 529 Bucket: aws.String("Bucket"), 530 Key: aws.String("Key"), 531 Body: &sizedReader{size: 1024 * 1024 * 12, err: io.EOF}, 532 }) 533 if err != nil { 534 t.Errorf("expect no error, got %v", err) 535 } 536 537 if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", 538 "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 { 539 t.Error(diff) 540 } 541 542 // Part lengths 543 var parts []int64 544 for i := 1; i <= 3; i++ { 545 parts = append(parts, getReaderLength((*params)[i].(*s3.UploadPartInput).Body)) 546 } 547 sort.Slice(parts, func(i, j int) bool { 548 return parts[i] < parts[j] 549 }) 550 551 if diff := cmp.Diff([]int64{1024 * 1024 * 2, 1024 * 1024 * 5, 1024 * 1024 * 5}, parts); len(diff) > 0 { 552 t.Error(diff) 553 } 554} 555 556// TestUploadOrderMultiBufferedReaderEOF tests the edge case where the 557// file size is the same as part size. 558func TestUploadOrderMultiBufferedReaderEOF(t *testing.T) { 559 c, invocations, params := s3testing.NewUploadLoggingClient(nil) 560 mgr := manager.NewUploader(c) 561 _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 562 Bucket: aws.String("Bucket"), 563 Key: aws.String("Key"), 564 Body: &sizedReader{size: 1024 * 1024 * 10, err: io.EOF}, 565 }) 566 567 if err != nil { 568 t.Errorf("expect no error, got %v", err) 569 } 570 571 if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 { 572 t.Error(diff) 573 } 574 575 // Part lengths 576 var parts []int64 577 for i := 1; i <= 2; i++ { 578 parts = append(parts, getReaderLength((*params)[i].(*s3.UploadPartInput).Body)) 579 } 580 sort.Slice(parts, func(i, j int) bool { 581 return parts[i] < parts[j] 582 }) 583 584 if diff := cmp.Diff([]int64{1024 * 1024 * 5, 1024 * 1024 * 5}, parts); len(diff) > 0 { 585 t.Error(diff) 586 } 587} 588 589func TestUploadOrderMultiBufferedReaderExceedTotalParts(t *testing.T) { 590 c, invocations, _ := s3testing.NewUploadLoggingClient([]string{"UploadPart"}) 591 mgr := manager.NewUploader(c, func(u *manager.Uploader) { 592 u.Concurrency = 1 593 u.MaxUploadParts = 2 594 }) 595 resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 596 Bucket: aws.String("Bucket"), 597 Key: aws.String("Key"), 598 Body: &sizedReader{size: 1024 * 1024 * 12}, 599 }) 600 if err == nil { 601 t.Fatal("expect error, got nil") 602 } 603 604 if resp != nil { 605 t.Errorf("expect nil, got %v", resp) 606 } 607 608 if diff := cmp.Diff([]string{"CreateMultipartUpload", "AbortMultipartUpload"}, *invocations); len(diff) > 0 { 609 t.Error(diff) 610 } 611 612 if !strings.Contains(err.Error(), "configured MaxUploadParts (2)") { 613 t.Errorf("expect 'configured MaxUploadParts (2)', got %q", err.Error()) 614 } 615} 616 617func TestUploadOrderSingleBufferedReader(t *testing.T) { 618 c, invocations, _ := s3testing.NewUploadLoggingClient(nil) 619 mgr := manager.NewUploader(c) 620 resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 621 Bucket: aws.String("Bucket"), 622 Key: aws.String("Key"), 623 Body: &sizedReader{size: 1024 * 1024 * 2}, 624 }) 625 626 if err != nil { 627 t.Errorf("expect no error, got %v", err) 628 } 629 630 if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 { 631 t.Error(diff) 632 } 633 634 if len(resp.Location) == 0 { 635 t.Error("expect a value in Location") 636 } 637 638 if len(resp.UploadID) > 0 { 639 t.Errorf("expect no value, got %q", resp.UploadID) 640 } 641} 642 643func TestUploadZeroLenObject(t *testing.T) { 644 client, invocations, _ := s3testing.NewUploadLoggingClient(nil) 645 646 mgr := manager.NewUploader(client) 647 resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 648 Bucket: aws.String("Bucket"), 649 Key: aws.String("Key"), 650 Body: strings.NewReader(""), 651 }) 652 653 if err != nil { 654 t.Errorf("expect no error but received %v", err) 655 } 656 if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 { 657 t.Errorf("expect request to have been made, but was not, %v", diff) 658 } 659 660 // TODO: not needed? 661 if len(resp.Location) == 0 { 662 t.Error("expect a non-empty string value for Location") 663 } 664 665 if len(resp.UploadID) > 0 { 666 t.Errorf("expect empty string, but received %q", resp.UploadID) 667 } 668} 669 670type testIncompleteReader struct { 671 Size int64 672 read int64 673} 674 675func (r *testIncompleteReader) Read(p []byte) (n int, err error) { 676 r.read += int64(len(p)) 677 if r.read >= r.Size { 678 return int(r.read - r.Size), io.ErrUnexpectedEOF 679 } 680 return len(p), nil 681} 682 683func TestUploadUnexpectedEOF(t *testing.T) { 684 c, invocations, _ := s3testing.NewUploadLoggingClient(nil) 685 mgr := manager.NewUploader(c, func(u *manager.Uploader) { 686 u.Concurrency = 1 687 u.PartSize = manager.MinUploadPartSize 688 }) 689 _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 690 Bucket: aws.String("Bucket"), 691 Key: aws.String("Key"), 692 Body: &testIncompleteReader{ 693 Size: manager.MinUploadPartSize + 1, 694 }, 695 }) 696 if err == nil { 697 t.Error("expect error, got nil") 698 } 699 700 // Ensure upload started. 701 if e, a := "CreateMultipartUpload", (*invocations)[0]; e != a { 702 t.Errorf("expect %q, got %q", e, a) 703 } 704 705 // Part may or may not be sent because of timing of sending parts and 706 // reading next part in upload manager. Just check for the last abort. 707 if e, a := "AbortMultipartUpload", (*invocations)[len(*invocations)-1]; e != a { 708 t.Errorf("expect %q, got %q", e, a) 709 } 710} 711 712func TestSSE(t *testing.T) { 713 client, _, _ := s3testing.NewUploadLoggingClient(nil) 714 client.UploadPartFn = func(u *s3testing.UploadLoggingClient, params *s3.UploadPartInput) (*s3.UploadPartOutput, error) { 715 if params.SSECustomerAlgorithm == nil { 716 t.Fatal("SSECustomerAlgoritm should not be nil") 717 } 718 if params.SSECustomerKey == nil { 719 t.Fatal("SSECustomerKey should not be nil") 720 } 721 return &s3.UploadPartOutput{ETag: aws.String(fmt.Sprintf("ETAG%d", u.PartNum))}, nil 722 } 723 724 mgr := manager.NewUploader(client, func(u *manager.Uploader) { 725 u.Concurrency = 5 726 }) 727 728 _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 729 Bucket: aws.String("Bucket"), 730 Key: aws.String("Key"), 731 SSECustomerAlgorithm: aws.String("AES256"), 732 SSECustomerKey: aws.String("foo"), 733 Body: bytes.NewBuffer(make([]byte, 1024*1024*10)), 734 }) 735 736 if err != nil { 737 t.Fatal("Expected no error, but received" + err.Error()) 738 } 739} 740 741func TestUploadWithContextCanceled(t *testing.T) { 742 u := manager.NewUploader(s3.New(s3.Options{ 743 UsePathStyle: true, 744 Region: "mock-region", 745 })) 746 747 params := s3.PutObjectInput{ 748 Bucket: aws.String("Bucket"), 749 Key: aws.String("Key"), 750 Body: bytes.NewReader(make([]byte, 0)), 751 } 752 753 ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})} 754 ctx.Error = fmt.Errorf("context canceled") 755 close(ctx.DoneCh) 756 757 _, err := u.Upload(ctx, ¶ms) 758 if err == nil { 759 t.Fatalf("expect error, got nil") 760 } 761 762 if e, a := "canceled", err.Error(); !strings.Contains(a, e) { 763 t.Errorf("expected error message to contain %q, but did not %q", e, a) 764 } 765} 766 767// S3 Uploader incorrectly fails an upload if the content being uploaded 768// has a size of MinPartSize * MaxUploadParts. 769// Github: aws/aws-sdk-go#2557 770func TestUploadMaxPartsEOF(t *testing.T) { 771 c, invocations, _ := s3testing.NewUploadLoggingClient(nil) 772 mgr := manager.NewUploader(c, func(u *manager.Uploader) { 773 u.Concurrency = 1 774 u.PartSize = manager.DefaultUploadPartSize 775 u.MaxUploadParts = 2 776 }) 777 f := bytes.NewReader(make([]byte, int(mgr.PartSize)*int(mgr.MaxUploadParts))) 778 779 r1 := io.NewSectionReader(f, 0, manager.DefaultUploadPartSize) 780 r2 := io.NewSectionReader(f, manager.DefaultUploadPartSize, 2*manager.DefaultUploadPartSize) 781 body := io.MultiReader(r1, r2) 782 783 _, err := mgr.Upload(context.Background(), &s3.PutObjectInput{ 784 Bucket: aws.String("Bucket"), 785 Key: aws.String("Key"), 786 Body: body, 787 }) 788 789 if err != nil { 790 t.Fatalf("expect no error, got %v", err) 791 } 792 793 expectOps := []string{ 794 "CreateMultipartUpload", 795 "UploadPart", 796 "UploadPart", 797 "CompleteMultipartUpload", 798 } 799 if diff := cmp.Diff(expectOps, *invocations); len(diff) > 0 { 800 t.Error(diff) 801 } 802} 803 804func createTempFile(t *testing.T, size int64) (*os.File, func(*testing.T), error) { 805 file, err := ioutil.TempFile(os.TempDir(), aws.SDKName+t.Name()) 806 if err != nil { 807 return nil, nil, err 808 } 809 filename := file.Name() 810 if err := file.Truncate(size); err != nil { 811 return nil, nil, err 812 } 813 814 return file, 815 func(t *testing.T) { 816 if err := file.Close(); err != nil { 817 t.Errorf("failed to close temp file, %s, %v", filename, err) 818 } 819 if err := os.Remove(filename); err != nil { 820 t.Errorf("failed to remove temp file, %s, %v", filename, err) 821 } 822 }, 823 nil 824} 825 826func buildFailHandlers(tb testing.TB, parts, retry int) []http.Handler { 827 handlers := make([]http.Handler, parts) 828 for i := 0; i < len(handlers); i++ { 829 handlers[i] = &failPartHandler{ 830 tb: tb, 831 failsRemaining: retry, 832 successHandler: successPartHandler{tb: tb}, 833 } 834 } 835 836 return handlers 837} 838 839func TestUploadRetry(t *testing.T) { 840 const numParts, retries = 3, 10 841 842 testFile, testFileCleanup, err := createTempFile(t, manager.DefaultUploadPartSize*numParts) 843 if err != nil { 844 t.Fatalf("failed to create test file, %v", err) 845 } 846 defer testFileCleanup(t) 847 848 cases := map[string]struct { 849 Body io.Reader 850 PartHandlers func(testing.TB) []http.Handler 851 }{ 852 "bytes.Buffer": { 853 Body: bytes.NewBuffer(make([]byte, manager.DefaultUploadPartSize*numParts)), 854 PartHandlers: func(tb testing.TB) []http.Handler { 855 return buildFailHandlers(tb, numParts, retries) 856 }, 857 }, 858 "bytes.Reader": { 859 Body: bytes.NewReader(make([]byte, manager.DefaultUploadPartSize*numParts)), 860 PartHandlers: func(tb testing.TB) []http.Handler { 861 return buildFailHandlers(tb, numParts, retries) 862 }, 863 }, 864 "os.File": { 865 Body: testFile, 866 PartHandlers: func(tb testing.TB) []http.Handler { 867 return buildFailHandlers(tb, numParts, retries) 868 }, 869 }, 870 } 871 872 for name, c := range cases { 873 t.Run(name, func(t *testing.T) { 874 restoreSleep := sdk.TestingUseNopSleep() 875 defer restoreSleep() 876 877 mux := newMockS3UploadServer(t, c.PartHandlers(t)) 878 server := httptest.NewServer(mux) 879 defer server.Close() 880 881 client := s3.New(s3.Options{ 882 EndpointResolver: s3testing.EndpointResolverFunc(func(region string, options s3.EndpointResolverOptions) (aws.Endpoint, error) { 883 return aws.Endpoint{ 884 URL: server.URL, 885 }, nil 886 }), 887 UsePathStyle: true, 888 Retryer: retry.NewStandard(func(o *retry.StandardOptions) { 889 o.MaxAttempts = retries + 1 890 }), 891 }) 892 893 uploader := manager.NewUploader(client) 894 _, err := uploader.Upload(context.Background(), &s3.PutObjectInput{ 895 Bucket: aws.String("bucket"), 896 Key: aws.String("key"), 897 Body: c.Body, 898 }) 899 900 if err != nil { 901 t.Fatalf("expect no error, got %v", err) 902 } 903 }) 904 } 905} 906 907func TestUploadBufferStrategy(t *testing.T) { 908 cases := map[string]struct { 909 PartSize int64 910 Size int64 911 Strategy manager.ReadSeekerWriteToProvider 912 callbacks int 913 }{ 914 "NoBuffer": { 915 PartSize: manager.DefaultUploadPartSize, 916 Strategy: nil, 917 }, 918 "SinglePart": { 919 PartSize: manager.DefaultUploadPartSize, 920 Size: manager.DefaultUploadPartSize, 921 Strategy: &recordedBufferProvider{size: int(manager.DefaultUploadPartSize)}, 922 callbacks: 1, 923 }, 924 "MultiPart": { 925 PartSize: manager.DefaultUploadPartSize, 926 Size: manager.DefaultUploadPartSize * 2, 927 Strategy: &recordedBufferProvider{size: int(manager.DefaultUploadPartSize)}, 928 callbacks: 2, 929 }, 930 } 931 932 for name, tCase := range cases { 933 t.Run(name, func(t *testing.T) { 934 client, _, _ := s3testing.NewUploadLoggingClient(nil) 935 client.ConsumeBody = true 936 937 uploader := manager.NewUploader(client, func(u *manager.Uploader) { 938 u.PartSize = tCase.PartSize 939 u.BufferProvider = tCase.Strategy 940 u.Concurrency = 1 941 }) 942 943 expected := s3testing.GetTestBytes(int(tCase.Size)) 944 _, err := uploader.Upload(context.Background(), &s3.PutObjectInput{ 945 Bucket: aws.String("bucket"), 946 Key: aws.String("key"), 947 Body: bytes.NewReader(expected), 948 }) 949 if err != nil { 950 t.Fatalf("failed to upload file: %v", err) 951 } 952 953 switch strat := tCase.Strategy.(type) { 954 case *recordedBufferProvider: 955 if !bytes.Equal(expected, strat.content) { 956 t.Errorf("content buffered did not match expected") 957 } 958 if tCase.callbacks != strat.callbackCount { 959 t.Errorf("expected %v, got %v callbacks", tCase.callbacks, strat.callbackCount) 960 } 961 } 962 }) 963 } 964} 965 966func TestUploaderValidARN(t *testing.T) { 967 cases := map[string]struct { 968 input s3.PutObjectInput 969 wantErr bool 970 }{ 971 "standard bucket": { 972 input: s3.PutObjectInput{ 973 Bucket: aws.String("test-bucket"), 974 Key: aws.String("test-key"), 975 Body: bytes.NewReader([]byte("test body content")), 976 }, 977 }, 978 "accesspoint": { 979 input: s3.PutObjectInput{ 980 Bucket: aws.String("arn:aws:s3:us-west-2:123456789012:accesspoint/myap"), 981 Key: aws.String("test-key"), 982 Body: bytes.NewReader([]byte("test body content")), 983 }, 984 }, 985 "outpost accesspoint": { 986 input: s3.PutObjectInput{ 987 Bucket: aws.String("arn:aws:s3-outposts:us-west-2:012345678901:outpost/op-1234567890123456/accesspoint/myaccesspoint"), 988 Key: aws.String("test-key"), 989 Body: bytes.NewReader([]byte("test body content")), 990 }, 991 }, 992 "s3-object-lambda accesspoint": { 993 input: s3.PutObjectInput{ 994 Bucket: aws.String("arn:aws:s3-object-lambda:us-west-2:123456789012:accesspoint/myap"), 995 Key: aws.String("test-key"), 996 Body: bytes.NewReader([]byte("test body content")), 997 }, 998 wantErr: true, 999 }, 1000 } 1001 1002 for name, tt := range cases { 1003 t.Run(name, func(t *testing.T) { 1004 client, _, _ := s3testing.NewUploadLoggingClient(nil) 1005 client.ConsumeBody = true 1006 1007 uploader := manager.NewUploader(client) 1008 1009 _, err := uploader.Upload(context.Background(), &tt.input) 1010 if (err != nil) != tt.wantErr { 1011 t.Errorf("err: %v, wantErr: %v", err, tt.wantErr) 1012 } 1013 }) 1014 } 1015} 1016 1017type mockS3UploadServer struct { 1018 *http.ServeMux 1019 1020 tb testing.TB 1021 partHandler []http.Handler 1022} 1023 1024func newMockS3UploadServer(tb testing.TB, partHandler []http.Handler) *mockS3UploadServer { 1025 s := &mockS3UploadServer{ 1026 ServeMux: http.NewServeMux(), 1027 partHandler: partHandler, 1028 tb: tb, 1029 } 1030 1031 s.HandleFunc("/", s.handleRequest) 1032 1033 return s 1034} 1035 1036func (s mockS3UploadServer) handleRequest(w http.ResponseWriter, r *http.Request) { 1037 defer r.Body.Close() 1038 1039 _, hasUploads := r.URL.Query()["uploads"] 1040 1041 switch { 1042 case r.Method == "POST" && hasUploads: 1043 // CreateMultipartUpload 1044 w.Header().Set("Content-Length", strconv.Itoa(len(createUploadResp))) 1045 w.Write([]byte(createUploadResp)) 1046 1047 case r.Method == "PUT": 1048 // UploadPart 1049 partNumStr := r.URL.Query().Get("partNumber") 1050 id, err := strconv.Atoi(partNumStr) 1051 if err != nil { 1052 failRequest(w, 400, "BadRequest", 1053 fmt.Sprintf("unable to parse partNumber, %q, %v", 1054 partNumStr, err)) 1055 return 1056 } 1057 id-- 1058 if id < 0 || id >= len(s.partHandler) { 1059 failRequest(w, 400, "BadRequest", 1060 fmt.Sprintf("invalid partNumber %v", id)) 1061 return 1062 } 1063 s.partHandler[id].ServeHTTP(w, r) 1064 1065 case r.Method == "POST": 1066 // CompleteMultipartUpload 1067 w.Header().Set("Content-Length", strconv.Itoa(len(completeUploadResp))) 1068 w.Write([]byte(completeUploadResp)) 1069 1070 case r.Method == "DELETE": 1071 // AbortMultipartUpload 1072 w.Header().Set("Content-Length", strconv.Itoa(len(abortUploadResp))) 1073 w.WriteHeader(200) 1074 w.Write([]byte(abortUploadResp)) 1075 1076 default: 1077 failRequest(w, 400, "BadRequest", 1078 fmt.Sprintf("invalid request %v %v", r.Method, r.URL)) 1079 } 1080} 1081 1082func failRequest(w http.ResponseWriter, status int, code, msg string) { 1083 msg = fmt.Sprintf(baseRequestErrorResp, code, msg) 1084 w.Header().Set("Content-Length", strconv.Itoa(len(msg))) 1085 w.WriteHeader(status) 1086 w.Write([]byte(msg)) 1087} 1088 1089type successPartHandler struct { 1090 tb testing.TB 1091} 1092 1093func (h successPartHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 1094 defer r.Body.Close() 1095 1096 n, err := io.Copy(ioutil.Discard, r.Body) 1097 if err != nil { 1098 failRequest(w, 400, "BadRequest", 1099 fmt.Sprintf("failed to read body, %v", err)) 1100 return 1101 } 1102 1103 contLenStr := r.Header.Get("Content-Length") 1104 expectLen, err := strconv.ParseInt(contLenStr, 10, 64) 1105 if err != nil { 1106 h.tb.Logf("expect content-length, got %q, %v", contLenStr, err) 1107 failRequest(w, 400, "BadRequest", 1108 fmt.Sprintf("unable to get content-length %v", err)) 1109 return 1110 } 1111 if e, a := expectLen, n; e != a { 1112 h.tb.Logf("expect %v read, got %v", e, a) 1113 failRequest(w, 400, "BadRequest", 1114 fmt.Sprintf( 1115 "content-length and body do not match, %v, %v", e, a)) 1116 return 1117 } 1118 1119 w.Header().Set("Content-Length", strconv.Itoa(len(uploadPartResp))) 1120 w.Write([]byte(uploadPartResp)) 1121} 1122 1123type failPartHandler struct { 1124 tb testing.TB 1125 1126 failsRemaining int 1127 successHandler http.Handler 1128} 1129 1130func (h *failPartHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 1131 defer r.Body.Close() 1132 1133 if h.failsRemaining == 0 && h.successHandler != nil { 1134 h.successHandler.ServeHTTP(w, r) 1135 return 1136 } 1137 1138 io.Copy(ioutil.Discard, r.Body) 1139 1140 failRequest(w, 500, "InternalException", 1141 fmt.Sprintf("mock error, partNumber %v", r.URL.Query().Get("partNumber"))) 1142 1143 h.failsRemaining-- 1144} 1145 1146type recordedBufferProvider struct { 1147 content []byte 1148 size int 1149 callbackCount int 1150} 1151 1152func (r *recordedBufferProvider) GetWriteTo(seeker io.ReadSeeker) (manager.ReadSeekerWriteTo, func()) { 1153 b := make([]byte, r.size) 1154 w := &manager.BufferedReadSeekerWriteTo{BufferedReadSeeker: manager.NewBufferedReadSeeker(seeker, b)} 1155 1156 return w, func() { 1157 r.content = append(r.content, b...) 1158 r.callbackCount++ 1159 } 1160} 1161 1162const createUploadResp = `<CreateMultipartUploadResponse> 1163 <Bucket>bucket</Bucket> 1164 <Key>key</Key> 1165 <UploadId>abc123</UploadId> 1166</CreateMultipartUploadResponse>` 1167 1168const uploadPartResp = `<UploadPartResponse> 1169 <ETag>key</ETag> 1170</UploadPartResponse>` 1171const baseRequestErrorResp = `<batchItemError> 1172 <Code>%s</Code> 1173 <Message>%s</Message> 1174 <RequestId>request-id</RequestId> 1175 <HostId>host-id</HostId> 1176</batchItemError>` 1177 1178const completeUploadResp = `<CompleteMultipartUploadResponse> 1179 <Bucket>bucket</Bucket> 1180 <Key>key</Key> 1181 <ETag>key</ETag> 1182 <Location>https://bucket.us-west-2.amazonaws.com/key</Location> 1183 <UploadId>abc123</UploadId> 1184</CompleteMultipartUploadResponse>` 1185 1186const abortUploadResp = `<AbortMultipartUploadResponse></AbortMultipartUploadResponse>` 1187