1// +build go1.7 2 3package s3manager 4 5import ( 6 "bytes" 7 "fmt" 8 "io" 9 "io/ioutil" 10 random "math/rand" 11 "net/http" 12 "strconv" 13 "sync" 14 "sync/atomic" 15 "testing" 16 17 "github.com/aws/aws-sdk-go/aws" 18 "github.com/aws/aws-sdk-go/aws/request" 19 "github.com/aws/aws-sdk-go/awstesting/unit" 20 "github.com/aws/aws-sdk-go/internal/sdkio" 21 "github.com/aws/aws-sdk-go/service/s3" 22 "github.com/aws/aws-sdk-go/service/s3/internal/s3testing" 23) 24 25const respBody = `<?xml version="1.0" encoding="UTF-8"?> 26<CompleteMultipartUploadOutput> 27 <Location>mockValue</Location> 28 <Bucket>mockValue</Bucket> 29 <Key>mockValue</Key> 30 <ETag>mockValue</ETag> 31</CompleteMultipartUploadOutput>` 32 33type testReader struct { 34 br *bytes.Reader 35 m sync.Mutex 36} 37 38func (r *testReader) Read(p []byte) (n int, err error) { 39 r.m.Lock() 40 defer r.m.Unlock() 41 return r.br.Read(p) 42} 43 44func TestUploadByteSlicePool(t *testing.T) { 45 cases := map[string]struct { 46 PartSize int64 47 FileSize int64 48 Concurrency int 49 ExAllocations uint64 50 }{ 51 "single part, single concurrency": { 52 PartSize: sdkio.MebiByte * 5, 53 FileSize: sdkio.MebiByte * 5, 54 ExAllocations: 2, 55 Concurrency: 1, 56 }, 57 "multi-part, single concurrency": { 58 PartSize: sdkio.MebiByte * 5, 59 FileSize: sdkio.MebiByte * 10, 60 ExAllocations: 2, 61 Concurrency: 1, 62 }, 63 "multi-part, multiple concurrency": { 64 PartSize: sdkio.MebiByte * 5, 65 FileSize: sdkio.MebiByte * 20, 66 ExAllocations: 3, 67 Concurrency: 2, 68 }, 69 } 70 71 for name, tt := range cases { 72 t.Run(name, func(t *testing.T) { 73 var p *recordedPartPool 74 75 unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool { 76 p = newRecordedPartPool(sliceSize) 77 return p 78 }) 79 defer unswap() 80 81 sess := unit.Session.Copy() 82 svc := s3.New(sess) 83 svc.Handlers.Unmarshal.Clear() 84 svc.Handlers.UnmarshalMeta.Clear() 85 svc.Handlers.UnmarshalError.Clear() 86 svc.Handlers.Send.Clear() 87 svc.Handlers.Send.PushFront(func(r *request.Request) { 88 if r.Body != nil { 89 io.Copy(ioutil.Discard, r.Body) 90 } 91 92 r.HTTPResponse = &http.Response{ 93 StatusCode: 200, 94 Body: ioutil.NopCloser(bytes.NewReader([]byte(respBody))), 95 } 96 97 switch data := r.Data.(type) { 98 case *s3.CreateMultipartUploadOutput: 99 data.UploadId = aws.String("UPLOAD-ID") 100 case *s3.UploadPartOutput: 101 data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int())) 102 case *s3.CompleteMultipartUploadOutput: 103 data.Location = aws.String("https://location") 104 data.VersionId = aws.String("VERSION-ID") 105 case *s3.PutObjectOutput: 106 data.VersionId = aws.String("VERSION-ID") 107 } 108 }) 109 110 uploader := NewUploaderWithClient(svc, func(u *Uploader) { 111 u.PartSize = tt.PartSize 112 u.Concurrency = tt.Concurrency 113 }) 114 115 expected := s3testing.GetTestBytes(int(tt.FileSize)) 116 _, err := uploader.Upload(&UploadInput{ 117 Bucket: aws.String("bucket"), 118 Key: aws.String("key"), 119 Body: &testReader{br: bytes.NewReader(expected)}, 120 }) 121 if err != nil { 122 t.Errorf("expected no error, but got %v", err) 123 } 124 125 if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 { 126 t.Fatalf("expected zero outsnatding pool parts, got %d", v) 127 } 128 129 gets, allocs := atomic.LoadUint64(&p.recordedGets), atomic.LoadUint64(&p.recordedAllocs) 130 131 t.Logf("total gets %v, total allocations %v", gets, allocs) 132 if e, a := tt.ExAllocations, allocs; a > e { 133 t.Errorf("expected %v allocations, got %v", e, a) 134 } 135 }) 136 } 137} 138 139func TestUploadByteSlicePool_Failures(t *testing.T) { 140 cases := map[string]struct { 141 PartSize int64 142 FileSize int64 143 Operations []string 144 }{ 145 "single part": { 146 PartSize: sdkio.MebiByte * 5, 147 FileSize: sdkio.MebiByte * 4, 148 Operations: []string{ 149 "PutObject", 150 }, 151 }, 152 "multi-part": { 153 PartSize: sdkio.MebiByte * 5, 154 FileSize: sdkio.MebiByte * 10, 155 Operations: []string{ 156 "CreateMultipartUpload", 157 "UploadPart", 158 "CompleteMultipartUpload", 159 }, 160 }, 161 } 162 163 for name, tt := range cases { 164 t.Run(name, func(t *testing.T) { 165 for _, operation := range tt.Operations { 166 t.Run(operation, func(t *testing.T) { 167 var p *recordedPartPool 168 169 unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool { 170 p = newRecordedPartPool(sliceSize) 171 return p 172 }) 173 defer unswap() 174 175 sess := unit.Session.Copy() 176 svc := s3.New(sess) 177 svc.Handlers.Unmarshal.Clear() 178 svc.Handlers.UnmarshalMeta.Clear() 179 svc.Handlers.UnmarshalError.Clear() 180 svc.Handlers.Send.Clear() 181 svc.Handlers.Send.PushFront(func(r *request.Request) { 182 if r.Body != nil { 183 io.Copy(ioutil.Discard, r.Body) 184 } 185 186 if r.Operation.Name == operation { 187 r.Retryable = aws.Bool(false) 188 r.Error = fmt.Errorf("request error") 189 r.HTTPResponse = &http.Response{ 190 StatusCode: 500, 191 Body: ioutil.NopCloser(bytes.NewReader([]byte{})), 192 } 193 return 194 } 195 196 r.HTTPResponse = &http.Response{ 197 StatusCode: 200, 198 Body: ioutil.NopCloser(bytes.NewReader([]byte(respBody))), 199 } 200 201 switch data := r.Data.(type) { 202 case *s3.CreateMultipartUploadOutput: 203 data.UploadId = aws.String("UPLOAD-ID") 204 case *s3.UploadPartOutput: 205 data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int())) 206 case *s3.CompleteMultipartUploadOutput: 207 data.Location = aws.String("https://location") 208 data.VersionId = aws.String("VERSION-ID") 209 case *s3.PutObjectOutput: 210 data.VersionId = aws.String("VERSION-ID") 211 } 212 }) 213 214 uploader := NewUploaderWithClient(svc, func(u *Uploader) { 215 u.Concurrency = 1 216 u.PartSize = tt.PartSize 217 }) 218 219 expected := s3testing.GetTestBytes(int(tt.FileSize)) 220 _, err := uploader.Upload(&UploadInput{ 221 Bucket: aws.String("bucket"), 222 Key: aws.String("key"), 223 Body: &testReader{br: bytes.NewReader(expected)}, 224 }) 225 if err == nil { 226 t.Fatalf("expected error but got none") 227 } 228 229 if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 { 230 t.Fatalf("expected zero outsnatding pool parts, got %d", v) 231 } 232 }) 233 } 234 }) 235 } 236} 237 238func TestUploadByteSlicePoolConcurrentMultiPartSize(t *testing.T) { 239 var ( 240 pools []*recordedPartPool 241 mtx sync.Mutex 242 ) 243 244 unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool { 245 mtx.Lock() 246 defer mtx.Unlock() 247 b := newRecordedPartPool(sliceSize) 248 pools = append(pools, b) 249 return b 250 }) 251 defer unswap() 252 253 sess := unit.Session.Copy() 254 svc := s3.New(sess) 255 svc.Handlers.Unmarshal.Clear() 256 svc.Handlers.UnmarshalMeta.Clear() 257 svc.Handlers.UnmarshalError.Clear() 258 svc.Handlers.Send.Clear() 259 svc.Handlers.Send.PushFront(func(r *request.Request) { 260 if r.Body != nil { 261 io.Copy(ioutil.Discard, r.Body) 262 } 263 264 r.HTTPResponse = &http.Response{ 265 StatusCode: 200, 266 Body: ioutil.NopCloser(bytes.NewReader([]byte(respBody))), 267 } 268 269 switch data := r.Data.(type) { 270 case *s3.CreateMultipartUploadOutput: 271 data.UploadId = aws.String("UPLOAD-ID") 272 case *s3.UploadPartOutput: 273 data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int())) 274 case *s3.CompleteMultipartUploadOutput: 275 data.Location = aws.String("https://location") 276 data.VersionId = aws.String("VERSION-ID") 277 case *s3.PutObjectOutput: 278 data.VersionId = aws.String("VERSION-ID") 279 } 280 }) 281 282 uploader := NewUploaderWithClient(svc, func(u *Uploader) { 283 u.PartSize = 5 * sdkio.MebiByte 284 u.Concurrency = 2 285 }) 286 287 var wg sync.WaitGroup 288 for i := 0; i < 2; i++ { 289 wg.Add(2) 290 go func() { 291 defer wg.Done() 292 expected := s3testing.GetTestBytes(int(15 * sdkio.MebiByte)) 293 _, err := uploader.Upload(&UploadInput{ 294 Bucket: aws.String("bucket"), 295 Key: aws.String("key"), 296 Body: &testReader{br: bytes.NewReader(expected)}, 297 }) 298 if err != nil { 299 t.Errorf("expected no error, but got %v", err) 300 } 301 }() 302 go func() { 303 defer wg.Done() 304 expected := s3testing.GetTestBytes(int(15 * sdkio.MebiByte)) 305 _, err := uploader.Upload(&UploadInput{ 306 Bucket: aws.String("bucket"), 307 Key: aws.String("key"), 308 Body: &testReader{br: bytes.NewReader(expected)}, 309 }, func(u *Uploader) { 310 u.PartSize = 6 * sdkio.MebiByte 311 }) 312 if err != nil { 313 t.Errorf("expected no error, but got %v", err) 314 } 315 }() 316 } 317 318 wg.Wait() 319 320 if e, a := 3, len(pools); e != a { 321 t.Errorf("expected %v, got %v", e, a) 322 } 323 324 for _, p := range pools { 325 if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 { 326 t.Fatalf("expected zero outsnatding pool parts, got %d", v) 327 } 328 329 t.Logf("total gets %v, total allocations %v", 330 atomic.LoadUint64(&p.recordedGets), 331 atomic.LoadUint64(&p.recordedAllocs)) 332 } 333} 334 335func BenchmarkPools(b *testing.B) { 336 cases := []struct { 337 PartSize int64 338 FileSize int64 339 Concurrency int 340 ExAllocations uint64 341 }{ 342 0: { 343 PartSize: sdkio.MebiByte * 5, 344 FileSize: sdkio.MebiByte * 5, 345 Concurrency: 1, 346 }, 347 1: { 348 PartSize: sdkio.MebiByte * 5, 349 FileSize: sdkio.MebiByte * 10, 350 Concurrency: 1, 351 }, 352 2: { 353 PartSize: sdkio.MebiByte * 5, 354 FileSize: sdkio.MebiByte * 20, 355 Concurrency: 2, 356 }, 357 3: { 358 PartSize: sdkio.MebiByte * 5, 359 FileSize: sdkio.MebiByte * 250, 360 Concurrency: 10, 361 }, 362 } 363 364 sess := unit.Session.Copy() 365 svc := s3.New(sess) 366 svc.Handlers.Unmarshal.Clear() 367 svc.Handlers.UnmarshalMeta.Clear() 368 svc.Handlers.UnmarshalError.Clear() 369 svc.Handlers.Send.Clear() 370 svc.Handlers.Send.PushFront(func(r *request.Request) { 371 if r.Body != nil { 372 io.Copy(ioutil.Discard, r.Body) 373 } 374 375 r.HTTPResponse = &http.Response{ 376 StatusCode: 200, 377 Body: ioutil.NopCloser(bytes.NewReader([]byte{})), 378 } 379 380 switch data := r.Data.(type) { 381 case *s3.CreateMultipartUploadOutput: 382 data.UploadId = aws.String("UPLOAD-ID") 383 case *s3.UploadPartOutput: 384 data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int())) 385 case *s3.CompleteMultipartUploadOutput: 386 data.Location = aws.String("https://location") 387 data.VersionId = aws.String("VERSION-ID") 388 case *s3.PutObjectOutput: 389 data.VersionId = aws.String("VERSION-ID") 390 } 391 }) 392 393 pools := map[string]func(sliceSize int64) byteSlicePool{ 394 "sync.Pool": func(sliceSize int64) byteSlicePool { 395 return newSyncSlicePool(sliceSize) 396 }, 397 "custom": func(sliceSize int64) byteSlicePool { 398 return newMaxSlicePool(sliceSize) 399 }, 400 } 401 402 for name, poolFunc := range pools { 403 b.Run(name, func(b *testing.B) { 404 unswap := swapByteSlicePool(poolFunc) 405 defer unswap() 406 for i, c := range cases { 407 b.Run(strconv.Itoa(i), func(b *testing.B) { 408 uploader := NewUploaderWithClient(svc, func(u *Uploader) { 409 u.PartSize = c.PartSize 410 u.Concurrency = c.Concurrency 411 }) 412 413 expected := s3testing.GetTestBytes(int(c.FileSize)) 414 b.ResetTimer() 415 _, err := uploader.Upload(&UploadInput{ 416 Bucket: aws.String("bucket"), 417 Key: aws.String("key"), 418 Body: &testReader{br: bytes.NewReader(expected)}, 419 }) 420 if err != nil { 421 b.Fatalf("expected no error, but got %v", err) 422 } 423 }) 424 } 425 }) 426 } 427} 428