1// +build integration
2
3package s3_test
4
5import (
6	"bytes"
7	"context"
8	"crypto/tls"
9	"flag"
10	"fmt"
11	"io"
12	"io/ioutil"
13	"net/http"
14	"os"
15	"reflect"
16	"strings"
17	"testing"
18	"time"
19
20	"github.com/aws/aws-sdk-go/aws"
21	"github.com/aws/aws-sdk-go/aws/arn"
22	"github.com/aws/aws-sdk-go/aws/endpoints"
23	"github.com/aws/aws-sdk-go/aws/request"
24	"github.com/aws/aws-sdk-go/awstesting/integration"
25	"github.com/aws/aws-sdk-go/awstesting/integration/s3integ"
26	"github.com/aws/aws-sdk-go/service/s3"
27	"github.com/aws/aws-sdk-go/service/s3control"
28	"github.com/aws/aws-sdk-go/service/sts"
29)
30
31const integBucketPrefix = "aws-sdk-go-integration"
32
33var integMetadata = struct {
34	AccountID string
35	Region    string
36	Buckets   struct {
37		Source struct {
38			Name string
39			ARN  string
40		}
41		Target struct {
42			Name string
43			ARN  string
44		}
45	}
46
47	AccessPoints struct {
48		Source struct {
49			Name string
50			ARN  string
51		}
52		Target struct {
53			Name string
54			ARN  string
55		}
56	}
57}{}
58
59var s3Svc *s3.S3
60var s3ControlSvc *s3control.S3Control
61var stsSvc *sts.STS
62var httpClient *http.Client
63
64// TODO: (Westeros) Remove Custom Resolver Usage Before Launch
65type customS3Resolver struct {
66	endpoint string
67	withTLS  bool
68	region   string
69}
70
71func (r customS3Resolver) EndpointFor(service, _ string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
72	switch strings.ToLower(service) {
73	case "s3-control":
74	case "s3":
75	default:
76		return endpoints.ResolvedEndpoint{}, fmt.Errorf("unsupported in custom resolver")
77	}
78
79	return endpoints.ResolvedEndpoint{
80		PartitionID:   "aws",
81		SigningRegion: r.region,
82		SigningName:   "s3",
83		SigningMethod: "s3v4",
84		URL:           endpoints.AddScheme(r.endpoint, r.withTLS),
85	}, nil
86}
87
88func TestMain(m *testing.M) {
89	var result int
90	defer func() {
91		if r := recover(); r != nil {
92			fmt.Fprintln(os.Stderr, "S3 integration tests paniced,", r)
93			result = 1
94		}
95		os.Exit(result)
96	}()
97
98	var verifyTLS bool
99	var s3Endpoint, s3ControlEndpoint string
100	var s3EnableTLS, s3ControlEnableTLS bool
101
102	flag.StringVar(&s3Endpoint, "s3-endpoint", "", "integration endpoint for S3")
103	flag.BoolVar(&s3EnableTLS, "s3-tls", true, "enable TLS for S3 endpoint")
104
105	flag.StringVar(&s3ControlEndpoint, "s3-control-endpoint", "", "integration endpoint for S3")
106	flag.BoolVar(&s3ControlEnableTLS, "s3-control-tls", true, "enable TLS for S3 control endpoint")
107
108	flag.StringVar(&integMetadata.AccountID, "account", "", "integration account id")
109	flag.BoolVar(&verifyTLS, "verify-tls", true, "verify server TLS certificate")
110	flag.Parse()
111
112	httpClient = &http.Client{
113		Transport: &http.Transport{
114			TLSClientConfig: &tls.Config{InsecureSkipVerify: verifyTLS},
115		}}
116
117	sess := integration.SessionWithDefaultRegion("us-west-2").Copy(&aws.Config{
118		HTTPClient: httpClient,
119	})
120
121	var s3EndpointResolver endpoints.Resolver
122	if len(s3Endpoint) != 0 {
123		s3EndpointResolver = customS3Resolver{
124			endpoint: s3Endpoint,
125			withTLS:  s3EnableTLS,
126			region:   aws.StringValue(sess.Config.Region),
127		}
128	}
129	s3Svc = s3.New(sess, &aws.Config{
130		DisableSSL:       aws.Bool(!s3EnableTLS),
131		EndpointResolver: s3EndpointResolver,
132	})
133
134	var s3ControlEndpointResolver endpoints.Resolver
135	if len(s3Endpoint) != 0 {
136		s3ControlEndpointResolver = customS3Resolver{
137			endpoint: s3ControlEndpoint,
138			withTLS:  s3ControlEnableTLS,
139			region:   aws.StringValue(sess.Config.Region),
140		}
141	}
142	s3ControlSvc = s3control.New(sess, &aws.Config{
143		DisableSSL:       aws.Bool(!s3ControlEnableTLS),
144		EndpointResolver: s3ControlEndpointResolver,
145	})
146	stsSvc = sts.New(sess)
147
148	var err error
149	integMetadata.AccountID, err = getAccountID()
150	if err != nil {
151		fmt.Fprintf(os.Stderr, "failed to get integration aws account id: %v\n", err)
152		result = 1
153		return
154	}
155
156	bucketCleanup, err := setupBuckets()
157	defer bucketCleanup()
158	if err != nil {
159		fmt.Fprintf(os.Stderr, "failed to setup integration test buckets: %v\n", err)
160		result = 1
161		return
162	}
163
164	accessPointsCleanup, err := setupAccessPoints()
165	defer accessPointsCleanup()
166	if err != nil {
167		fmt.Fprintf(os.Stderr, "failed to setup integration test access points: %v\n", err)
168		result = 1
169		return
170	}
171
172	result = m.Run()
173}
174
175func getAccountID() (string, error) {
176	if len(integMetadata.AccountID) != 0 {
177		return integMetadata.AccountID, nil
178	}
179
180	output, err := stsSvc.GetCallerIdentity(nil)
181	if err != nil {
182		return "", fmt.Errorf("faield to get sts caller identity")
183	}
184
185	return *output.Account, nil
186}
187
188func setupBuckets() (func(), error) {
189	var cleanups []func()
190
191	cleanup := func() {
192		for i := range cleanups {
193			cleanups[i]()
194		}
195	}
196
197	bucketCreates := []struct {
198		name *string
199		arn  *string
200	}{
201		{name: &integMetadata.Buckets.Source.Name, arn: &integMetadata.Buckets.Source.ARN},
202		{name: &integMetadata.Buckets.Target.Name, arn: &integMetadata.Buckets.Target.ARN},
203	}
204
205	for _, bucket := range bucketCreates {
206		*bucket.name = s3integ.GenerateBucketName()
207
208		if err := s3integ.SetupBucket(s3Svc, *bucket.name); err != nil {
209			return cleanup, err
210		}
211
212		// Compute ARN
213		bARN := arn.ARN{
214			Partition: "aws",
215			Service:   "s3",
216			Region:    s3Svc.SigningRegion,
217			AccountID: integMetadata.AccountID,
218			Resource:  fmt.Sprintf("bucket_name:%s", *bucket.name),
219		}.String()
220
221		*bucket.arn = bARN
222
223		bucketName := *bucket.name
224		cleanups = append(cleanups, func() {
225			if err := s3integ.CleanupBucket(s3Svc, bucketName); err != nil {
226				fmt.Fprintln(os.Stderr, err)
227			}
228		})
229	}
230
231	return cleanup, nil
232}
233
234func setupAccessPoints() (func(), error) {
235	var cleanups []func()
236
237	cleanup := func() {
238		for i := range cleanups {
239			cleanups[i]()
240		}
241	}
242
243	creates := []struct {
244		bucket string
245		name   *string
246		arn    *string
247	}{
248		{bucket: integMetadata.Buckets.Source.Name, name: &integMetadata.AccessPoints.Source.Name, arn: &integMetadata.AccessPoints.Source.ARN},
249		{bucket: integMetadata.Buckets.Target.Name, name: &integMetadata.AccessPoints.Target.Name, arn: &integMetadata.AccessPoints.Target.ARN},
250	}
251
252	for _, ap := range creates {
253		*ap.name = integration.UniqueID()
254
255		err := s3integ.SetupAccessPoint(s3ControlSvc, integMetadata.AccountID, ap.bucket, *ap.name)
256		if err != nil {
257			return cleanup, err
258		}
259
260		// Compute ARN
261		apARN := arn.ARN{
262			Partition: "aws",
263			Service:   "s3",
264			Region:    s3ControlSvc.SigningRegion,
265			AccountID: integMetadata.AccountID,
266			Resource:  fmt.Sprintf("accesspoint/%s", *ap.name),
267		}.String()
268
269		*ap.arn = apARN
270
271		apName := *ap.name
272		cleanups = append(cleanups, func() {
273			err := s3integ.CleanupAccessPoint(s3ControlSvc, integMetadata.AccountID, apName)
274			if err != nil {
275				fmt.Fprintln(os.Stderr, err)
276			}
277		})
278	}
279
280	return cleanup, nil
281}
282
283func putTestFile(t *testing.T, filename, key string, opts ...request.Option) {
284	f, err := os.Open(filename)
285	if err != nil {
286		t.Fatalf("failed to open testfile, %v", err)
287	}
288	defer f.Close()
289
290	putTestContent(t, f, key, opts...)
291}
292
293func putTestContent(t *testing.T, reader io.ReadSeeker, key string, opts ...request.Option) {
294	t.Logf("uploading test file %s/%s", integMetadata.Buckets.Source.Name, key)
295	_, err := s3Svc.PutObjectWithContext(context.Background(),
296		&s3.PutObjectInput{
297			Bucket: &integMetadata.Buckets.Source.Name,
298			Key:    aws.String(key),
299			Body:   reader,
300		}, opts...)
301	if err != nil {
302		t.Errorf("expect no error, got %v", err)
303	}
304}
305
306func testWriteToObject(t *testing.T, bucket string, opts ...request.Option) {
307	key := integration.UniqueID()
308
309	_, err := s3Svc.PutObjectWithContext(context.Background(),
310		&s3.PutObjectInput{
311			Bucket: &bucket,
312			Key:    &key,
313			Body:   bytes.NewReader([]byte("hello world")),
314		}, opts...)
315	if err != nil {
316		t.Fatalf("expect no error, got %v", err)
317	}
318
319	resp, err := s3Svc.GetObjectWithContext(context.Background(),
320		&s3.GetObjectInput{
321			Bucket: &bucket,
322			Key:    &key,
323		}, opts...)
324	if err != nil {
325		t.Fatalf("expect no error, got %v", err)
326	}
327
328	b, _ := ioutil.ReadAll(resp.Body)
329	if e, a := []byte("hello world"), b; !bytes.Equal(e, a) {
330		t.Errorf("expect %v, got %v", e, a)
331	}
332}
333
334func testPresignedGetPut(t *testing.T, bucket string, opts ...request.Option) {
335	key := integration.UniqueID()
336
337	putreq, _ := s3Svc.PutObjectRequest(&s3.PutObjectInput{
338		Bucket: &bucket,
339		Key:    &key,
340	})
341	putreq.ApplyOptions(opts...)
342	var err error
343
344	// Presign a PUT request
345	var puturl string
346	puturl, err = putreq.Presign(5 * time.Minute)
347	if err != nil {
348		t.Fatalf("expect no error, got %v", err)
349	}
350
351	// PUT to the presigned URL with a body
352	var puthttpreq *http.Request
353	buf := bytes.NewReader([]byte("hello world"))
354	puthttpreq, err = http.NewRequest("PUT", puturl, buf)
355	if err != nil {
356		t.Fatalf("expect no error, got %v", err)
357	}
358
359	var putresp *http.Response
360	putresp, err = httpClient.Do(puthttpreq)
361	if err != nil {
362		t.Errorf("expect put with presign url no error, got %v", err)
363	}
364	if e, a := 200, putresp.StatusCode; e != a {
365		t.Fatalf("expect %v, got %v", e, a)
366	}
367
368	// Presign a GET on the same URL
369	getreq, _ := s3Svc.GetObjectRequest(&s3.GetObjectInput{
370		Bucket: &bucket,
371		Key:    &key,
372	})
373	getreq.ApplyOptions(opts...)
374
375	var geturl string
376	geturl, err = getreq.Presign(300 * time.Second)
377	if err != nil {
378		t.Fatalf("expect no error, got %v", err)
379	}
380
381	// Get the body
382	var getresp *http.Response
383	getresp, err = httpClient.Get(geturl)
384	if err != nil {
385		t.Fatalf("expect no error, got %v", err)
386	}
387
388	var b []byte
389	defer getresp.Body.Close()
390	b, err = ioutil.ReadAll(getresp.Body)
391	if e, a := "hello world", string(b); e != a {
392		t.Fatalf("expect %v, got %v", e, a)
393	}
394}
395
396func testCopyObject(t *testing.T, sourceBucket string, targetBucket string, opts ...request.Option) {
397	key := integration.UniqueID()
398
399	_, err := s3Svc.PutObjectWithContext(context.Background(),
400		&s3.PutObjectInput{
401			Bucket: &sourceBucket,
402			Key:    &key,
403			Body:   bytes.NewReader([]byte("hello world")),
404		}, opts...)
405	if err != nil {
406		t.Fatalf("expect no error, got %v", err)
407	}
408
409	_, err = s3Svc.CopyObjectWithContext(context.Background(),
410		&s3.CopyObjectInput{
411			Bucket:     &targetBucket,
412			CopySource: aws.String("/" + sourceBucket + "/" + key),
413			Key:        &key,
414		}, opts...)
415	if err != nil {
416		t.Fatalf("expect no error, got %v", err)
417	}
418
419	resp, err := s3Svc.GetObjectWithContext(context.Background(),
420		&s3.GetObjectInput{
421			Bucket: &targetBucket,
422			Key:    &key,
423		}, opts...)
424	if err != nil {
425		t.Fatalf("expect no error, got %v", err)
426	}
427
428	b, _ := ioutil.ReadAll(resp.Body)
429	if e, a := []byte("hello world"), b; !reflect.DeepEqual(e, a) {
430		t.Errorf("expect %v, got %v", e, a)
431	}
432}
433