1package corehandlers_test
2
3import (
4	"bytes"
5	"fmt"
6	"io/ioutil"
7	"net/http"
8	"net/http/httptest"
9	"net/url"
10	"strings"
11	"testing"
12	"time"
13
14	"github.com/aws/aws-sdk-go/aws"
15	"github.com/aws/aws-sdk-go/aws/awserr"
16	"github.com/aws/aws-sdk-go/aws/client"
17	"github.com/aws/aws-sdk-go/aws/client/metadata"
18	"github.com/aws/aws-sdk-go/aws/corehandlers"
19	"github.com/aws/aws-sdk-go/aws/credentials"
20	"github.com/aws/aws-sdk-go/aws/request"
21	"github.com/aws/aws-sdk-go/awstesting"
22	"github.com/aws/aws-sdk-go/awstesting/unit"
23	"github.com/aws/aws-sdk-go/internal/sdktesting"
24	"github.com/aws/aws-sdk-go/service/s3"
25)
26
27func TestValidateEndpointHandler(t *testing.T) {
28	restoreEnvFn := sdktesting.StashEnv()
29	defer restoreEnvFn()
30	svc := awstesting.NewClient(aws.NewConfig().WithRegion("us-west-2"))
31	svc.Handlers.Clear()
32	svc.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
33
34	req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
35	err := req.Build()
36
37	if err != nil {
38		t.Errorf("expect no error, got %v", err)
39	}
40}
41
42func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
43	restoreEnvFn := sdktesting.StashEnv()
44	defer restoreEnvFn()
45	svc := awstesting.NewClient()
46	svc.Handlers.Clear()
47	svc.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
48
49	req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
50	err := req.Build()
51
52	if err == nil {
53		t.Errorf("expect error, got none")
54	}
55	if e, a := aws.ErrMissingRegion, err; e != a {
56		t.Errorf("expect %v to be %v", e, a)
57	}
58}
59
60type mockCredsProvider struct {
61	expired        bool
62	retrieveCalled bool
63}
64
65func (m *mockCredsProvider) Retrieve() (credentials.Value, error) {
66	m.retrieveCalled = true
67	return credentials.Value{
68		AccessKeyID:     "AKID",
69		SecretAccessKey: "SECRET",
70		ProviderName:    "mockCredsProvider",
71	}, nil
72}
73
74func (m *mockCredsProvider) IsExpired() bool {
75	return m.expired
76}
77
78func TestAfterRetryRefreshCreds(t *testing.T) {
79	restoreEnvFn := sdktesting.StashEnv()
80	defer restoreEnvFn()
81
82	credProvider := &mockCredsProvider{}
83
84	sess := unit.Session.Copy(&aws.Config{
85		Credentials: credentials.NewCredentials(credProvider),
86		MaxRetries:  aws.Int(2),
87	})
88	clientInfo := metadata.ClientInfo{
89		Endpoint:    "http://endpoint",
90		SigningName: "",
91	}
92	svc := client.New(*sess.Config, clientInfo, sess.Handlers)
93
94	svc.Handlers.Sign.PushBack(func(r *request.Request) {
95		if !svc.Config.Credentials.IsExpired() {
96			t.Errorf("expect credentials of of been expired before request attempt")
97		}
98		_, err := svc.Config.Credentials.Get()
99		r.Error = err
100	})
101
102	var respID int
103	resps := []struct {
104		Resp *http.Response
105		Err  error
106	}{
107		{
108			Resp: &http.Response{
109				StatusCode: 403,
110				Header:     http.Header{},
111				Body:       ioutil.NopCloser(bytes.NewBuffer([]byte{})),
112			},
113			Err: awserr.New("ExpiredToken", "", nil),
114		},
115		{
116			Resp: &http.Response{
117				StatusCode: 403,
118				Header:     http.Header{},
119				Body:       ioutil.NopCloser(bytes.NewBuffer([]byte{})),
120			},
121			Err: awserr.New("ExpiredToken", "", nil),
122		},
123		{
124			Resp: &http.Response{
125				StatusCode: 200,
126				Header:     http.Header{},
127				Body:       ioutil.NopCloser(bytes.NewBuffer([]byte{})),
128			},
129		},
130	}
131	svc.Handlers.Send.Clear()
132	svc.Handlers.Send.PushBack(func(r *request.Request) {
133		r.HTTPResponse = resps[respID].Resp
134	})
135	svc.Handlers.UnmarshalError.PushBack(func(r *request.Request) {
136		r.Error = resps[respID].Err
137	})
138	svc.Handlers.CompleteAttempt.PushBack(func(r *request.Request) {
139		respID++
140	})
141
142	if !svc.Config.Credentials.IsExpired() {
143		t.Fatalf("expect to start out expired")
144	}
145	if credProvider.retrieveCalled {
146		t.Fatalf("expect retrieve not yet called")
147	}
148
149	req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
150	if err := req.Send(); err != nil {
151		t.Fatalf("expect no error, got %v", err)
152	}
153	if e, a := len(resps)-1, req.RetryCount; e != a {
154		t.Errorf("expect %v retries, got %v", e, a)
155	}
156	if svc.Config.Credentials.IsExpired() {
157		t.Errorf("expect credentials not to be expired")
158	}
159	if !credProvider.retrieveCalled {
160		t.Errorf("expect retrieve to be called")
161	}
162}
163
164func TestAfterRetryWithContextCanceled(t *testing.T) {
165	c := awstesting.NewClient()
166
167	req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
168
169	ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
170	req.SetContext(ctx)
171
172	req.Error = fmt.Errorf("some error")
173	req.Retryable = aws.Bool(true)
174	req.HTTPResponse = &http.Response{
175		StatusCode: 500,
176	}
177
178	close(ctx.DoneCh)
179	ctx.Error = fmt.Errorf("context canceled")
180
181	corehandlers.AfterRetryHandler.Fn(req)
182
183	if req.Error == nil {
184		t.Fatalf("expect error but didn't receive one")
185	}
186
187	aerr := req.Error.(awserr.Error)
188
189	if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
190		t.Errorf("expect %q, error code got %q", e, a)
191	}
192}
193
194func TestAfterRetryWithContext(t *testing.T) {
195	c := awstesting.NewClient()
196
197	req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
198
199	ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
200	req.SetContext(ctx)
201
202	req.Error = fmt.Errorf("some error")
203	req.Retryable = aws.Bool(true)
204	req.HTTPResponse = &http.Response{
205		StatusCode: 500,
206	}
207
208	corehandlers.AfterRetryHandler.Fn(req)
209
210	if req.Error != nil {
211		t.Fatalf("expect no error, got %v", req.Error)
212	}
213	if e, a := 1, req.RetryCount; e != a {
214		t.Errorf("expect retry count to be %d, got %d", e, a)
215	}
216}
217
218func TestSendWithContextCanceled(t *testing.T) {
219	c := awstesting.NewClient(&aws.Config{
220		SleepDelay: func(dur time.Duration) {
221			t.Errorf("SleepDelay should not be called")
222		},
223	})
224
225	req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
226
227	ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
228	req.SetContext(ctx)
229
230	req.Error = fmt.Errorf("some error")
231	req.Retryable = aws.Bool(true)
232	req.HTTPResponse = &http.Response{
233		StatusCode: 500,
234	}
235
236	close(ctx.DoneCh)
237	ctx.Error = fmt.Errorf("context canceled")
238
239	corehandlers.SendHandler.Fn(req)
240
241	if req.Error == nil {
242		t.Fatalf("expect error but didn't receive one")
243	}
244
245	aerr := req.Error.(awserr.Error)
246
247	if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
248		t.Errorf("expect %q, error code got %q", e, a)
249	}
250}
251
252type testSendHandlerTransport struct{}
253
254func (t *testSendHandlerTransport) RoundTrip(r *http.Request) (*http.Response, error) {
255	return nil, fmt.Errorf("mock error")
256}
257
258func TestSendHandlerError(t *testing.T) {
259	svc := awstesting.NewClient(&aws.Config{
260		HTTPClient: &http.Client{
261			Transport: &testSendHandlerTransport{},
262		},
263	})
264	svc.Handlers.Clear()
265	svc.Handlers.Send.PushBackNamed(corehandlers.SendHandler)
266	r := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
267
268	r.Send()
269
270	if r.Error == nil {
271		t.Errorf("expect error, got none")
272	}
273	if r.HTTPResponse == nil {
274		t.Errorf("expect response, got none")
275	}
276}
277
278func TestSendWithoutFollowRedirects(t *testing.T) {
279	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
280		switch r.URL.Path {
281		case "/original":
282			w.Header().Set("Location", "/redirected")
283			w.WriteHeader(301)
284		case "/redirected":
285			t.Fatalf("expect not to redirect, but was")
286		}
287	}))
288	defer server.Close()
289
290	svc := awstesting.NewClient(&aws.Config{
291		DisableSSL: aws.Bool(true),
292		Endpoint:   aws.String(server.URL),
293	})
294	svc.Handlers.Clear()
295	svc.Handlers.Send.PushBackNamed(corehandlers.SendHandler)
296
297	r := svc.NewRequest(&request.Operation{
298		Name:     "Operation",
299		HTTPPath: "/original",
300	}, nil, nil)
301	r.DisableFollowRedirects = true
302
303	err := r.Send()
304	if err != nil {
305		t.Errorf("expect no error, got %v", err)
306	}
307	if e, a := 301, r.HTTPResponse.StatusCode; e != a {
308		t.Errorf("expect %d status code, got %d", e, a)
309	}
310}
311
312func TestValidateReqSigHandler(t *testing.T) {
313	cases := []struct {
314		Req    *request.Request
315		Resign bool
316	}{
317		{
318			Req: &request.Request{
319				Config: aws.Config{Credentials: credentials.AnonymousCredentials},
320				Time:   time.Now().Add(-15 * time.Minute),
321			},
322			Resign: false,
323		},
324		{
325			Req: &request.Request{
326				Time: time.Now().Add(-15 * time.Minute),
327			},
328			Resign: true,
329		},
330		{
331			Req: &request.Request{
332				Time: time.Now().Add(-1 * time.Minute),
333			},
334			Resign: false,
335		},
336	}
337
338	for i, c := range cases {
339		c.Req.HTTPRequest = &http.Request{URL: &url.URL{}}
340
341		resigned := false
342		c.Req.Handlers.Sign.PushBack(func(r *request.Request) {
343			resigned = true
344		})
345
346		corehandlers.ValidateReqSigHandler.Fn(c.Req)
347
348		if c.Req.Error != nil {
349			t.Errorf("expect no error, got %v", c.Req.Error)
350		}
351		if e, a := c.Resign, resigned; e != a {
352			t.Errorf("%d, expect %v to be %v", i, e, a)
353		}
354	}
355}
356
357func setupContentLengthTestServer(t *testing.T, hasContentLength bool, contentLength int64) *httptest.Server {
358	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
359		_, ok := r.Header["Content-Length"]
360		if e, a := hasContentLength, ok; e != a {
361			t.Errorf("expect %v to be %v", e, a)
362		}
363		if hasContentLength {
364			if e, a := contentLength, r.ContentLength; e != a {
365				t.Errorf("expect %v to be %v", e, a)
366			}
367		}
368
369		b, err := ioutil.ReadAll(r.Body)
370		if err != nil {
371			t.Errorf("expect no error, got %v", err)
372		}
373		r.Body.Close()
374
375		authHeader := r.Header.Get("Authorization")
376		if hasContentLength {
377			if e, a := "content-length", authHeader; !strings.Contains(a, e) {
378				t.Errorf("expect %v to be in %v", e, a)
379			}
380		} else {
381			if e, a := "content-length", authHeader; strings.Contains(a, e) {
382				t.Errorf("expect %v to not be in %v", e, a)
383			}
384		}
385
386		if e, a := contentLength, int64(len(b)); e != a {
387			t.Errorf("expect %v to be %v", e, a)
388		}
389	}))
390
391	return server
392}
393
394func TestBuildContentLength_ZeroBody(t *testing.T) {
395	server := setupContentLengthTestServer(t, false, 0)
396	defer server.Close()
397
398	svc := s3.New(unit.Session, &aws.Config{
399		Endpoint:         aws.String(server.URL),
400		S3ForcePathStyle: aws.Bool(true),
401		DisableSSL:       aws.Bool(true),
402	})
403	_, err := svc.GetObject(&s3.GetObjectInput{
404		Bucket: aws.String("bucketname"),
405		Key:    aws.String("keyname"),
406	})
407
408	if err != nil {
409		t.Errorf("expect no error, got %v", err)
410	}
411}
412
413func TestBuildContentLength_NegativeBody(t *testing.T) {
414	server := setupContentLengthTestServer(t, false, 0)
415	defer server.Close()
416
417	svc := s3.New(unit.Session, &aws.Config{
418		Endpoint:         aws.String(server.URL),
419		S3ForcePathStyle: aws.Bool(true),
420		DisableSSL:       aws.Bool(true),
421	})
422	req, _ := svc.GetObjectRequest(&s3.GetObjectInput{
423		Bucket: aws.String("bucketname"),
424		Key:    aws.String("keyname"),
425	})
426
427	req.HTTPRequest.Header.Set("Content-Length", "-1")
428
429	if req.Error != nil {
430		t.Errorf("expect no error, got %v", req.Error)
431	}
432}
433
434func TestBuildContentLength_WithBody(t *testing.T) {
435	server := setupContentLengthTestServer(t, true, 1024)
436	defer server.Close()
437
438	svc := s3.New(unit.Session, &aws.Config{
439		Endpoint:         aws.String(server.URL),
440		S3ForcePathStyle: aws.Bool(true),
441		DisableSSL:       aws.Bool(true),
442	})
443	_, err := svc.PutObject(&s3.PutObjectInput{
444		Bucket: aws.String("bucketname"),
445		Key:    aws.String("keyname"),
446		Body:   bytes.NewReader(make([]byte, 1024)),
447	})
448
449	if err != nil {
450		t.Errorf("expect no error, got %v", err)
451	}
452}
453