1package manager
2
3import (
4	"bytes"
5	"context"
6	"fmt"
7	"strconv"
8	"sync"
9	"sync/atomic"
10	"testing"
11
12	"github.com/aws/aws-sdk-go-v2/aws"
13	s3testing "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/testing"
14	"github.com/aws/aws-sdk-go-v2/internal/sdkio"
15	"github.com/aws/aws-sdk-go-v2/service/s3"
16)
17
18type testReader struct {
19	br *bytes.Reader
20	m  sync.Mutex
21}
22
23func (r *testReader) Read(p []byte) (n int, err error) {
24	r.m.Lock()
25	defer r.m.Unlock()
26	return r.br.Read(p)
27}
28
29func TestUploadByteSlicePool(t *testing.T) {
30	cases := map[string]struct {
31		PartSize      int64
32		FileSize      int64
33		Concurrency   int
34		ExAllocations uint64
35	}{
36		"single part, single concurrency": {
37			PartSize:      sdkio.MebiByte * 5,
38			FileSize:      sdkio.MebiByte * 5,
39			ExAllocations: 2,
40			Concurrency:   1,
41		},
42		"multi-part, single concurrency": {
43			PartSize:      sdkio.MebiByte * 5,
44			FileSize:      sdkio.MebiByte * 10,
45			ExAllocations: 2,
46			Concurrency:   1,
47		},
48		"multi-part, multiple concurrency": {
49			PartSize:      sdkio.MebiByte * 5,
50			FileSize:      sdkio.MebiByte * 20,
51			ExAllocations: 3,
52			Concurrency:   2,
53		},
54	}
55
56	for name, tt := range cases {
57		t.Run(name, func(t *testing.T) {
58			var p *recordedPartPool
59
60			unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool {
61				p = newRecordedPartPool(sliceSize)
62				return p
63			})
64			defer unswap()
65
66			client, _, _ := s3testing.NewUploadLoggingClient(nil)
67
68			uploader := NewUploader(client, func(u *Uploader) {
69				u.PartSize = tt.PartSize
70				u.Concurrency = tt.Concurrency
71			})
72
73			expected := s3testing.GetTestBytes(int(tt.FileSize))
74			_, err := uploader.Upload(context.Background(), &s3.PutObjectInput{
75				Bucket: aws.String("bucket"),
76				Key:    aws.String("key"),
77				Body:   &testReader{br: bytes.NewReader(expected)},
78			})
79			if err != nil {
80				t.Errorf("expected no error, but got %v", err)
81			}
82
83			if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 {
84				t.Fatalf("expected zero outsnatding pool parts, got %d", v)
85			}
86
87			gets, allocs := atomic.LoadUint64(&p.recordedGets), atomic.LoadUint64(&p.recordedAllocs)
88
89			t.Logf("total gets %v, total allocations %v", gets, allocs)
90			if e, a := tt.ExAllocations, allocs; a > e {
91				t.Errorf("expected %v allocations, got %v", e, a)
92			}
93		})
94	}
95}
96
97func TestUploadByteSlicePool_Failures(t *testing.T) {
98	const (
99		putObject               = "PutObject"
100		createMultipartUpload   = "CreateMultipartUpload"
101		uploadPart              = "UploadPart"
102		completeMultipartUpload = "CompleteMultipartUpload"
103	)
104
105	cases := map[string]struct {
106		PartSize   int64
107		FileSize   int64
108		Operations []string
109	}{
110		"single part": {
111			PartSize: sdkio.MebiByte * 5,
112			FileSize: sdkio.MebiByte * 4,
113			Operations: []string{
114				putObject,
115			},
116		},
117		"multi-part": {
118			PartSize: sdkio.MebiByte * 5,
119			FileSize: sdkio.MebiByte * 10,
120			Operations: []string{
121				createMultipartUpload,
122				uploadPart,
123				completeMultipartUpload,
124			},
125		},
126	}
127
128	for name, tt := range cases {
129		t.Run(name, func(t *testing.T) {
130			for _, operation := range tt.Operations {
131				t.Run(operation, func(t *testing.T) {
132					var p *recordedPartPool
133
134					unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool {
135						p = newRecordedPartPool(sliceSize)
136						return p
137					})
138					defer unswap()
139
140					client, _, _ := s3testing.NewUploadLoggingClient(nil)
141
142					switch operation {
143					case putObject:
144						client.PutObjectFn = func(*s3testing.UploadLoggingClient, *s3.PutObjectInput) (*s3.PutObjectOutput, error) {
145							return nil, fmt.Errorf("put object failure")
146						}
147					case createMultipartUpload:
148						client.CreateMultipartUploadFn = func(*s3testing.UploadLoggingClient, *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error) {
149							return nil, fmt.Errorf("create multipart upload failure")
150						}
151					case uploadPart:
152						client.UploadPartFn = func(*s3testing.UploadLoggingClient, *s3.UploadPartInput) (*s3.UploadPartOutput, error) {
153							return nil, fmt.Errorf("upload part failure")
154						}
155					case completeMultipartUpload:
156						client.CompleteMultipartUploadFn = func(*s3testing.UploadLoggingClient, *s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error) {
157							return nil, fmt.Errorf("complete multipart upload failure")
158						}
159					}
160
161					uploader := NewUploader(client, func(u *Uploader) {
162						u.Concurrency = 1
163						u.PartSize = tt.PartSize
164					})
165
166					expected := s3testing.GetTestBytes(int(tt.FileSize))
167					_, err := uploader.Upload(context.Background(), &s3.PutObjectInput{
168						Bucket: aws.String("bucket"),
169						Key:    aws.String("key"),
170						Body:   &testReader{br: bytes.NewReader(expected)},
171					})
172					if err == nil {
173						t.Fatalf("expected error but got none")
174					}
175
176					if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 {
177						t.Fatalf("expected zero outsnatding pool parts, got %d", v)
178					}
179				})
180			}
181		})
182	}
183}
184
185func TestUploadByteSlicePoolConcurrentMultiPartSize(t *testing.T) {
186	var (
187		pools []*recordedPartPool
188		mtx   sync.Mutex
189	)
190
191	unswap := swapByteSlicePool(func(sliceSize int64) byteSlicePool {
192		mtx.Lock()
193		defer mtx.Unlock()
194		b := newRecordedPartPool(sliceSize)
195		pools = append(pools, b)
196		return b
197	})
198	defer unswap()
199
200	client, _, _ := s3testing.NewUploadLoggingClient(nil)
201
202	uploader := NewUploader(client, func(u *Uploader) {
203		u.PartSize = 5 * sdkio.MebiByte
204		u.Concurrency = 2
205	})
206
207	var wg sync.WaitGroup
208	for i := 0; i < 2; i++ {
209		wg.Add(2)
210		go func() {
211			defer wg.Done()
212			expected := s3testing.GetTestBytes(int(15 * sdkio.MebiByte))
213			_, err := uploader.Upload(context.Background(), &s3.PutObjectInput{
214				Bucket: aws.String("bucket"),
215				Key:    aws.String("key"),
216				Body:   &testReader{br: bytes.NewReader(expected)},
217			})
218			if err != nil {
219				t.Errorf("expected no error, but got %v", err)
220			}
221		}()
222		go func() {
223			defer wg.Done()
224			expected := s3testing.GetTestBytes(int(15 * sdkio.MebiByte))
225			_, err := uploader.Upload(context.Background(), &s3.PutObjectInput{
226				Bucket: aws.String("bucket"),
227				Key:    aws.String("key"),
228				Body:   &testReader{br: bytes.NewReader(expected)},
229			}, func(u *Uploader) {
230				u.PartSize = 6 * sdkio.MebiByte
231			})
232			if err != nil {
233				t.Errorf("expected no error, but got %v", err)
234			}
235		}()
236	}
237
238	wg.Wait()
239
240	if e, a := 3, len(pools); e != a {
241		t.Errorf("expected %v, got %v", e, a)
242	}
243
244	for _, p := range pools {
245		if v := atomic.LoadInt64(&p.recordedOutstanding); v != 0 {
246			t.Fatalf("expected zero outsnatding pool parts, got %d", v)
247		}
248
249		t.Logf("total gets %v, total allocations %v",
250			atomic.LoadUint64(&p.recordedGets),
251			atomic.LoadUint64(&p.recordedAllocs))
252	}
253}
254
255func BenchmarkPools(b *testing.B) {
256	cases := []struct {
257		PartSize      int64
258		FileSize      int64
259		Concurrency   int
260		ExAllocations uint64
261	}{
262		0: {
263			PartSize:    sdkio.MebiByte * 5,
264			FileSize:    sdkio.MebiByte * 5,
265			Concurrency: 1,
266		},
267		1: {
268			PartSize:    sdkio.MebiByte * 5,
269			FileSize:    sdkio.MebiByte * 10,
270			Concurrency: 1,
271		},
272		2: {
273			PartSize:    sdkio.MebiByte * 5,
274			FileSize:    sdkio.MebiByte * 20,
275			Concurrency: 2,
276		},
277		3: {
278			PartSize:    sdkio.MebiByte * 5,
279			FileSize:    sdkio.MebiByte * 250,
280			Concurrency: 10,
281		},
282	}
283
284	client, _, _ := s3testing.NewUploadLoggingClient(nil)
285
286	pools := map[string]func(sliceSize int64) byteSlicePool{
287		"sync.Pool": func(sliceSize int64) byteSlicePool {
288			return newSyncSlicePool(sliceSize)
289		},
290		"custom": func(sliceSize int64) byteSlicePool {
291			return newMaxSlicePool(sliceSize)
292		},
293	}
294
295	for name, poolFunc := range pools {
296		b.Run(name, func(b *testing.B) {
297			unswap := swapByteSlicePool(poolFunc)
298			defer unswap()
299			for i, c := range cases {
300				b.Run(strconv.Itoa(i), func(b *testing.B) {
301					uploader := NewUploader(client, func(u *Uploader) {
302						u.PartSize = c.PartSize
303						u.Concurrency = c.Concurrency
304					})
305
306					expected := s3testing.GetTestBytes(int(c.FileSize))
307					b.ResetTimer()
308					_, err := uploader.Upload(context.Background(), &s3.PutObjectInput{
309						Bucket: aws.String("bucket"),
310						Key:    aws.String("key"),
311						Body:   &testReader{br: bytes.NewReader(expected)},
312					})
313					if err != nil {
314						b.Fatalf("expected no error, but got %v", err)
315					}
316				})
317			}
318		})
319	}
320}
321