1package processcreds_test
2
3import (
4	"bytes"
5	"encoding/json"
6	"fmt"
7	"io"
8	"io/ioutil"
9	"os"
10	"os/exec"
11	"runtime"
12	"strings"
13	"testing"
14	"time"
15
16	"github.com/aws/aws-sdk-go/aws"
17	"github.com/aws/aws-sdk-go/aws/awserr"
18	"github.com/aws/aws-sdk-go/aws/credentials/processcreds"
19	"github.com/aws/aws-sdk-go/aws/session"
20	"github.com/aws/aws-sdk-go/internal/sdktesting"
21)
22
23func TestProcessProviderFromSessionCfg(t *testing.T) {
24	restoreEnvFn := sdktesting.StashEnv()
25	defer restoreEnvFn()
26
27	os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
28	if runtime.GOOS == "windows" {
29		os.Setenv("AWS_CONFIG_FILE", "testdata\\shconfig_win.ini")
30	} else {
31		os.Setenv("AWS_CONFIG_FILE", "testdata/shconfig.ini")
32	}
33
34	sess, err := session.NewSession(&aws.Config{
35		Region: aws.String("region")},
36	)
37
38	if err != nil {
39		t.Errorf("error getting session: %v", err)
40	}
41
42	creds, err := sess.Config.Credentials.Get()
43	if err != nil {
44		t.Errorf("error getting credentials: %v", err)
45	}
46
47	if e, a := "accessKey", creds.AccessKeyID; e != a {
48		t.Errorf("expected %v, got %v", e, a)
49	}
50
51	if e, a := "secret", creds.SecretAccessKey; e != a {
52		t.Errorf("expected %v, got %v", e, a)
53	}
54
55	if e, a := "tokenDefault", creds.SessionToken; e != a {
56		t.Errorf("expected %v, got %v", e, a)
57	}
58
59}
60
61func TestProcessProviderFromSessionWithProfileCfg(t *testing.T) {
62	restoreEnvFn := sdktesting.StashEnv()
63	defer restoreEnvFn()
64
65	os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
66	os.Setenv("AWS_PROFILE", "non_expire")
67	if runtime.GOOS == "windows" {
68		os.Setenv("AWS_CONFIG_FILE", "testdata\\shconfig_win.ini")
69	} else {
70		os.Setenv("AWS_CONFIG_FILE", "testdata/shconfig.ini")
71	}
72
73	sess, err := session.NewSession(&aws.Config{
74		Region: aws.String("region")},
75	)
76
77	if err != nil {
78		t.Errorf("error getting session: %v", err)
79	}
80
81	creds, err := sess.Config.Credentials.Get()
82	if err != nil {
83		t.Errorf("error getting credentials: %v", err)
84	}
85
86	if e, a := "nonDefaultToken", creds.SessionToken; e != a {
87		t.Errorf("expected %v, got %v", e, a)
88	}
89
90}
91
92func TestProcessProviderNotFromCredProcCfg(t *testing.T) {
93	restoreEnvFn := sdktesting.StashEnv()
94	defer restoreEnvFn()
95
96	os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
97	os.Setenv("AWS_PROFILE", "not_alone")
98	if runtime.GOOS == "windows" {
99		os.Setenv("AWS_CONFIG_FILE", "testdata\\shconfig_win.ini")
100	} else {
101		os.Setenv("AWS_CONFIG_FILE", "testdata/shconfig.ini")
102	}
103
104	sess, err := session.NewSession(&aws.Config{
105		Region: aws.String("region")},
106	)
107
108	if err != nil {
109		t.Errorf("error getting session: %v", err)
110	}
111
112	creds, err := sess.Config.Credentials.Get()
113	if err != nil {
114		t.Errorf("error getting credentials: %v", err)
115	}
116
117	if e, a := "notFromCredProcAccess", creds.AccessKeyID; e != a {
118		t.Errorf("expected %v, got %v", e, a)
119	}
120
121	if e, a := "notFromCredProcSecret", creds.SecretAccessKey; e != a {
122		t.Errorf("expected %v, got %v", e, a)
123	}
124
125}
126
127func TestProcessProviderFromSessionCrd(t *testing.T) {
128	restoreEnvFn := sdktesting.StashEnv()
129	defer restoreEnvFn()
130
131	if runtime.GOOS == "windows" {
132		os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata\\shcred_win.ini")
133	} else {
134		os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata/shcred.ini")
135	}
136
137	sess, err := session.NewSession(&aws.Config{
138		Region: aws.String("region")},
139	)
140
141	if err != nil {
142		t.Errorf("error getting session: %v", err)
143	}
144
145	creds, err := sess.Config.Credentials.Get()
146	if err != nil {
147		t.Errorf("error getting credentials: %v", err)
148	}
149
150	if e, a := "accessKey", creds.AccessKeyID; e != a {
151		t.Errorf("expected %v, got %v", e, a)
152	}
153
154	if e, a := "secret", creds.SecretAccessKey; e != a {
155		t.Errorf("expected %v, got %v", e, a)
156	}
157
158	if e, a := "tokenDefault", creds.SessionToken; e != a {
159		t.Errorf("expected %v, got %v", e, a)
160	}
161
162}
163
164func TestProcessProviderFromSessionWithProfileCrd(t *testing.T) {
165	restoreEnvFn := sdktesting.StashEnv()
166	defer restoreEnvFn()
167
168	os.Setenv("AWS_PROFILE", "non_expire")
169	if runtime.GOOS == "windows" {
170		os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata\\shcred_win.ini")
171	} else {
172		os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata/shcred.ini")
173	}
174
175	sess, err := session.NewSession(&aws.Config{
176		Region: aws.String("region")},
177	)
178
179	if err != nil {
180		t.Errorf("error getting session: %v", err)
181	}
182
183	creds, err := sess.Config.Credentials.Get()
184	if err != nil {
185		t.Errorf("error getting credentials: %v", err)
186	}
187
188	if e, a := "nonDefaultToken", creds.SessionToken; e != a {
189		t.Errorf("expected %v, got %v", e, a)
190	}
191
192}
193
194func TestProcessProviderNotFromCredProcCrd(t *testing.T) {
195	restoreEnvFn := sdktesting.StashEnv()
196	defer restoreEnvFn()
197
198	os.Setenv("AWS_PROFILE", "not_alone")
199	if runtime.GOOS == "windows" {
200		os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata\\shcred_win.ini")
201	} else {
202		os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata/shcred.ini")
203	}
204
205	sess, err := session.NewSession(&aws.Config{
206		Region: aws.String("region")},
207	)
208
209	if err != nil {
210		t.Errorf("error getting session: %v", err)
211	}
212
213	creds, err := sess.Config.Credentials.Get()
214	if err != nil {
215		t.Errorf("error getting credentials: %v", err)
216	}
217
218	if e, a := "notFromCredProcAccess", creds.AccessKeyID; e != a {
219		t.Errorf("expected %v, got %v", e, a)
220	}
221
222	if e, a := "notFromCredProcSecret", creds.SecretAccessKey; e != a {
223		t.Errorf("expected %v, got %v", e, a)
224	}
225
226}
227
228func TestProcessProviderBadCommand(t *testing.T) {
229	restoreEnvFn := sdktesting.StashEnv()
230	defer restoreEnvFn()
231
232	creds := processcreds.NewCredentials("/bad/process")
233	_, err := creds.Get()
234	if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderExecution {
235		t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderExecution, err)
236	}
237}
238
239func TestProcessProviderMoreEmptyCommands(t *testing.T) {
240	restoreEnvFn := sdktesting.StashEnv()
241	defer restoreEnvFn()
242
243	creds := processcreds.NewCredentials("")
244	_, err := creds.Get()
245	if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderExecution {
246		t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderExecution, err)
247	}
248
249}
250
251func TestProcessProviderExpectErrors(t *testing.T) {
252	restoreEnvFn := sdktesting.StashEnv()
253	defer restoreEnvFn()
254
255	creds := processcreds.NewCredentials(
256		fmt.Sprintf(
257			"%s %s",
258			getOSCat(),
259			strings.Join(
260				[]string{"testdata", "malformed.json"},
261				string(os.PathSeparator))))
262	_, err := creds.Get()
263	if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderParse {
264		t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderParse, err)
265	}
266
267	creds = processcreds.NewCredentials(
268		fmt.Sprintf("%s %s",
269			getOSCat(),
270			strings.Join(
271				[]string{"testdata", "wrongversion.json"},
272				string(os.PathSeparator))))
273	_, err = creds.Get()
274	if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderVersion {
275		t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderVersion, err)
276	}
277
278	creds = processcreds.NewCredentials(
279		fmt.Sprintf(
280			"%s %s",
281			getOSCat(),
282			strings.Join(
283				[]string{"testdata", "missingkey.json"},
284				string(os.PathSeparator))))
285	_, err = creds.Get()
286	if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderRequired {
287		t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderRequired, err)
288	}
289
290	creds = processcreds.NewCredentials(
291		fmt.Sprintf(
292			"%s %s",
293			getOSCat(),
294			strings.Join(
295				[]string{"testdata", "missingsecret.json"},
296				string(os.PathSeparator))))
297	_, err = creds.Get()
298	if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderRequired {
299		t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderRequired, err)
300	}
301
302}
303
304func TestProcessProviderTimeout(t *testing.T) {
305	restoreEnvFn := sdktesting.StashEnv()
306	defer restoreEnvFn()
307
308	command := "/bin/sleep 2"
309	if runtime.GOOS == "windows" {
310		// "timeout" command does not work due to pipe redirection
311		command = "ping -n 2 127.0.0.1>nul"
312	}
313
314	creds := processcreds.NewCredentialsTimeout(
315		command,
316		time.Duration(1)*time.Second)
317	if _, err := creds.Get(); err == nil || err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderExecution || err.(awserr.Error).Message() != "credential process timed out" {
318		t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderExecution, err)
319	}
320
321}
322
323func TestProcessProviderWithLongSessionToken(t *testing.T) {
324	restoreEnvFn := sdktesting.StashEnv()
325	defer restoreEnvFn()
326
327	creds := processcreds.NewCredentials(
328		fmt.Sprintf(
329			"%s %s",
330			getOSCat(),
331			strings.Join(
332				[]string{"testdata", "longsessiontoken.json"},
333				string(os.PathSeparator))))
334	v, err := creds.Get()
335	if err != nil {
336		t.Errorf("expected %v, got %v", "no error", err)
337	}
338
339	// Text string same length as session token returned by AWS for AssumeRoleWithWebIdentity
340	e := "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
341	if a := v.SessionToken; e != a {
342		t.Errorf("expected %v, got %v", e, a)
343	}
344}
345
346type credentialTest struct {
347	Version         int
348	AccessKeyID     string `json:"AccessKeyId"`
349	SecretAccessKey string
350	Expiration      string
351}
352
353func TestProcessProviderStatic(t *testing.T) {
354	restoreEnvFn := sdktesting.StashEnv()
355	defer restoreEnvFn()
356
357	// static
358	creds := processcreds.NewCredentials(
359		fmt.Sprintf(
360			"%s %s",
361			getOSCat(),
362			strings.Join(
363				[]string{"testdata", "static.json"},
364				string(os.PathSeparator))))
365	_, err := creds.Get()
366	if err != nil {
367		t.Errorf("expected %v, got %v", "no error", err)
368	}
369	if creds.IsExpired() {
370		t.Errorf("expected %v, got %v", "static credentials/not expired", "expired")
371	}
372
373}
374
375func TestProcessProviderNotExpired(t *testing.T) {
376	restoreEnvFn := sdktesting.StashEnv()
377	defer restoreEnvFn()
378
379	// non-static, not expired
380	exp := &credentialTest{}
381	exp.Version = 1
382	exp.AccessKeyID = "accesskey"
383	exp.SecretAccessKey = "secretkey"
384	exp.Expiration = time.Now().Add(1 * time.Hour).UTC().Format(time.RFC3339)
385	b, err := json.Marshal(exp)
386	if err != nil {
387		t.Errorf("expected %v, got %v", "no error", err)
388	}
389
390	tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp_expiring")
391	if err != nil {
392		t.Errorf("expected %v, got %v", "no error", err)
393	}
394	if _, err = io.Copy(tmpFile, bytes.NewReader(b)); err != nil {
395		t.Errorf("expected %v, got %v", "no error", err)
396	}
397	defer func() {
398		if err = tmpFile.Close(); err != nil {
399			t.Errorf("expected %v, got %v", "no error", err)
400		}
401		if err = os.Remove(tmpFile.Name()); err != nil {
402			t.Errorf("expected %v, got %v", "no error", err)
403		}
404	}()
405	creds := processcreds.NewCredentials(
406		fmt.Sprintf("%s %s", getOSCat(), tmpFile.Name()))
407	_, err = creds.Get()
408	if err != nil {
409		t.Errorf("expected %v, got %v", "no error", err)
410	}
411	if creds.IsExpired() {
412		t.Errorf("expected %v, got %v", "not expired", "expired")
413	}
414}
415
416func TestProcessProviderExpired(t *testing.T) {
417	restoreEnvFn := sdktesting.StashEnv()
418	defer restoreEnvFn()
419
420	// non-static, expired
421	exp := &credentialTest{}
422	exp.Version = 1
423	exp.AccessKeyID = "accesskey"
424	exp.SecretAccessKey = "secretkey"
425	exp.Expiration = time.Now().Add(-1 * time.Hour).UTC().Format(time.RFC3339)
426	b, err := json.Marshal(exp)
427	if err != nil {
428		t.Errorf("expected %v, got %v", "no error", err)
429	}
430
431	tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp_expired")
432	if err != nil {
433		t.Errorf("expected %v, got %v", "no error", err)
434	}
435	if _, err = io.Copy(tmpFile, bytes.NewReader(b)); err != nil {
436		t.Errorf("expected %v, got %v", "no error", err)
437	}
438	defer func() {
439		if err = tmpFile.Close(); err != nil {
440			t.Errorf("expected %v, got %v", "no error", err)
441		}
442		if err = os.Remove(tmpFile.Name()); err != nil {
443			t.Errorf("expected %v, got %v", "no error", err)
444		}
445	}()
446	creds := processcreds.NewCredentials(
447		fmt.Sprintf("%s %s", getOSCat(), tmpFile.Name()))
448	_, err = creds.Get()
449	if err != nil {
450		t.Errorf("expected %v, got %v", "no error", err)
451	}
452	if !creds.IsExpired() {
453		t.Errorf("expected %v, got %v", "expired", "not expired")
454	}
455}
456
457func TestProcessProviderForceExpire(t *testing.T) {
458	restoreEnvFn := sdktesting.StashEnv()
459	defer restoreEnvFn()
460
461	// non-static, not expired
462
463	// setup test credentials file
464	exp := &credentialTest{}
465	exp.Version = 1
466	exp.AccessKeyID = "accesskey"
467	exp.SecretAccessKey = "secretkey"
468	exp.Expiration = time.Now().Add(1 * time.Hour).UTC().Format(time.RFC3339)
469	b, err := json.Marshal(exp)
470	if err != nil {
471		t.Errorf("expected %v, got %v", "no error", err)
472	}
473	tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp_force_expire")
474	if err != nil {
475		t.Errorf("expected %v, got %v", "no error", err)
476	}
477	if _, err = io.Copy(tmpFile, bytes.NewReader(b)); err != nil {
478		t.Errorf("expected %v, got %v", "no error", err)
479	}
480	defer func() {
481		if err = tmpFile.Close(); err != nil {
482			t.Errorf("expected %v, got %v", "no error", err)
483		}
484		if err = os.Remove(tmpFile.Name()); err != nil {
485			t.Errorf("expected %v, got %v", "no error", err)
486		}
487	}()
488
489	// get credentials from file
490	creds := processcreds.NewCredentials(
491		fmt.Sprintf("%s %s", getOSCat(), tmpFile.Name()))
492	if _, err = creds.Get(); err != nil {
493		t.Errorf("expected %v, got %v", "no error", err)
494	}
495	if creds.IsExpired() {
496		t.Errorf("expected %v, got %v", "not expired", "expired")
497	}
498
499	// force expire creds
500	creds.Expire()
501	if !creds.IsExpired() {
502		t.Errorf("expected %v, got %v", "expired", "not expired")
503	}
504
505	// renew creds
506	if _, err = creds.Get(); err != nil {
507		t.Errorf("expected %v, got %v", "no error", err)
508	}
509	if creds.IsExpired() {
510		t.Errorf("expected %v, got %v", "not expired", "expired")
511	}
512
513}
514
515func TestProcessProviderAltConstruct(t *testing.T) {
516	restoreEnvFn := sdktesting.StashEnv()
517	defer restoreEnvFn()
518
519	// constructing with exec.Cmd instead of string
520	myCommand := exec.Command(
521		fmt.Sprintf(
522			"%s %s",
523			getOSCat(),
524			strings.Join(
525				[]string{"testdata", "static.json"},
526				string(os.PathSeparator))))
527	creds := processcreds.NewCredentialsCommand(myCommand, func(opt *processcreds.ProcessProvider) {
528		opt.Timeout = time.Duration(1) * time.Second
529	})
530	_, err := creds.Get()
531	if err != nil {
532		t.Errorf("expected %v, got %v", "no error", err)
533	}
534	if creds.IsExpired() {
535		t.Errorf("expected %v, got %v", "static credentials/not expired", "expired")
536	}
537}
538
539func BenchmarkProcessProvider(b *testing.B) {
540	restoreEnvFn := sdktesting.StashEnv()
541	defer restoreEnvFn()
542
543	creds := processcreds.NewCredentials(
544		fmt.Sprintf(
545			"%s %s",
546			getOSCat(),
547			strings.Join(
548				[]string{"testdata", "static.json"},
549				string(os.PathSeparator))))
550	_, err := creds.Get()
551	if err != nil {
552		b.Fatal(err)
553	}
554
555	b.ResetTimer()
556	for i := 0; i < b.N; i++ {
557		_, err := creds.Get()
558		if err != nil {
559			b.Fatal(err)
560		}
561	}
562}
563
564func getOSCat() string {
565	if runtime.GOOS == "windows" {
566		return "type"
567	}
568	return "cat"
569}
570