1package getter
2
3import (
4	"net/url"
5	"os"
6	"path/filepath"
7	"testing"
8
9	"github.com/aws/aws-sdk-go/aws/awserr"
10)
11
12func init() {
13	// These are well known restricted IAM keys to a HashiCorp-managed bucket
14	// in a private AWS account that only has access to the open source test
15	// resources.
16	//
17	// We do the string concat below to avoid AWS autodetection of a key. This
18	// key is locked down an IAM policy that is read-only so we're purposely
19	// exposing it.
20	os.Setenv("AWS_ACCESS_KEY", "AKIAITTDR"+"WY2STXOZE2A")
21	os.Setenv("AWS_SECRET_KEY", "oMwSyqdass2kPF"+"/7ORZA9dlb/iegz+89B0Cy01Ea")
22}
23
24func TestS3Getter_impl(t *testing.T) {
25	var _ Getter = new(S3Getter)
26}
27
28func TestS3Getter(t *testing.T) {
29	g := new(S3Getter)
30	dst := tempDir(t)
31
32	// With a dir that doesn't exist
33	err := g.Get(
34		dst, testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder"))
35	if err != nil {
36		t.Fatalf("err: %s", err)
37	}
38
39	// Verify the main file exists
40	mainPath := filepath.Join(dst, "main.tf")
41	if _, err := os.Stat(mainPath); err != nil {
42		t.Fatalf("err: %s", err)
43	}
44}
45
46func TestS3Getter_subdir(t *testing.T) {
47	g := new(S3Getter)
48	dst := tempDir(t)
49
50	// With a dir that doesn't exist
51	err := g.Get(
52		dst, testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/subfolder"))
53	if err != nil {
54		t.Fatalf("err: %s", err)
55	}
56
57	// Verify the main file exists
58	subPath := filepath.Join(dst, "sub.tf")
59	if _, err := os.Stat(subPath); err != nil {
60		t.Fatalf("err: %s", err)
61	}
62}
63
64func TestS3Getter_GetFile(t *testing.T) {
65	g := new(S3Getter)
66	dst := tempTestFile(t)
67	defer os.RemoveAll(filepath.Dir(dst))
68
69	// Download
70	err := g.GetFile(
71		dst, testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/main.tf"))
72	if err != nil {
73		t.Fatalf("err: %s", err)
74	}
75
76	// Verify the main file exists
77	if _, err := os.Stat(dst); err != nil {
78		t.Fatalf("err: %s", err)
79	}
80	assertContents(t, dst, "# Main\n")
81}
82
83func TestS3Getter_GetFile_badParams(t *testing.T) {
84	g := new(S3Getter)
85	dst := tempTestFile(t)
86	defer os.RemoveAll(filepath.Dir(dst))
87
88	// Download
89	err := g.GetFile(
90		dst,
91		testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/main.tf?aws_access_key_id=foo&aws_access_key_secret=bar&aws_access_token=baz"))
92	if err == nil {
93		t.Fatalf("expected error, got none")
94	}
95
96	if reqerr, ok := err.(awserr.RequestFailure); !ok || reqerr.StatusCode() != 403 {
97		t.Fatalf("expected InvalidAccessKeyId error")
98	}
99}
100
101func TestS3Getter_GetFile_notfound(t *testing.T) {
102	g := new(S3Getter)
103	dst := tempTestFile(t)
104	defer os.RemoveAll(filepath.Dir(dst))
105
106	// Download
107	err := g.GetFile(
108		dst, testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/404.tf"))
109	if err == nil {
110		t.Fatalf("expected error, got none")
111	}
112}
113
114func TestS3Getter_ClientMode_dir(t *testing.T) {
115	g := new(S3Getter)
116
117	// Check client mode on a key prefix with only a single key.
118	mode, err := g.ClientMode(
119		testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder"))
120	if err != nil {
121		t.Fatalf("err: %s", err)
122	}
123	if mode != ClientModeDir {
124		t.Fatal("expect ClientModeDir")
125	}
126}
127
128func TestS3Getter_ClientMode_file(t *testing.T) {
129	g := new(S3Getter)
130
131	// Check client mode on a key prefix which contains sub-keys.
132	mode, err := g.ClientMode(
133		testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/main.tf"))
134	if err != nil {
135		t.Fatalf("err: %s", err)
136	}
137	if mode != ClientModeFile {
138		t.Fatal("expect ClientModeFile")
139	}
140}
141
142func TestS3Getter_ClientMode_notfound(t *testing.T) {
143	g := new(S3Getter)
144
145	// Check the client mode when a non-existent key is looked up. This does not
146	// return an error, but rather should just return the file mode so that S3
147	// can return an appropriate error later on. This also checks that the
148	// prefix is handled properly (e.g., "/fold" and "/folder" don't put the
149	// client mode into "dir".
150	mode, err := g.ClientMode(
151		testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/fold"))
152	if err != nil {
153		t.Fatalf("err: %s", err)
154	}
155	if mode != ClientModeFile {
156		t.Fatal("expect ClientModeFile")
157	}
158}
159
160func TestS3Getter_ClientMode_collision(t *testing.T) {
161	g := new(S3Getter)
162
163	// Check that the client mode is "file" if there is both an object and a
164	// folder with a common prefix (i.e., a "collision" in the namespace).
165	mode, err := g.ClientMode(
166		testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/collision/foo"))
167	if err != nil {
168		t.Fatalf("err: %s", err)
169	}
170	if mode != ClientModeFile {
171		t.Fatal("expect ClientModeFile")
172	}
173}
174
175func TestS3Getter_Url(t *testing.T) {
176	var s3tests = []struct {
177		name    string
178		url     string
179		region  string
180		bucket  string
181		path    string
182		version string
183	}{
184		{
185			name:    "AWSv1234",
186			url:     "s3::https://s3-eu-west-1.amazonaws.com/bucket/foo/bar.baz?version=1234",
187			region:  "eu-west-1",
188			bucket:  "bucket",
189			path:    "foo/bar.baz",
190			version: "1234",
191		},
192		{
193			name:    "localhost-1",
194			url:     "s3::http://127.0.0.1:9000/test-bucket/hello.txt?aws_access_key_id=TESTID&aws_access_key_secret=TestSecret&region=us-east-2&version=1",
195			region:  "us-east-2",
196			bucket:  "test-bucket",
197			path:    "hello.txt",
198			version: "1",
199		},
200		{
201			name:    "localhost-2",
202			url:     "s3::http://127.0.0.1:9000/test-bucket/hello.txt?aws_access_key_id=TESTID&aws_access_key_secret=TestSecret&version=1",
203			region:  "us-east-1",
204			bucket:  "test-bucket",
205			path:    "hello.txt",
206			version: "1",
207		},
208		{
209			name:    "localhost-3",
210			url:     "s3::http://127.0.0.1:9000/test-bucket/hello.txt?aws_access_key_id=TESTID&aws_access_key_secret=TestSecret",
211			region:  "us-east-1",
212			bucket:  "test-bucket",
213			path:    "hello.txt",
214			version: "",
215		},
216	}
217
218	for i, pt := range s3tests {
219		t.Run(pt.name, func(t *testing.T) {
220			g := new(S3Getter)
221			forced, src := getForcedGetter(pt.url)
222			u, err := url.Parse(src)
223
224			if err != nil {
225				t.Errorf("test %d: unexpected error: %s", i, err)
226			}
227			if forced != "s3" {
228				t.Fatalf("expected forced protocol to be s3")
229			}
230
231			region, bucket, path, version, creds, err := g.parseUrl(u)
232
233			if err != nil {
234				t.Fatalf("err: %s", err)
235			}
236			if region != pt.region {
237				t.Fatalf("expected %s, got %s", pt.region, region)
238			}
239			if bucket != pt.bucket {
240				t.Fatalf("expected %s, got %s", pt.bucket, bucket)
241			}
242			if path != pt.path {
243				t.Fatalf("expected %s, got %s", pt.path, path)
244			}
245			if version != pt.version {
246				t.Fatalf("expected %s, got %s", pt.version, version)
247			}
248			if &creds == nil {
249				t.Fatalf("expected to not be nil")
250			}
251		})
252	}
253}
254