1//go:build go1.7
2// +build go1.7
3
4package s3manager
5
6import (
7	"bytes"
8	"fmt"
9	"io"
10	"io/ioutil"
11	random "math/rand"
12	"net/http"
13	"strconv"
14	"sync"
15	"sync/atomic"
16	"testing"
17
18	"github.com/aws/aws-sdk-go/aws"
19	"github.com/aws/aws-sdk-go/aws/request"
20	"github.com/aws/aws-sdk-go/awstesting/unit"
21	"github.com/aws/aws-sdk-go/internal/sdkio"
22	"github.com/aws/aws-sdk-go/service/s3"
23	"github.com/aws/aws-sdk-go/service/s3/internal/s3testing"
24)
25
26const respBody = `<?xml version="1.0" encoding="UTF-8"?>
27<CompleteMultipartUploadOutput>
28   <Location>mockValue</Location>
29   <Bucket>mockValue</Bucket>
30   <Key>mockValue</Key>
31   <ETag>mockValue</ETag>
32</CompleteMultipartUploadOutput>`
33
34type testReader struct {
35	br *bytes.Reader
36	m  sync.Mutex
37}
38
39func (r *testReader) Read(p []byte) (n int, err error) {
40	r.m.Lock()
41	defer r.m.Unlock()
42	return r.br.Read(p)
43}
44
45func TestUploadByteSlicePool(t *testing.T) {
46	cases := map[string]struct {
47		PartSize      int64
48		FileSize      int64
49		Concurrency   int
50		ExAllocations uint64
51	}{
52		"single part, single concurrency": {
53			PartSize:      sdkio.MebiByte * 5,
54			FileSize:      sdkio.MebiByte * 5,
55			ExAllocations: 2,
56			Concurrency:   1,
57		},
58		"multi-part, single concurrency": {
59			PartSize:      sdkio.MebiByte * 5,
60			FileSize:      sdkio.MebiByte * 10,
61			ExAllocations: 2,
62			Concurrency:   1,
63		},
64		"multi-part, multiple concurrency": {
65			PartSize:      sdkio.MebiByte * 5,
66			FileSize:      sdkio.MebiByte * 20,
67			ExAllocations: 3,
68			Concurrency:   2,
69		},
70	}
71
72	for name, tt := range cases {
73		t.Run(name, func(t *testing.T) {
74			var p *recordedPartPool
75
76			unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool {
77				p = newRecordedPartPool(sliceSize)
78				return p
79			})
80			defer unswap()
81
82			sess := unit.Session.Copy()
83			svc := s3.New(sess)
84			svc.Handlers.Unmarshal.Clear()
85			svc.Handlers.UnmarshalMeta.Clear()
86			svc.Handlers.UnmarshalError.Clear()
87			svc.Handlers.Send.Clear()
88			svc.Handlers.Send.PushFront(func(r *request.Request) {
89				if r.Body != nil {
90					io.Copy(ioutil.Discard, r.Body)
91				}
92
93				r.HTTPResponse = &http.Response{
94					StatusCode: 200,
95					Body:       ioutil.NopCloser(bytes.NewReader([]byte(respBody))),
96				}
97
98				switch data := r.Data.(type) {
99				case *s3.CreateMultipartUploadOutput:
100					data.UploadId = aws.String("UPLOAD-ID")
101				case *s3.UploadPartOutput:
102					data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int()))
103				case *s3.CompleteMultipartUploadOutput:
104					data.Location = aws.String("https://location")
105					data.VersionId = aws.String("VERSION-ID")
106				case *s3.PutObjectOutput:
107					data.VersionId = aws.String("VERSION-ID")
108				}
109			})
110
111			uploader := NewUploaderWithClient(svc, func(u *Uploader) {
112				u.PartSize = tt.PartSize
113				u.Concurrency = tt.Concurrency
114			})
115
116			expected := s3testing.GetTestBytes(int(tt.FileSize))
117			_, err := uploader.Upload(&UploadInput{
118				Bucket: aws.String("bucket"),
119				Key:    aws.String("key"),
120				Body:   &testReader{br: bytes.NewReader(expected)},
121			})
122			if err != nil {
123				t.Errorf("expected no error, but got %v", err)
124			}
125
126			if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 {
127				t.Fatalf("expected zero outsnatding pool parts, got %d", v)
128			}
129
130			gets, allocs := atomic.LoadUint64(&p.recordedGets), atomic.LoadUint64(&p.recordedAllocs)
131
132			t.Logf("total gets %v, total allocations %v", gets, allocs)
133			if e, a := tt.ExAllocations, allocs; a > e {
134				t.Errorf("expected %v allocations, got %v", e, a)
135			}
136		})
137	}
138}
139
140func TestUploadByteSlicePool_Failures(t *testing.T) {
141	cases := map[string]struct {
142		PartSize   int64
143		FileSize   int64
144		Operations []string
145	}{
146		"single part": {
147			PartSize: sdkio.MebiByte * 5,
148			FileSize: sdkio.MebiByte * 4,
149			Operations: []string{
150				"PutObject",
151			},
152		},
153		"multi-part": {
154			PartSize: sdkio.MebiByte * 5,
155			FileSize: sdkio.MebiByte * 10,
156			Operations: []string{
157				"CreateMultipartUpload",
158				"UploadPart",
159				"CompleteMultipartUpload",
160			},
161		},
162	}
163
164	for name, tt := range cases {
165		t.Run(name, func(t *testing.T) {
166			for _, operation := range tt.Operations {
167				t.Run(operation, func(t *testing.T) {
168					var p *recordedPartPool
169
170					unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool {
171						p = newRecordedPartPool(sliceSize)
172						return p
173					})
174					defer unswap()
175
176					sess := unit.Session.Copy()
177					svc := s3.New(sess)
178					svc.Handlers.Unmarshal.Clear()
179					svc.Handlers.UnmarshalMeta.Clear()
180					svc.Handlers.UnmarshalError.Clear()
181					svc.Handlers.Send.Clear()
182					svc.Handlers.Send.PushFront(func(r *request.Request) {
183						if r.Body != nil {
184							io.Copy(ioutil.Discard, r.Body)
185						}
186
187						if r.Operation.Name == operation {
188							r.Retryable = aws.Bool(false)
189							r.Error = fmt.Errorf("request error")
190							r.HTTPResponse = &http.Response{
191								StatusCode: 500,
192								Body:       ioutil.NopCloser(bytes.NewReader([]byte{})),
193							}
194							return
195						}
196
197						r.HTTPResponse = &http.Response{
198							StatusCode: 200,
199							Body:       ioutil.NopCloser(bytes.NewReader([]byte(respBody))),
200						}
201
202						switch data := r.Data.(type) {
203						case *s3.CreateMultipartUploadOutput:
204							data.UploadId = aws.String("UPLOAD-ID")
205						case *s3.UploadPartOutput:
206							data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int()))
207						case *s3.CompleteMultipartUploadOutput:
208							data.Location = aws.String("https://location")
209							data.VersionId = aws.String("VERSION-ID")
210						case *s3.PutObjectOutput:
211							data.VersionId = aws.String("VERSION-ID")
212						}
213					})
214
215					uploader := NewUploaderWithClient(svc, func(u *Uploader) {
216						u.Concurrency = 1
217						u.PartSize = tt.PartSize
218					})
219
220					expected := s3testing.GetTestBytes(int(tt.FileSize))
221					_, err := uploader.Upload(&UploadInput{
222						Bucket: aws.String("bucket"),
223						Key:    aws.String("key"),
224						Body:   &testReader{br: bytes.NewReader(expected)},
225					})
226					if err == nil {
227						t.Fatalf("expected error but got none")
228					}
229
230					if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 {
231						t.Fatalf("expected zero outsnatding pool parts, got %d", v)
232					}
233				})
234			}
235		})
236	}
237}
238
239func TestUploadByteSlicePoolConcurrentMultiPartSize(t *testing.T) {
240	var (
241		pools []*recordedPartPool
242		mtx   sync.Mutex
243	)
244
245	unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool {
246		mtx.Lock()
247		defer mtx.Unlock()
248		b := newRecordedPartPool(sliceSize)
249		pools = append(pools, b)
250		return b
251	})
252	defer unswap()
253
254	sess := unit.Session.Copy()
255	svc := s3.New(sess)
256	svc.Handlers.Unmarshal.Clear()
257	svc.Handlers.UnmarshalMeta.Clear()
258	svc.Handlers.UnmarshalError.Clear()
259	svc.Handlers.Send.Clear()
260	svc.Handlers.Send.PushFront(func(r *request.Request) {
261		if r.Body != nil {
262			io.Copy(ioutil.Discard, r.Body)
263		}
264
265		r.HTTPResponse = &http.Response{
266			StatusCode: 200,
267			Body:       ioutil.NopCloser(bytes.NewReader([]byte(respBody))),
268		}
269
270		switch data := r.Data.(type) {
271		case *s3.CreateMultipartUploadOutput:
272			data.UploadId = aws.String("UPLOAD-ID")
273		case *s3.UploadPartOutput:
274			data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int()))
275		case *s3.CompleteMultipartUploadOutput:
276			data.Location = aws.String("https://location")
277			data.VersionId = aws.String("VERSION-ID")
278		case *s3.PutObjectOutput:
279			data.VersionId = aws.String("VERSION-ID")
280		}
281	})
282
283	uploader := NewUploaderWithClient(svc, func(u *Uploader) {
284		u.PartSize = 5 * sdkio.MebiByte
285		u.Concurrency = 2
286	})
287
288	var wg sync.WaitGroup
289	for i := 0; i < 2; i++ {
290		wg.Add(2)
291		go func() {
292			defer wg.Done()
293			expected := s3testing.GetTestBytes(int(15 * sdkio.MebiByte))
294			_, err := uploader.Upload(&UploadInput{
295				Bucket: aws.String("bucket"),
296				Key:    aws.String("key"),
297				Body:   &testReader{br: bytes.NewReader(expected)},
298			})
299			if err != nil {
300				t.Errorf("expected no error, but got %v", err)
301			}
302		}()
303		go func() {
304			defer wg.Done()
305			expected := s3testing.GetTestBytes(int(15 * sdkio.MebiByte))
306			_, err := uploader.Upload(&UploadInput{
307				Bucket: aws.String("bucket"),
308				Key:    aws.String("key"),
309				Body:   &testReader{br: bytes.NewReader(expected)},
310			}, func(u *Uploader) {
311				u.PartSize = 6 * sdkio.MebiByte
312			})
313			if err != nil {
314				t.Errorf("expected no error, but got %v", err)
315			}
316		}()
317	}
318
319	wg.Wait()
320
321	if e, a := 3, len(pools); e != a {
322		t.Errorf("expected %v, got %v", e, a)
323	}
324
325	for _, p := range pools {
326		if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 {
327			t.Fatalf("expected zero outsnatding pool parts, got %d", v)
328		}
329
330		t.Logf("total gets %v, total allocations %v",
331			atomic.LoadUint64(&p.recordedGets),
332			atomic.LoadUint64(&p.recordedAllocs))
333	}
334}
335
336func BenchmarkPools(b *testing.B) {
337	cases := []struct {
338		PartSize      int64
339		FileSize      int64
340		Concurrency   int
341		ExAllocations uint64
342	}{
343		0: {
344			PartSize:    sdkio.MebiByte * 5,
345			FileSize:    sdkio.MebiByte * 5,
346			Concurrency: 1,
347		},
348		1: {
349			PartSize:    sdkio.MebiByte * 5,
350			FileSize:    sdkio.MebiByte * 10,
351			Concurrency: 1,
352		},
353		2: {
354			PartSize:    sdkio.MebiByte * 5,
355			FileSize:    sdkio.MebiByte * 20,
356			Concurrency: 2,
357		},
358		3: {
359			PartSize:    sdkio.MebiByte * 5,
360			FileSize:    sdkio.MebiByte * 250,
361			Concurrency: 10,
362		},
363	}
364
365	sess := unit.Session.Copy()
366	svc := s3.New(sess)
367	svc.Handlers.Unmarshal.Clear()
368	svc.Handlers.UnmarshalMeta.Clear()
369	svc.Handlers.UnmarshalError.Clear()
370	svc.Handlers.Send.Clear()
371	svc.Handlers.Send.PushFront(func(r *request.Request) {
372		if r.Body != nil {
373			io.Copy(ioutil.Discard, r.Body)
374		}
375
376		r.HTTPResponse = &http.Response{
377			StatusCode: 200,
378			Body:       ioutil.NopCloser(bytes.NewReader([]byte{})),
379		}
380
381		switch data := r.Data.(type) {
382		case *s3.CreateMultipartUploadOutput:
383			data.UploadId = aws.String("UPLOAD-ID")
384		case *s3.UploadPartOutput:
385			data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int()))
386		case *s3.CompleteMultipartUploadOutput:
387			data.Location = aws.String("https://location")
388			data.VersionId = aws.String("VERSION-ID")
389		case *s3.PutObjectOutput:
390			data.VersionId = aws.String("VERSION-ID")
391		}
392	})
393
394	pools := map[string]func(sliceSize int64) byteSlicePool{
395		"sync.Pool": func(sliceSize int64) byteSlicePool {
396			return newSyncSlicePool(sliceSize)
397		},
398		"custom": func(sliceSize int64) byteSlicePool {
399			return newMaxSlicePool(sliceSize)
400		},
401	}
402
403	for name, poolFunc := range pools {
404		b.Run(name, func(b *testing.B) {
405			unswap := swapByteSlicePool(poolFunc)
406			defer unswap()
407			for i, c := range cases {
408				b.Run(strconv.Itoa(i), func(b *testing.B) {
409					uploader := NewUploaderWithClient(svc, func(u *Uploader) {
410						u.PartSize = c.PartSize
411						u.Concurrency = c.Concurrency
412					})
413
414					expected := s3testing.GetTestBytes(int(c.FileSize))
415					b.ResetTimer()
416					_, err := uploader.Upload(&UploadInput{
417						Bucket: aws.String("bucket"),
418						Key:    aws.String("key"),
419						Body:   &testReader{br: bytes.NewReader(expected)},
420					})
421					if err != nil {
422						b.Fatalf("expected no error, but got %v", err)
423					}
424				})
425			}
426		})
427	}
428}
429