1// +build go1.7
2
3package s3manager
4
5import (
6	"bytes"
7	"fmt"
8	"io"
9	"io/ioutil"
10	random "math/rand"
11	"net/http"
12	"strconv"
13	"sync"
14	"sync/atomic"
15	"testing"
16
17	"github.com/aws/aws-sdk-go/aws"
18	"github.com/aws/aws-sdk-go/aws/request"
19	"github.com/aws/aws-sdk-go/awstesting/unit"
20	"github.com/aws/aws-sdk-go/internal/sdkio"
21	"github.com/aws/aws-sdk-go/service/s3"
22	"github.com/aws/aws-sdk-go/service/s3/internal/s3testing"
23)
24
25const respBody = `<?xml version="1.0" encoding="UTF-8"?>
26<CompleteMultipartUploadOutput>
27   <Location>mockValue</Location>
28   <Bucket>mockValue</Bucket>
29   <Key>mockValue</Key>
30   <ETag>mockValue</ETag>
31</CompleteMultipartUploadOutput>`
32
33type testReader struct {
34	br *bytes.Reader
35	m  sync.Mutex
36}
37
38func (r *testReader) Read(p []byte) (n int, err error) {
39	r.m.Lock()
40	defer r.m.Unlock()
41	return r.br.Read(p)
42}
43
44func TestUploadByteSlicePool(t *testing.T) {
45	cases := map[string]struct {
46		PartSize      int64
47		FileSize      int64
48		Concurrency   int
49		ExAllocations uint64
50	}{
51		"single part, single concurrency": {
52			PartSize:      sdkio.MebiByte * 5,
53			FileSize:      sdkio.MebiByte * 5,
54			ExAllocations: 2,
55			Concurrency:   1,
56		},
57		"multi-part, single concurrency": {
58			PartSize:      sdkio.MebiByte * 5,
59			FileSize:      sdkio.MebiByte * 10,
60			ExAllocations: 2,
61			Concurrency:   1,
62		},
63		"multi-part, multiple concurrency": {
64			PartSize:      sdkio.MebiByte * 5,
65			FileSize:      sdkio.MebiByte * 20,
66			ExAllocations: 3,
67			Concurrency:   2,
68		},
69	}
70
71	for name, tt := range cases {
72		t.Run(name, func(t *testing.T) {
73			var p *recordedPartPool
74
75			unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool {
76				p = newRecordedPartPool(sliceSize)
77				return p
78			})
79			defer unswap()
80
81			sess := unit.Session.Copy()
82			svc := s3.New(sess)
83			svc.Handlers.Unmarshal.Clear()
84			svc.Handlers.UnmarshalMeta.Clear()
85			svc.Handlers.UnmarshalError.Clear()
86			svc.Handlers.Send.Clear()
87			svc.Handlers.Send.PushFront(func(r *request.Request) {
88				if r.Body != nil {
89					io.Copy(ioutil.Discard, r.Body)
90				}
91
92				r.HTTPResponse = &http.Response{
93					StatusCode: 200,
94					Body:       ioutil.NopCloser(bytes.NewReader([]byte(respBody))),
95				}
96
97				switch data := r.Data.(type) {
98				case *s3.CreateMultipartUploadOutput:
99					data.UploadId = aws.String("UPLOAD-ID")
100				case *s3.UploadPartOutput:
101					data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int()))
102				case *s3.CompleteMultipartUploadOutput:
103					data.Location = aws.String("https://location")
104					data.VersionId = aws.String("VERSION-ID")
105				case *s3.PutObjectOutput:
106					data.VersionId = aws.String("VERSION-ID")
107				}
108			})
109
110			uploader := NewUploaderWithClient(svc, func(u *Uploader) {
111				u.PartSize = tt.PartSize
112				u.Concurrency = tt.Concurrency
113			})
114
115			expected := s3testing.GetTestBytes(int(tt.FileSize))
116			_, err := uploader.Upload(&UploadInput{
117				Bucket: aws.String("bucket"),
118				Key:    aws.String("key"),
119				Body:   &testReader{br: bytes.NewReader(expected)},
120			})
121			if err != nil {
122				t.Errorf("expected no error, but got %v", err)
123			}
124
125			if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 {
126				t.Fatalf("expected zero outsnatding pool parts, got %d", v)
127			}
128
129			gets, allocs := atomic.LoadUint64(&p.recordedGets), atomic.LoadUint64(&p.recordedAllocs)
130
131			t.Logf("total gets %v, total allocations %v", gets, allocs)
132			if e, a := tt.ExAllocations, allocs; a > e {
133				t.Errorf("expected %v allocations, got %v", e, a)
134			}
135		})
136	}
137}
138
139func TestUploadByteSlicePool_Failures(t *testing.T) {
140	cases := map[string]struct {
141		PartSize   int64
142		FileSize   int64
143		Operations []string
144	}{
145		"single part": {
146			PartSize: sdkio.MebiByte * 5,
147			FileSize: sdkio.MebiByte * 4,
148			Operations: []string{
149				"PutObject",
150			},
151		},
152		"multi-part": {
153			PartSize: sdkio.MebiByte * 5,
154			FileSize: sdkio.MebiByte * 10,
155			Operations: []string{
156				"CreateMultipartUpload",
157				"UploadPart",
158				"CompleteMultipartUpload",
159			},
160		},
161	}
162
163	for name, tt := range cases {
164		t.Run(name, func(t *testing.T) {
165			for _, operation := range tt.Operations {
166				t.Run(operation, func(t *testing.T) {
167					var p *recordedPartPool
168
169					unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool {
170						p = newRecordedPartPool(sliceSize)
171						return p
172					})
173					defer unswap()
174
175					sess := unit.Session.Copy()
176					svc := s3.New(sess)
177					svc.Handlers.Unmarshal.Clear()
178					svc.Handlers.UnmarshalMeta.Clear()
179					svc.Handlers.UnmarshalError.Clear()
180					svc.Handlers.Send.Clear()
181					svc.Handlers.Send.PushFront(func(r *request.Request) {
182						if r.Body != nil {
183							io.Copy(ioutil.Discard, r.Body)
184						}
185
186						if r.Operation.Name == operation {
187							r.Retryable = aws.Bool(false)
188							r.Error = fmt.Errorf("request error")
189							r.HTTPResponse = &http.Response{
190								StatusCode: 500,
191								Body:       ioutil.NopCloser(bytes.NewReader([]byte{})),
192							}
193							return
194						}
195
196						r.HTTPResponse = &http.Response{
197							StatusCode: 200,
198							Body:       ioutil.NopCloser(bytes.NewReader([]byte(respBody))),
199						}
200
201						switch data := r.Data.(type) {
202						case *s3.CreateMultipartUploadOutput:
203							data.UploadId = aws.String("UPLOAD-ID")
204						case *s3.UploadPartOutput:
205							data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int()))
206						case *s3.CompleteMultipartUploadOutput:
207							data.Location = aws.String("https://location")
208							data.VersionId = aws.String("VERSION-ID")
209						case *s3.PutObjectOutput:
210							data.VersionId = aws.String("VERSION-ID")
211						}
212					})
213
214					uploader := NewUploaderWithClient(svc, func(u *Uploader) {
215						u.Concurrency = 1
216						u.PartSize = tt.PartSize
217					})
218
219					expected := s3testing.GetTestBytes(int(tt.FileSize))
220					_, err := uploader.Upload(&UploadInput{
221						Bucket: aws.String("bucket"),
222						Key:    aws.String("key"),
223						Body:   &testReader{br: bytes.NewReader(expected)},
224					})
225					if err == nil {
226						t.Fatalf("expected error but got none")
227					}
228
229					if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 {
230						t.Fatalf("expected zero outsnatding pool parts, got %d", v)
231					}
232				})
233			}
234		})
235	}
236}
237
238func TestUploadByteSlicePoolConcurrentMultiPartSize(t *testing.T) {
239	var (
240		pools []*recordedPartPool
241		mtx   sync.Mutex
242	)
243
244	unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool {
245		mtx.Lock()
246		defer mtx.Unlock()
247		b := newRecordedPartPool(sliceSize)
248		pools = append(pools, b)
249		return b
250	})
251	defer unswap()
252
253	sess := unit.Session.Copy()
254	svc := s3.New(sess)
255	svc.Handlers.Unmarshal.Clear()
256	svc.Handlers.UnmarshalMeta.Clear()
257	svc.Handlers.UnmarshalError.Clear()
258	svc.Handlers.Send.Clear()
259	svc.Handlers.Send.PushFront(func(r *request.Request) {
260		if r.Body != nil {
261			io.Copy(ioutil.Discard, r.Body)
262		}
263
264		r.HTTPResponse = &http.Response{
265			StatusCode: 200,
266			Body:       ioutil.NopCloser(bytes.NewReader([]byte(respBody))),
267		}
268
269		switch data := r.Data.(type) {
270		case *s3.CreateMultipartUploadOutput:
271			data.UploadId = aws.String("UPLOAD-ID")
272		case *s3.UploadPartOutput:
273			data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int()))
274		case *s3.CompleteMultipartUploadOutput:
275			data.Location = aws.String("https://location")
276			data.VersionId = aws.String("VERSION-ID")
277		case *s3.PutObjectOutput:
278			data.VersionId = aws.String("VERSION-ID")
279		}
280	})
281
282	uploader := NewUploaderWithClient(svc, func(u *Uploader) {
283		u.PartSize = 5 * sdkio.MebiByte
284		u.Concurrency = 2
285	})
286
287	var wg sync.WaitGroup
288	for i := 0; i < 2; i++ {
289		wg.Add(2)
290		go func() {
291			defer wg.Done()
292			expected := s3testing.GetTestBytes(int(15 * sdkio.MebiByte))
293			_, err := uploader.Upload(&UploadInput{
294				Bucket: aws.String("bucket"),
295				Key:    aws.String("key"),
296				Body:   &testReader{br: bytes.NewReader(expected)},
297			})
298			if err != nil {
299				t.Errorf("expected no error, but got %v", err)
300			}
301		}()
302		go func() {
303			defer wg.Done()
304			expected := s3testing.GetTestBytes(int(15 * sdkio.MebiByte))
305			_, err := uploader.Upload(&UploadInput{
306				Bucket: aws.String("bucket"),
307				Key:    aws.String("key"),
308				Body:   &testReader{br: bytes.NewReader(expected)},
309			}, func(u *Uploader) {
310				u.PartSize = 6 * sdkio.MebiByte
311			})
312			if err != nil {
313				t.Errorf("expected no error, but got %v", err)
314			}
315		}()
316	}
317
318	wg.Wait()
319
320	if e, a := 3, len(pools); e != a {
321		t.Errorf("expected %v, got %v", e, a)
322	}
323
324	for _, p := range pools {
325		if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 {
326			t.Fatalf("expected zero outsnatding pool parts, got %d", v)
327		}
328
329		t.Logf("total gets %v, total allocations %v",
330			atomic.LoadUint64(&p.recordedGets),
331			atomic.LoadUint64(&p.recordedAllocs))
332	}
333}
334
335func BenchmarkPools(b *testing.B) {
336	cases := []struct {
337		PartSize      int64
338		FileSize      int64
339		Concurrency   int
340		ExAllocations uint64
341	}{
342		0: {
343			PartSize:    sdkio.MebiByte * 5,
344			FileSize:    sdkio.MebiByte * 5,
345			Concurrency: 1,
346		},
347		1: {
348			PartSize:    sdkio.MebiByte * 5,
349			FileSize:    sdkio.MebiByte * 10,
350			Concurrency: 1,
351		},
352		2: {
353			PartSize:    sdkio.MebiByte * 5,
354			FileSize:    sdkio.MebiByte * 20,
355			Concurrency: 2,
356		},
357		3: {
358			PartSize:    sdkio.MebiByte * 5,
359			FileSize:    sdkio.MebiByte * 250,
360			Concurrency: 10,
361		},
362	}
363
364	sess := unit.Session.Copy()
365	svc := s3.New(sess)
366	svc.Handlers.Unmarshal.Clear()
367	svc.Handlers.UnmarshalMeta.Clear()
368	svc.Handlers.UnmarshalError.Clear()
369	svc.Handlers.Send.Clear()
370	svc.Handlers.Send.PushFront(func(r *request.Request) {
371		if r.Body != nil {
372			io.Copy(ioutil.Discard, r.Body)
373		}
374
375		r.HTTPResponse = &http.Response{
376			StatusCode: 200,
377			Body:       ioutil.NopCloser(bytes.NewReader([]byte{})),
378		}
379
380		switch data := r.Data.(type) {
381		case *s3.CreateMultipartUploadOutput:
382			data.UploadId = aws.String("UPLOAD-ID")
383		case *s3.UploadPartOutput:
384			data.ETag = aws.String(fmt.Sprintf("ETAG%d", random.Int()))
385		case *s3.CompleteMultipartUploadOutput:
386			data.Location = aws.String("https://location")
387			data.VersionId = aws.String("VERSION-ID")
388		case *s3.PutObjectOutput:
389			data.VersionId = aws.String("VERSION-ID")
390		}
391	})
392
393	pools := map[string]func(sliceSize int64) byteSlicePool{
394		"sync.Pool": func(sliceSize int64) byteSlicePool {
395			return newSyncSlicePool(sliceSize)
396		},
397		"custom": func(sliceSize int64) byteSlicePool {
398			return newMaxSlicePool(sliceSize)
399		},
400	}
401
402	for name, poolFunc := range pools {
403		b.Run(name, func(b *testing.B) {
404			unswap := swapByteSlicePool(poolFunc)
405			defer unswap()
406			for i, c := range cases {
407				b.Run(strconv.Itoa(i), func(b *testing.B) {
408					uploader := NewUploaderWithClient(svc, func(u *Uploader) {
409						u.PartSize = c.PartSize
410						u.Concurrency = c.Concurrency
411					})
412
413					expected := s3testing.GetTestBytes(int(c.FileSize))
414					b.ResetTimer()
415					_, err := uploader.Upload(&UploadInput{
416						Bucket: aws.String("bucket"),
417						Key:    aws.String("key"),
418						Body:   &testReader{br: bytes.NewReader(expected)},
419					})
420					if err != nil {
421						b.Fatalf("expected no error, but got %v", err)
422					}
423				})
424			}
425		})
426	}
427}
428