1//go:build integration
2// +build integration
3
4package s3_test
5
6import (
7	"bytes"
8	"context"
9	"crypto/tls"
10	"flag"
11	"fmt"
12	"io"
13	"io/ioutil"
14	"net/http"
15	"os"
16	"reflect"
17	"strings"
18	"testing"
19	"time"
20
21	"github.com/aws/aws-sdk-go/aws"
22	"github.com/aws/aws-sdk-go/aws/arn"
23	"github.com/aws/aws-sdk-go/aws/endpoints"
24	"github.com/aws/aws-sdk-go/aws/request"
25	"github.com/aws/aws-sdk-go/awstesting/integration"
26	"github.com/aws/aws-sdk-go/awstesting/integration/s3integ"
27	"github.com/aws/aws-sdk-go/service/s3"
28	"github.com/aws/aws-sdk-go/service/s3control"
29	"github.com/aws/aws-sdk-go/service/sts"
30)
31
32const integBucketPrefix = "aws-sdk-go-integration"
33
34var integMetadata = struct {
35	AccountID string
36	Region    string
37	Buckets   struct {
38		Source struct {
39			Name string
40			ARN  string
41		}
42		Target struct {
43			Name string
44			ARN  string
45		}
46	}
47
48	AccessPoints struct {
49		Source struct {
50			Name string
51			ARN  string
52		}
53		Target struct {
54			Name string
55			ARN  string
56		}
57	}
58}{}
59
60var s3Svc *s3.S3
61var s3ControlSvc *s3control.S3Control
62var stsSvc *sts.STS
63var httpClient *http.Client
64
65// TODO: (Westeros) Remove Custom Resolver Usage Before Launch
66type customS3Resolver struct {
67	endpoint string
68	withTLS  bool
69	region   string
70}
71
72func (r customS3Resolver) EndpointFor(service, _ string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
73	switch strings.ToLower(service) {
74	case "s3-control":
75	case "s3":
76	default:
77		return endpoints.ResolvedEndpoint{}, fmt.Errorf("unsupported in custom resolver")
78	}
79
80	return endpoints.ResolvedEndpoint{
81		PartitionID:   "aws",
82		SigningRegion: r.region,
83		SigningName:   "s3",
84		SigningMethod: "s3v4",
85		URL:           endpoints.AddScheme(r.endpoint, r.withTLS),
86	}, nil
87}
88
89func TestMain(m *testing.M) {
90	var result int
91	defer func() {
92		if r := recover(); r != nil {
93			fmt.Fprintln(os.Stderr, "S3 integration tests paniced,", r)
94			result = 1
95		}
96		os.Exit(result)
97	}()
98
99	var verifyTLS bool
100	var s3Endpoint, s3ControlEndpoint string
101	var s3EnableTLS, s3ControlEnableTLS bool
102
103	flag.StringVar(&s3Endpoint, "s3-endpoint", "", "integration endpoint for S3")
104	flag.BoolVar(&s3EnableTLS, "s3-tls", true, "enable TLS for S3 endpoint")
105
106	flag.StringVar(&s3ControlEndpoint, "s3-control-endpoint", "", "integration endpoint for S3")
107	flag.BoolVar(&s3ControlEnableTLS, "s3-control-tls", true, "enable TLS for S3 control endpoint")
108
109	flag.StringVar(&integMetadata.AccountID, "account", "", "integration account id")
110	flag.BoolVar(&verifyTLS, "verify-tls", true, "verify server TLS certificate")
111	flag.Parse()
112
113	httpClient = &http.Client{
114		Transport: &http.Transport{
115			TLSClientConfig: &tls.Config{InsecureSkipVerify: verifyTLS},
116		}}
117
118	sess := integration.SessionWithDefaultRegion("us-west-2").Copy(&aws.Config{
119		HTTPClient: httpClient,
120	})
121
122	var s3EndpointResolver endpoints.Resolver
123	if len(s3Endpoint) != 0 {
124		s3EndpointResolver = customS3Resolver{
125			endpoint: s3Endpoint,
126			withTLS:  s3EnableTLS,
127			region:   aws.StringValue(sess.Config.Region),
128		}
129	}
130	s3Svc = s3.New(sess, &aws.Config{
131		DisableSSL:       aws.Bool(!s3EnableTLS),
132		EndpointResolver: s3EndpointResolver,
133	})
134
135	var s3ControlEndpointResolver endpoints.Resolver
136	if len(s3Endpoint) != 0 {
137		s3ControlEndpointResolver = customS3Resolver{
138			endpoint: s3ControlEndpoint,
139			withTLS:  s3ControlEnableTLS,
140			region:   aws.StringValue(sess.Config.Region),
141		}
142	}
143	s3ControlSvc = s3control.New(sess, &aws.Config{
144		DisableSSL:       aws.Bool(!s3ControlEnableTLS),
145		EndpointResolver: s3ControlEndpointResolver,
146	})
147	stsSvc = sts.New(sess)
148
149	var err error
150	integMetadata.AccountID, err = getAccountID()
151	if err != nil {
152		fmt.Fprintf(os.Stderr, "failed to get integration aws account id: %v\n", err)
153		result = 1
154		return
155	}
156
157	bucketCleanup, err := setupBuckets()
158	defer bucketCleanup()
159	if err != nil {
160		fmt.Fprintf(os.Stderr, "failed to setup integration test buckets: %v\n", err)
161		result = 1
162		return
163	}
164
165	accessPointsCleanup, err := setupAccessPoints()
166	defer accessPointsCleanup()
167	if err != nil {
168		fmt.Fprintf(os.Stderr, "failed to setup integration test access points: %v\n", err)
169		result = 1
170		return
171	}
172
173	result = m.Run()
174}
175
176func getAccountID() (string, error) {
177	if len(integMetadata.AccountID) != 0 {
178		return integMetadata.AccountID, nil
179	}
180
181	output, err := stsSvc.GetCallerIdentity(nil)
182	if err != nil {
183		return "", fmt.Errorf("faield to get sts caller identity")
184	}
185
186	return *output.Account, nil
187}
188
189func setupBuckets() (func(), error) {
190	var cleanups []func()
191
192	cleanup := func() {
193		for i := range cleanups {
194			cleanups[i]()
195		}
196	}
197
198	bucketCreates := []struct {
199		name *string
200		arn  *string
201	}{
202		{name: &integMetadata.Buckets.Source.Name, arn: &integMetadata.Buckets.Source.ARN},
203		{name: &integMetadata.Buckets.Target.Name, arn: &integMetadata.Buckets.Target.ARN},
204	}
205
206	for _, bucket := range bucketCreates {
207		*bucket.name = s3integ.GenerateBucketName()
208
209		if err := s3integ.SetupBucket(s3Svc, *bucket.name); err != nil {
210			return cleanup, err
211		}
212
213		// Compute ARN
214		bARN := arn.ARN{
215			Partition: "aws",
216			Service:   "s3",
217			Region:    s3Svc.SigningRegion,
218			AccountID: integMetadata.AccountID,
219			Resource:  fmt.Sprintf("bucket_name:%s", *bucket.name),
220		}.String()
221
222		*bucket.arn = bARN
223
224		bucketName := *bucket.name
225		cleanups = append(cleanups, func() {
226			if err := s3integ.CleanupBucket(s3Svc, bucketName); err != nil {
227				fmt.Fprintln(os.Stderr, err)
228			}
229		})
230	}
231
232	return cleanup, nil
233}
234
235func setupAccessPoints() (func(), error) {
236	var cleanups []func()
237
238	cleanup := func() {
239		for i := range cleanups {
240			cleanups[i]()
241		}
242	}
243
244	creates := []struct {
245		bucket string
246		name   *string
247		arn    *string
248	}{
249		{bucket: integMetadata.Buckets.Source.Name, name: &integMetadata.AccessPoints.Source.Name, arn: &integMetadata.AccessPoints.Source.ARN},
250		{bucket: integMetadata.Buckets.Target.Name, name: &integMetadata.AccessPoints.Target.Name, arn: &integMetadata.AccessPoints.Target.ARN},
251	}
252
253	for _, ap := range creates {
254		*ap.name = integration.UniqueID()
255
256		err := s3integ.SetupAccessPoint(s3ControlSvc, integMetadata.AccountID, ap.bucket, *ap.name)
257		if err != nil {
258			return cleanup, err
259		}
260
261		// Compute ARN
262		apARN := arn.ARN{
263			Partition: "aws",
264			Service:   "s3",
265			Region:    s3ControlSvc.SigningRegion,
266			AccountID: integMetadata.AccountID,
267			Resource:  fmt.Sprintf("accesspoint/%s", *ap.name),
268		}.String()
269
270		*ap.arn = apARN
271
272		apName := *ap.name
273		cleanups = append(cleanups, func() {
274			err := s3integ.CleanupAccessPoint(s3ControlSvc, integMetadata.AccountID, apName)
275			if err != nil {
276				fmt.Fprintln(os.Stderr, err)
277			}
278		})
279	}
280
281	return cleanup, nil
282}
283
284func putTestFile(t *testing.T, filename, key string, opts ...request.Option) {
285	f, err := os.Open(filename)
286	if err != nil {
287		t.Fatalf("failed to open testfile, %v", err)
288	}
289	defer f.Close()
290
291	putTestContent(t, f, key, opts...)
292}
293
294func putTestContent(t *testing.T, reader io.ReadSeeker, key string, opts ...request.Option) {
295	t.Logf("uploading test file %s/%s", integMetadata.Buckets.Source.Name, key)
296	_, err := s3Svc.PutObjectWithContext(context.Background(),
297		&s3.PutObjectInput{
298			Bucket: &integMetadata.Buckets.Source.Name,
299			Key:    aws.String(key),
300			Body:   reader,
301		}, opts...)
302	if err != nil {
303		t.Errorf("expect no error, got %v", err)
304	}
305}
306
307func testWriteToObject(t *testing.T, bucket string, opts ...request.Option) {
308	key := integration.UniqueID()
309
310	_, err := s3Svc.PutObjectWithContext(context.Background(),
311		&s3.PutObjectInput{
312			Bucket: &bucket,
313			Key:    &key,
314			Body:   bytes.NewReader([]byte("hello world")),
315		}, opts...)
316	if err != nil {
317		t.Fatalf("expect no error, got %v", err)
318	}
319
320	resp, err := s3Svc.GetObjectWithContext(context.Background(),
321		&s3.GetObjectInput{
322			Bucket: &bucket,
323			Key:    &key,
324		}, opts...)
325	if err != nil {
326		t.Fatalf("expect no error, got %v", err)
327	}
328
329	b, _ := ioutil.ReadAll(resp.Body)
330	if e, a := []byte("hello world"), b; !bytes.Equal(e, a) {
331		t.Errorf("expect %v, got %v", e, a)
332	}
333}
334
335func testPresignedGetPut(t *testing.T, bucket string, opts ...request.Option) {
336	key := integration.UniqueID()
337
338	putreq, _ := s3Svc.PutObjectRequest(&s3.PutObjectInput{
339		Bucket: &bucket,
340		Key:    &key,
341	})
342	putreq.ApplyOptions(opts...)
343	var err error
344
345	// Presign a PUT request
346	var puturl string
347	puturl, err = putreq.Presign(5 * time.Minute)
348	if err != nil {
349		t.Fatalf("expect no error, got %v", err)
350	}
351
352	// PUT to the presigned URL with a body
353	var puthttpreq *http.Request
354	buf := bytes.NewReader([]byte("hello world"))
355	puthttpreq, err = http.NewRequest("PUT", puturl, buf)
356	if err != nil {
357		t.Fatalf("expect no error, got %v", err)
358	}
359
360	var putresp *http.Response
361	putresp, err = httpClient.Do(puthttpreq)
362	if err != nil {
363		t.Errorf("expect put with presign url no error, got %v", err)
364	}
365	if e, a := 200, putresp.StatusCode; e != a {
366		t.Fatalf("expect %v, got %v", e, a)
367	}
368
369	// Presign a GET on the same URL
370	getreq, _ := s3Svc.GetObjectRequest(&s3.GetObjectInput{
371		Bucket: &bucket,
372		Key:    &key,
373	})
374	getreq.ApplyOptions(opts...)
375
376	var geturl string
377	geturl, err = getreq.Presign(300 * time.Second)
378	if err != nil {
379		t.Fatalf("expect no error, got %v", err)
380	}
381
382	// Get the body
383	var getresp *http.Response
384	getresp, err = httpClient.Get(geturl)
385	if err != nil {
386		t.Fatalf("expect no error, got %v", err)
387	}
388
389	var b []byte
390	defer getresp.Body.Close()
391	b, err = ioutil.ReadAll(getresp.Body)
392	if e, a := "hello world", string(b); e != a {
393		t.Fatalf("expect %v, got %v", e, a)
394	}
395}
396
397func testCopyObject(t *testing.T, sourceBucket string, targetBucket string, opts ...request.Option) {
398	key := integration.UniqueID()
399
400	_, err := s3Svc.PutObjectWithContext(context.Background(),
401		&s3.PutObjectInput{
402			Bucket: &sourceBucket,
403			Key:    &key,
404			Body:   bytes.NewReader([]byte("hello world")),
405		}, opts...)
406	if err != nil {
407		t.Fatalf("expect no error, got %v", err)
408	}
409
410	_, err = s3Svc.CopyObjectWithContext(context.Background(),
411		&s3.CopyObjectInput{
412			Bucket:     &targetBucket,
413			CopySource: aws.String("/" + sourceBucket + "/" + key),
414			Key:        &key,
415		}, opts...)
416	if err != nil {
417		t.Fatalf("expect no error, got %v", err)
418	}
419
420	resp, err := s3Svc.GetObjectWithContext(context.Background(),
421		&s3.GetObjectInput{
422			Bucket: &targetBucket,
423			Key:    &key,
424		}, opts...)
425	if err != nil {
426		t.Fatalf("expect no error, got %v", err)
427	}
428
429	b, _ := ioutil.ReadAll(resp.Body)
430	if e, a := []byte("hello world"), b; !reflect.DeepEqual(e, a) {
431		t.Errorf("expect %v, got %v", e, a)
432	}
433}
434