1//go:build go1.7 2// +build go1.7 3 4package s3manager 5 6import ( 7 "bytes" 8 "fmt" 9 "io" 10 "io/ioutil" 11 random "math/rand" 12 "net/http" 13 "strconv" 14 "sync" 15 "sync/atomic" 16 "testing" 17 18 "github.com/aws/aws-sdk-go/aws" 19 "github.com/aws/aws-sdk-go/aws/request" 20 "github.com/aws/aws-sdk-go/awstesting/unit" 21 "github.com/aws/aws-sdk-go/internal/sdkio" 22 "github.com/aws/aws-sdk-go/service/s3" 23 "github.com/aws/aws-sdk-go/service/s3/internal/s3testing" 24) 25 26const respBody = `<?xml version="1.0" encoding="UTF-8"?> 27<CompleteMultipartUploadOutput> 28 <Location>mockValue</Location> 29 <Bucket>mockValue</Bucket> 30 <Key>mockValue</Key> 31 <ETag>mockValue</ETag> 32</CompleteMultipartUploadOutput>` 33 34type testReader struct { 35 br *bytes.Reader 36 m sync.Mutex 37} 38 39func (r *testReader) Read(p []byte) (n int, err error) { 40 r.m.Lock() 41 defer r.m.Unlock() 42 return r.br.Read(p) 43} 44 45func TestUploadByteSlicePool(t *testing.T) { 46 cases := map[string]struct { 47 PartSize int64 48 FileSize int64 49 Concurrency int 50 ExAllocations uint64 51 }{ 52 "single part, single concurrency": { 53 PartSize: sdkio.MebiByte * 5, 54 FileSize: sdkio.MebiByte * 5, 55 ExAllocations: 2, 56 Concurrency: 1, 57 }, 58 "multi-part, single concurrency": { 59 PartSize: sdkio.MebiByte * 5, 60 FileSize: sdkio.MebiByte * 10, 61 ExAllocations: 2, 62 Concurrency: 1, 63 }, 64 "multi-part, multiple concurrency": { 65 PartSize: sdkio.MebiByte * 5, 66 FileSize: sdkio.MebiByte * 20, 67 ExAllocations: 3, 68 Concurrency: 2, 69 }, 70 } 71 72 for name, tt := range cases { 73 t.Run(name, func(t *testing.T) { 74 var p *recordedPartPool 75 76 unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool { 77 p = newRecordedPartPool(sliceSize) 78 return p 79 }) 80 defer unswap() 81 82 sess := unit.Session.Copy() 83 svc := s3.New(sess) 84 svc.Handlers.Unmarshal.Clear() 85 svc.Handlers.UnmarshalMeta.Clear() 86 svc.Handlers.UnmarshalError.Clear() 87 svc.Handlers.Send.Clear() 88 svc.Handlers.Send.PushFront(func(r *request.Request) { 89 if r.Body != nil { 90 io.Copy(ioutil.Discard, r.Body) 91 } 92 93 r.HTTPResponse = &http.Response{ 94 StatusCode: 200, 95 Body: ioutil.NopCloser(bytes.NewReader([]byte(respBody))), 96 } 97 98 switch data := r.Data.(type) { 99 case *s3.CreateMultipartUploadOutput: 100 data.UploadId = aws.String("UPLOAD-ID") 101 case *s3.UploadPartOutput: 102 data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int())) 103 case *s3.CompleteMultipartUploadOutput: 104 data.Location = aws.String("https://location") 105 data.VersionId = aws.String("VERSION-ID") 106 case *s3.PutObjectOutput: 107 data.VersionId = aws.String("VERSION-ID") 108 } 109 }) 110 111 uploader := NewUploaderWithClient(svc, func(u *Uploader) { 112 u.PartSize = tt.PartSize 113 u.Concurrency = tt.Concurrency 114 }) 115 116 expected := s3testing.GetTestBytes(int(tt.FileSize)) 117 _, err := uploader.Upload(&UploadInput{ 118 Bucket: aws.String("bucket"), 119 Key: aws.String("key"), 120 Body: &testReader{br: bytes.NewReader(expected)}, 121 }) 122 if err != nil { 123 t.Errorf("expected no error, but got %v", err) 124 } 125 126 if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 { 127 t.Fatalf("expected zero outsnatding pool parts, got %d", v) 128 } 129 130 gets, allocs := atomic.LoadUint64(&p.recordedGets), atomic.LoadUint64(&p.recordedAllocs) 131 132 t.Logf("total gets %v, total allocations %v", gets, allocs) 133 if e, a := tt.ExAllocations, allocs; a > e { 134 t.Errorf("expected %v allocations, got %v", e, a) 135 } 136 }) 137 } 138} 139 140func TestUploadByteSlicePool_Failures(t *testing.T) { 141 cases := map[string]struct { 142 PartSize int64 143 FileSize int64 144 Operations []string 145 }{ 146 "single part": { 147 PartSize: sdkio.MebiByte * 5, 148 FileSize: sdkio.MebiByte * 4, 149 Operations: []string{ 150 "PutObject", 151 }, 152 }, 153 "multi-part": { 154 PartSize: sdkio.MebiByte * 5, 155 FileSize: sdkio.MebiByte * 10, 156 Operations: []string{ 157 "CreateMultipartUpload", 158 "UploadPart", 159 "CompleteMultipartUpload", 160 }, 161 }, 162 } 163 164 for name, tt := range cases { 165 t.Run(name, func(t *testing.T) { 166 for _, operation := range tt.Operations { 167 t.Run(operation, func(t *testing.T) { 168 var p *recordedPartPool 169 170 unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool { 171 p = newRecordedPartPool(sliceSize) 172 return p 173 }) 174 defer unswap() 175 176 sess := unit.Session.Copy() 177 svc := s3.New(sess) 178 svc.Handlers.Unmarshal.Clear() 179 svc.Handlers.UnmarshalMeta.Clear() 180 svc.Handlers.UnmarshalError.Clear() 181 svc.Handlers.Send.Clear() 182 svc.Handlers.Send.PushFront(func(r *request.Request) { 183 if r.Body != nil { 184 io.Copy(ioutil.Discard, r.Body) 185 } 186 187 if r.Operation.Name == operation { 188 r.Retryable = aws.Bool(false) 189 r.Error = fmt.Errorf("request error") 190 r.HTTPResponse = &http.Response{ 191 StatusCode: 500, 192 Body: ioutil.NopCloser(bytes.NewReader([]byte{})), 193 } 194 return 195 } 196 197 r.HTTPResponse = &http.Response{ 198 StatusCode: 200, 199 Body: ioutil.NopCloser(bytes.NewReader([]byte(respBody))), 200 } 201 202 switch data := r.Data.(type) { 203 case *s3.CreateMultipartUploadOutput: 204 data.UploadId = aws.String("UPLOAD-ID") 205 case *s3.UploadPartOutput: 206 data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int())) 207 case *s3.CompleteMultipartUploadOutput: 208 data.Location = aws.String("https://location") 209 data.VersionId = aws.String("VERSION-ID") 210 case *s3.PutObjectOutput: 211 data.VersionId = aws.String("VERSION-ID") 212 } 213 }) 214 215 uploader := NewUploaderWithClient(svc, func(u *Uploader) { 216 u.Concurrency = 1 217 u.PartSize = tt.PartSize 218 }) 219 220 expected := s3testing.GetTestBytes(int(tt.FileSize)) 221 _, err := uploader.Upload(&UploadInput{ 222 Bucket: aws.String("bucket"), 223 Key: aws.String("key"), 224 Body: &testReader{br: bytes.NewReader(expected)}, 225 }) 226 if err == nil { 227 t.Fatalf("expected error but got none") 228 } 229 230 if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 { 231 t.Fatalf("expected zero outsnatding pool parts, got %d", v) 232 } 233 }) 234 } 235 }) 236 } 237} 238 239func TestUploadByteSlicePoolConcurrentMultiPartSize(t *testing.T) { 240 var ( 241 pools []*recordedPartPool 242 mtx sync.Mutex 243 ) 244 245 unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool { 246 mtx.Lock() 247 defer mtx.Unlock() 248 b := newRecordedPartPool(sliceSize) 249 pools = append(pools, b) 250 return b 251 }) 252 defer unswap() 253 254 sess := unit.Session.Copy() 255 svc := s3.New(sess) 256 svc.Handlers.Unmarshal.Clear() 257 svc.Handlers.UnmarshalMeta.Clear() 258 svc.Handlers.UnmarshalError.Clear() 259 svc.Handlers.Send.Clear() 260 svc.Handlers.Send.PushFront(func(r *request.Request) { 261 if r.Body != nil { 262 io.Copy(ioutil.Discard, r.Body) 263 } 264 265 r.HTTPResponse = &http.Response{ 266 StatusCode: 200, 267 Body: ioutil.NopCloser(bytes.NewReader([]byte(respBody))), 268 } 269 270 switch data := r.Data.(type) { 271 case *s3.CreateMultipartUploadOutput: 272 data.UploadId = aws.String("UPLOAD-ID") 273 case *s3.UploadPartOutput: 274 data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int())) 275 case *s3.CompleteMultipartUploadOutput: 276 data.Location = aws.String("https://location") 277 data.VersionId = aws.String("VERSION-ID") 278 case *s3.PutObjectOutput: 279 data.VersionId = aws.String("VERSION-ID") 280 } 281 }) 282 283 uploader := NewUploaderWithClient(svc, func(u *Uploader) { 284 u.PartSize = 5 * sdkio.MebiByte 285 u.Concurrency = 2 286 }) 287 288 var wg sync.WaitGroup 289 for i := 0; i < 2; i++ { 290 wg.Add(2) 291 go func() { 292 defer wg.Done() 293 expected := s3testing.GetTestBytes(int(15 * sdkio.MebiByte)) 294 _, err := uploader.Upload(&UploadInput{ 295 Bucket: aws.String("bucket"), 296 Key: aws.String("key"), 297 Body: &testReader{br: bytes.NewReader(expected)}, 298 }) 299 if err != nil { 300 t.Errorf("expected no error, but got %v", err) 301 } 302 }() 303 go func() { 304 defer wg.Done() 305 expected := s3testing.GetTestBytes(int(15 * sdkio.MebiByte)) 306 _, err := uploader.Upload(&UploadInput{ 307 Bucket: aws.String("bucket"), 308 Key: aws.String("key"), 309 Body: &testReader{br: bytes.NewReader(expected)}, 310 }, func(u *Uploader) { 311 u.PartSize = 6 * sdkio.MebiByte 312 }) 313 if err != nil { 314 t.Errorf("expected no error, but got %v", err) 315 } 316 }() 317 } 318 319 wg.Wait() 320 321 if e, a := 3, len(pools); e != a { 322 t.Errorf("expected %v, got %v", e, a) 323 } 324 325 for _, p := range pools { 326 if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 { 327 t.Fatalf("expected zero outsnatding pool parts, got %d", v) 328 } 329 330 t.Logf("total gets %v, total allocations %v", 331 atomic.LoadUint64(&p.recordedGets), 332 atomic.LoadUint64(&p.recordedAllocs)) 333 } 334} 335 336func BenchmarkPools(b *testing.B) { 337 cases := []struct { 338 PartSize int64 339 FileSize int64 340 Concurrency int 341 ExAllocations uint64 342 }{ 343 0: { 344 PartSize: sdkio.MebiByte * 5, 345 FileSize: sdkio.MebiByte * 5, 346 Concurrency: 1, 347 }, 348 1: { 349 PartSize: sdkio.MebiByte * 5, 350 FileSize: sdkio.MebiByte * 10, 351 Concurrency: 1, 352 }, 353 2: { 354 PartSize: sdkio.MebiByte * 5, 355 FileSize: sdkio.MebiByte * 20, 356 Concurrency: 2, 357 }, 358 3: { 359 PartSize: sdkio.MebiByte * 5, 360 FileSize: sdkio.MebiByte * 250, 361 Concurrency: 10, 362 }, 363 } 364 365 sess := unit.Session.Copy() 366 svc := s3.New(sess) 367 svc.Handlers.Unmarshal.Clear() 368 svc.Handlers.UnmarshalMeta.Clear() 369 svc.Handlers.UnmarshalError.Clear() 370 svc.Handlers.Send.Clear() 371 svc.Handlers.Send.PushFront(func(r *request.Request) { 372 if r.Body != nil { 373 io.Copy(ioutil.Discard, r.Body) 374 } 375 376 r.HTTPResponse = &http.Response{ 377 StatusCode: 200, 378 Body: ioutil.NopCloser(bytes.NewReader([]byte{})), 379 } 380 381 switch data := r.Data.(type) { 382 case *s3.CreateMultipartUploadOutput: 383 data.UploadId = aws.String("UPLOAD-ID") 384 case *s3.UploadPartOutput: 385 data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int())) 386 case *s3.CompleteMultipartUploadOutput: 387 data.Location = aws.String("https://location") 388 data.VersionId = aws.String("VERSION-ID") 389 case *s3.PutObjectOutput: 390 data.VersionId = aws.String("VERSION-ID") 391 } 392 }) 393 394 pools := map[string]func(sliceSize int64) byteSlicePool{ 395 "sync.Pool": func(sliceSize int64) byteSlicePool { 396 return newSyncSlicePool(sliceSize) 397 }, 398 "custom": func(sliceSize int64) byteSlicePool { 399 return newMaxSlicePool(sliceSize) 400 }, 401 } 402 403 for name, poolFunc := range pools { 404 b.Run(name, func(b *testing.B) { 405 unswap := swapByteSlicePool(poolFunc) 406 defer unswap() 407 for i, c := range cases { 408 b.Run(strconv.Itoa(i), func(b *testing.B) { 409 uploader := NewUploaderWithClient(svc, func(u *Uploader) { 410 u.PartSize = c.PartSize 411 u.Concurrency = c.Concurrency 412 }) 413 414 expected := s3testing.GetTestBytes(int(c.FileSize)) 415 b.ResetTimer() 416 _, err := uploader.Upload(&UploadInput{ 417 Bucket: aws.String("bucket"), 418 Key: aws.String("key"), 419 Body: &testReader{br: bytes.NewReader(expected)}, 420 }) 421 if err != nil { 422 b.Fatalf("expected no error, but got %v", err) 423 } 424 }) 425 } 426 }) 427 } 428} 429