1// Copyright 2018 The Go Cloud Development Kit Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package s3blob
16
17import (
18	"context"
19	"errors"
20	"fmt"
21	"net/http"
22	"testing"
23
24	"github.com/aws/aws-sdk-go/aws"
25	"github.com/aws/aws-sdk-go/aws/awserr"
26	"github.com/aws/aws-sdk-go/aws/client"
27	"github.com/aws/aws-sdk-go/aws/session"
28	"github.com/aws/aws-sdk-go/service/s3"
29	"github.com/aws/aws-sdk-go/service/s3/s3manager"
30	"gocloud.dev/blob"
31	"gocloud.dev/blob/driver"
32	"gocloud.dev/blob/drivertest"
33	"gocloud.dev/internal/testing/setup"
34)
35
36// These constants record the region & bucket used for the last --record.
37// If you want to use --record mode,
38// 1. Create a bucket in your AWS project from the S3 management console.
39//    https://s3.console.aws.amazon.com/s3/home.
40// 2. Update this constant to your bucket name.
41// TODO(issue #300): Use Terraform to provision a bucket, and get the bucket
42//    name from the Terraform output instead (saving a copy of it for replay).
43const (
44	bucketName = "go-cloud-testing"
45	region     = "us-west-1"
46)
47
48type harness struct {
49	session *session.Session
50	opts    *Options
51	rt      http.RoundTripper
52	closer  func()
53}
54
55func newHarness(ctx context.Context, t *testing.T) (drivertest.Harness, error) {
56	sess, rt, done, _ := setup.NewAWSSession(ctx, t, region)
57	return &harness{session: sess, opts: nil, rt: rt, closer: done}, nil
58}
59
60func newHarnessUsingLegacyList(ctx context.Context, t *testing.T) (drivertest.Harness, error) {
61	sess, rt, done, _ := setup.NewAWSSession(ctx, t, region)
62	return &harness{session: sess, opts: &Options{UseLegacyList: true}, rt: rt, closer: done}, nil
63}
64
65func (h *harness) HTTPClient() *http.Client {
66	return &http.Client{Transport: h.rt}
67}
68
69func (h *harness) MakeDriver(ctx context.Context) (driver.Bucket, error) {
70	return openBucket(ctx, h.session, bucketName, h.opts)
71}
72
73func (h *harness) Close() {
74	h.closer()
75}
76
77func TestConformance(t *testing.T) {
78	drivertest.RunConformanceTests(t, newHarness, []drivertest.AsTest{verifyContentLanguage{usingLegacyList: false}})
79}
80
81func TestConformanceUsingLegacyList(t *testing.T) {
82	drivertest.RunConformanceTests(t, newHarnessUsingLegacyList, []drivertest.AsTest{verifyContentLanguage{usingLegacyList: true}})
83}
84
85func BenchmarkS3blob(b *testing.B) {
86	sess, err := session.NewSession(&aws.Config{
87		Region: aws.String(region),
88	})
89	if err != nil {
90		b.Fatal(err)
91	}
92	bkt, err := OpenBucket(context.Background(), sess, bucketName, nil)
93	if err != nil {
94		b.Fatal(err)
95	}
96	drivertest.RunBenchmarks(b, bkt)
97}
98
99const language = "nl"
100
101// verifyContentLanguage uses As to access the underlying GCS types and
102// read/write the ContentLanguage field.
103type verifyContentLanguage struct {
104	usingLegacyList bool
105}
106
107func (verifyContentLanguage) Name() string {
108	return "verify ContentLanguage can be written and read through As"
109}
110
111func (verifyContentLanguage) BucketCheck(b *blob.Bucket) error {
112	var client *s3.S3
113	if !b.As(&client) {
114		return errors.New("Bucket.As failed")
115	}
116	return nil
117}
118
119func (verifyContentLanguage) ErrorCheck(b *blob.Bucket, err error) error {
120	var e awserr.Error
121	if !b.ErrorAs(err, &e) {
122		return errors.New("blob.ErrorAs failed")
123	}
124	return nil
125}
126
127func (verifyContentLanguage) BeforeRead(as func(interface{}) bool) error {
128	var req *s3.GetObjectInput
129	if !as(&req) {
130		return errors.New("BeforeRead As failed")
131	}
132	return nil
133}
134
135func (verifyContentLanguage) BeforeWrite(as func(interface{}) bool) error {
136	var req *s3manager.UploadInput
137	if !as(&req) {
138		return errors.New("Writer.As failed")
139	}
140	req.ContentLanguage = aws.String(language)
141	return nil
142}
143
144func (verifyContentLanguage) BeforeCopy(as func(interface{}) bool) error {
145	var in *s3.CopyObjectInput
146	if !as(&in) {
147		return errors.New("BeforeCopy.As failed")
148	}
149	return nil
150}
151
152func (v verifyContentLanguage) BeforeList(as func(interface{}) bool) error {
153	if v.usingLegacyList {
154		var req *s3.ListObjectsInput
155		if !as(&req) {
156			return errors.New("List.As failed")
157		}
158	} else {
159		var req *s3.ListObjectsV2Input
160		if !as(&req) {
161			return errors.New("List.As failed")
162		}
163	}
164	// Nothing to do.
165	return nil
166}
167
168func (verifyContentLanguage) AttributesCheck(attrs *blob.Attributes) error {
169	var hoo s3.HeadObjectOutput
170	if !attrs.As(&hoo) {
171		return errors.New("Attributes.As returned false")
172	}
173	if got := *hoo.ContentLanguage; got != language {
174		return fmt.Errorf("got %q want %q", got, language)
175	}
176	return nil
177}
178
179func (verifyContentLanguage) ReaderCheck(r *blob.Reader) error {
180	var goo s3.GetObjectOutput
181	if !r.As(&goo) {
182		return errors.New("Reader.As returned false")
183	}
184	if got := *goo.ContentLanguage; got != language {
185		return fmt.Errorf("got %q want %q", got, language)
186	}
187	return nil
188}
189
190func (verifyContentLanguage) ListObjectCheck(o *blob.ListObject) error {
191	if o.IsDir {
192		var commonPrefix s3.CommonPrefix
193		if !o.As(&commonPrefix) {
194			return errors.New("ListObject.As for directory returned false")
195		}
196		return nil
197	}
198	var obj s3.Object
199	if !o.As(&obj) {
200		return errors.New("ListObject.As for object returned false")
201	}
202	if obj.Key == nil || o.Key != *obj.Key {
203		return errors.New("ListObject.As for object returned a different item")
204	}
205	// Nothing to check.
206	return nil
207}
208
209func TestOpenBucket(t *testing.T) {
210	tests := []struct {
211		description string
212		bucketName  string
213		nilSession  bool
214		want        string
215		wantErr     bool
216	}{
217		{
218			description: "empty bucket name results in error",
219			wantErr:     true,
220		},
221		{
222			description: "nil sess results in error",
223			bucketName:  "foo",
224			nilSession:  true,
225			wantErr:     true,
226		},
227		{
228			description: "success",
229			bucketName:  "foo",
230			want:        "foo",
231		},
232	}
233
234	ctx := context.Background()
235	for _, test := range tests {
236		t.Run(test.description, func(t *testing.T) {
237			var sess client.ConfigProvider
238			if !test.nilSession {
239				var done func()
240				sess, _, done, _ = setup.NewAWSSession(ctx, t, region)
241				defer done()
242			}
243
244			// Create driver impl.
245			drv, err := openBucket(ctx, sess, test.bucketName, nil)
246			if (err != nil) != test.wantErr {
247				t.Errorf("got err %v want error %v", err, test.wantErr)
248			}
249			if err == nil && drv != nil && drv.name != test.want {
250				t.Errorf("got %q want %q", drv.name, test.want)
251			}
252
253			// Create portable type.
254			b, err := OpenBucket(ctx, sess, test.bucketName, nil)
255			if b != nil {
256				defer b.Close()
257			}
258			if (err != nil) != test.wantErr {
259				t.Errorf("got err %v want error %v", err, test.wantErr)
260			}
261		})
262	}
263}
264
265func TestOpenBucketFromURL(t *testing.T) {
266	tests := []struct {
267		URL     string
268		WantErr bool
269	}{
270		// OK.
271		{"s3://mybucket", false},
272		// OK, setting region.
273		{"s3://mybucket?region=us-west1", false},
274		// Invalid parameter.
275		{"s3://mybucket?param=value", true},
276	}
277
278	ctx := context.Background()
279	for _, test := range tests {
280		b, err := blob.OpenBucket(ctx, test.URL)
281		if b != nil {
282			defer b.Close()
283		}
284		if (err != nil) != test.WantErr {
285			t.Errorf("%s: got error %v, want error %v", test.URL, err, test.WantErr)
286		}
287	}
288}
289