1package endpointcreds_test
2
3import (
4	"encoding/json"
5	"fmt"
6	"net/http"
7	"net/http/httptest"
8	"testing"
9	"time"
10
11	"github.com/aws/aws-sdk-go/aws/awserr"
12	"github.com/aws/aws-sdk-go/aws/credentials/endpointcreds"
13	"github.com/aws/aws-sdk-go/awstesting/unit"
14)
15
16func TestRetrieveRefreshableCredentials(t *testing.T) {
17	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
18		if e, a := "/path/to/endpoint", r.URL.Path; e != a {
19			t.Errorf("expect %v, got %v", e, a)
20		}
21		if e, a := "application/json", r.Header.Get("Accept"); e != a {
22			t.Errorf("expect %v, got %v", e, a)
23		}
24		if e, a := "else", r.URL.Query().Get("something"); e != a {
25			t.Errorf("expect %v, got %v", e, a)
26		}
27
28		encoder := json.NewEncoder(w)
29		err := encoder.Encode(map[string]interface{}{
30			"AccessKeyID":     "AKID",
31			"SecretAccessKey": "SECRET",
32			"Token":           "TOKEN",
33			"Expiration":      time.Now().Add(1 * time.Hour),
34		})
35
36		if err != nil {
37			fmt.Println("failed to write out creds", err)
38		}
39	}))
40	defer server.Close()
41
42	client := endpointcreds.NewProviderClient(*unit.Session.Config,
43		unit.Session.Handlers,
44		server.URL+"/path/to/endpoint?something=else",
45	)
46	creds, err := client.Retrieve()
47
48	if err != nil {
49		t.Errorf("expect no error, got %v", err)
50	}
51
52	if e, a := "AKID", creds.AccessKeyID; e != a {
53		t.Errorf("expect %v, got %v", e, a)
54	}
55	if e, a := "SECRET", creds.SecretAccessKey; e != a {
56		t.Errorf("expect %v, got %v", e, a)
57	}
58	if e, a := "TOKEN", creds.SessionToken; e != a {
59		t.Errorf("expect %v, got %v", e, a)
60	}
61	if client.IsExpired() {
62		t.Errorf("expect not expired, was")
63	}
64
65	client.(*endpointcreds.Provider).CurrentTime = func() time.Time {
66		return time.Now().Add(2 * time.Hour)
67	}
68
69	if !client.IsExpired() {
70		t.Errorf("expect expired, wasn't")
71	}
72}
73
74func TestRetrieveStaticCredentials(t *testing.T) {
75	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
76		encoder := json.NewEncoder(w)
77		err := encoder.Encode(map[string]interface{}{
78			"AccessKeyID":     "AKID",
79			"SecretAccessKey": "SECRET",
80		})
81
82		if err != nil {
83			fmt.Println("failed to write out creds", err)
84		}
85	}))
86	defer server.Close()
87
88	client := endpointcreds.NewProviderClient(*unit.Session.Config, unit.Session.Handlers, server.URL)
89	creds, err := client.Retrieve()
90
91	if err != nil {
92		t.Errorf("expect no error, got %v", err)
93	}
94
95	if e, a := "AKID", creds.AccessKeyID; e != a {
96		t.Errorf("expect %v, got %v", e, a)
97	}
98	if e, a := "SECRET", creds.SecretAccessKey; e != a {
99		t.Errorf("expect %v, got %v", e, a)
100	}
101	if v := creds.SessionToken; len(v) != 0 {
102		t.Errorf("Expect no SessionToken, got %#v", v)
103	}
104	if client.IsExpired() {
105		t.Errorf("expect not expired, was")
106	}
107}
108
109func TestFailedRetrieveCredentials(t *testing.T) {
110	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
111		w.WriteHeader(400)
112		encoder := json.NewEncoder(w)
113		err := encoder.Encode(map[string]interface{}{
114			"Code":    "Error",
115			"Message": "Message",
116		})
117
118		if err != nil {
119			fmt.Println("failed to write error", err)
120		}
121	}))
122	defer server.Close()
123
124	client := endpointcreds.NewProviderClient(*unit.Session.Config, unit.Session.Handlers, server.URL)
125	creds, err := client.Retrieve()
126
127	if err == nil {
128		t.Errorf("expect error, got none")
129	}
130	aerr := err.(awserr.Error)
131
132	if e, a := "CredentialsEndpointError", aerr.Code(); e != a {
133		t.Errorf("expect %v, got %v", e, a)
134	}
135	if e, a := "failed to load credentials", aerr.Message(); e != a {
136		t.Errorf("expect %v, got %v", e, a)
137	}
138
139	aerr = aerr.OrigErr().(awserr.Error)
140	if e, a := "Error", aerr.Code(); e != a {
141		t.Errorf("expect %v, got %v", e, a)
142	}
143	if e, a := "Message", aerr.Message(); e != a {
144		t.Errorf("expect %v, got %v", e, a)
145	}
146
147	if v := creds.AccessKeyID; len(v) != 0 {
148		t.Errorf("expect empty, got %#v", v)
149	}
150	if v := creds.SecretAccessKey; len(v) != 0 {
151		t.Errorf("expect empty, got %#v", v)
152	}
153	if v := creds.SessionToken; len(v) != 0 {
154		t.Errorf("expect empty, got %#v", v)
155	}
156	if !client.IsExpired() {
157		t.Errorf("expect expired, wasn't")
158	}
159}
160
161func TestAuthorizationToken(t *testing.T) {
162	const expectAuthToken = "Basic abc123"
163
164	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
165		if e, a := "/path/to/endpoint", r.URL.Path; e != a {
166			t.Errorf("expect %v, got %v", e, a)
167		}
168		if e, a := "application/json", r.Header.Get("Accept"); e != a {
169			t.Errorf("expect %v, got %v", e, a)
170		}
171		if e, a := expectAuthToken, r.Header.Get("Authorization"); e != a {
172			t.Fatalf("expect %v, got %v", e, a)
173		}
174
175		encoder := json.NewEncoder(w)
176		err := encoder.Encode(map[string]interface{}{
177			"AccessKeyID":     "AKID",
178			"SecretAccessKey": "SECRET",
179			"Token":           "TOKEN",
180			"Expiration":      time.Now().Add(1 * time.Hour),
181		})
182
183		if err != nil {
184			fmt.Println("failed to write out creds", err)
185		}
186	}))
187	defer server.Close()
188
189	client := endpointcreds.NewProviderClient(*unit.Session.Config,
190		unit.Session.Handlers,
191		server.URL+"/path/to/endpoint?something=else",
192		func(p *endpointcreds.Provider) {
193			p.AuthorizationToken = expectAuthToken
194		},
195	)
196	creds, err := client.Retrieve()
197
198	if err != nil {
199		t.Errorf("expect no error, got %v", err)
200	}
201
202	if e, a := "AKID", creds.AccessKeyID; e != a {
203		t.Errorf("expect %v, got %v", e, a)
204	}
205	if e, a := "SECRET", creds.SecretAccessKey; e != a {
206		t.Errorf("expect %v, got %v", e, a)
207	}
208	if e, a := "TOKEN", creds.SessionToken; e != a {
209		t.Errorf("expect %v, got %v", e, a)
210	}
211	if client.IsExpired() {
212		t.Errorf("expect not expired, was")
213	}
214
215	client.(*endpointcreds.Provider).CurrentTime = func() time.Time {
216		return time.Now().Add(2 * time.Hour)
217	}
218
219	if !client.IsExpired() {
220		t.Errorf("expect expired, wasn't")
221	}
222}
223