1package corehandlers_test 2 3import ( 4 "bytes" 5 "fmt" 6 "io/ioutil" 7 "net/http" 8 "net/http/httptest" 9 "net/url" 10 "strings" 11 "testing" 12 "time" 13 14 "github.com/aws/aws-sdk-go/aws" 15 "github.com/aws/aws-sdk-go/aws/awserr" 16 "github.com/aws/aws-sdk-go/aws/client" 17 "github.com/aws/aws-sdk-go/aws/client/metadata" 18 "github.com/aws/aws-sdk-go/aws/corehandlers" 19 "github.com/aws/aws-sdk-go/aws/credentials" 20 "github.com/aws/aws-sdk-go/aws/request" 21 "github.com/aws/aws-sdk-go/awstesting" 22 "github.com/aws/aws-sdk-go/awstesting/unit" 23 "github.com/aws/aws-sdk-go/internal/sdktesting" 24 "github.com/aws/aws-sdk-go/service/s3" 25) 26 27func TestValidateEndpointHandler(t *testing.T) { 28 restoreEnvFn := sdktesting.StashEnv() 29 defer restoreEnvFn() 30 svc := awstesting.NewClient(aws.NewConfig().WithRegion("us-west-2")) 31 svc.Handlers.Clear() 32 svc.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler) 33 34 req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil) 35 err := req.Build() 36 37 if err != nil { 38 t.Errorf("expect no error, got %v", err) 39 } 40} 41 42func TestValidateEndpointHandlerErrorRegion(t *testing.T) { 43 restoreEnvFn := sdktesting.StashEnv() 44 defer restoreEnvFn() 45 svc := awstesting.NewClient() 46 svc.Handlers.Clear() 47 svc.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler) 48 49 req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil) 50 err := req.Build() 51 52 if err == nil { 53 t.Errorf("expect error, got none") 54 } 55 if e, a := aws.ErrMissingRegion, err; e != a { 56 t.Errorf("expect %v to be %v", e, a) 57 } 58} 59 60type mockCredsProvider struct { 61 expired bool 62 retrieveCalled bool 63} 64 65func (m *mockCredsProvider) Retrieve() (credentials.Value, error) { 66 m.retrieveCalled = true 67 return credentials.Value{ 68 AccessKeyID: "AKID", 69 SecretAccessKey: "SECRET", 70 ProviderName: "mockCredsProvider", 71 }, nil 72} 73 74func (m *mockCredsProvider) IsExpired() bool { 75 return m.expired 76} 77 78func TestAfterRetryRefreshCreds(t *testing.T) { 79 restoreEnvFn := sdktesting.StashEnv() 80 defer restoreEnvFn() 81 82 credProvider := &mockCredsProvider{} 83 84 sess := unit.Session.Copy(&aws.Config{ 85 Credentials: credentials.NewCredentials(credProvider), 86 MaxRetries: aws.Int(2), 87 }) 88 clientInfo := metadata.ClientInfo{ 89 Endpoint: "http://endpoint", 90 SigningName: "", 91 } 92 svc := client.New(*sess.Config, clientInfo, sess.Handlers) 93 94 svc.Handlers.Sign.PushBack(func(r *request.Request) { 95 if !svc.Config.Credentials.IsExpired() { 96 t.Errorf("expect credentials of of been expired before request attempt") 97 } 98 _, err := svc.Config.Credentials.Get() 99 r.Error = err 100 }) 101 102 var respID int 103 resps := []struct { 104 Resp *http.Response 105 Err error 106 }{ 107 { 108 Resp: &http.Response{ 109 StatusCode: 403, 110 Header: http.Header{}, 111 Body: ioutil.NopCloser(bytes.NewBuffer([]byte{})), 112 }, 113 Err: awserr.New("ExpiredToken", "", nil), 114 }, 115 { 116 Resp: &http.Response{ 117 StatusCode: 403, 118 Header: http.Header{}, 119 Body: ioutil.NopCloser(bytes.NewBuffer([]byte{})), 120 }, 121 Err: awserr.New("ExpiredToken", "", nil), 122 }, 123 { 124 Resp: &http.Response{ 125 StatusCode: 200, 126 Header: http.Header{}, 127 Body: ioutil.NopCloser(bytes.NewBuffer([]byte{})), 128 }, 129 }, 130 } 131 svc.Handlers.Send.Clear() 132 svc.Handlers.Send.PushBack(func(r *request.Request) { 133 r.HTTPResponse = resps[respID].Resp 134 }) 135 svc.Handlers.UnmarshalError.PushBack(func(r *request.Request) { 136 r.Error = resps[respID].Err 137 }) 138 svc.Handlers.CompleteAttempt.PushBack(func(r *request.Request) { 139 respID++ 140 }) 141 142 if !svc.Config.Credentials.IsExpired() { 143 t.Fatalf("expect to start out expired") 144 } 145 if credProvider.retrieveCalled { 146 t.Fatalf("expect retrieve not yet called") 147 } 148 149 req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil) 150 if err := req.Send(); err != nil { 151 t.Fatalf("expect no error, got %v", err) 152 } 153 if e, a := len(resps)-1, req.RetryCount; e != a { 154 t.Errorf("expect %v retries, got %v", e, a) 155 } 156 if svc.Config.Credentials.IsExpired() { 157 t.Errorf("expect credentials not to be expired") 158 } 159 if !credProvider.retrieveCalled { 160 t.Errorf("expect retrieve to be called") 161 } 162} 163 164func TestAfterRetryWithContextCanceled(t *testing.T) { 165 c := awstesting.NewClient() 166 167 req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil) 168 169 ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})} 170 req.SetContext(ctx) 171 172 req.Error = fmt.Errorf("some error") 173 req.Retryable = aws.Bool(true) 174 req.HTTPResponse = &http.Response{ 175 StatusCode: 500, 176 } 177 178 close(ctx.DoneCh) 179 ctx.Error = fmt.Errorf("context canceled") 180 181 corehandlers.AfterRetryHandler.Fn(req) 182 183 if req.Error == nil { 184 t.Fatalf("expect error but didn't receive one") 185 } 186 187 aerr := req.Error.(awserr.Error) 188 189 if e, a := request.CanceledErrorCode, aerr.Code(); e != a { 190 t.Errorf("expect %q, error code got %q", e, a) 191 } 192} 193 194func TestAfterRetryWithContext(t *testing.T) { 195 c := awstesting.NewClient() 196 197 req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil) 198 199 ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})} 200 req.SetContext(ctx) 201 202 req.Error = fmt.Errorf("some error") 203 req.Retryable = aws.Bool(true) 204 req.HTTPResponse = &http.Response{ 205 StatusCode: 500, 206 } 207 208 corehandlers.AfterRetryHandler.Fn(req) 209 210 if req.Error != nil { 211 t.Fatalf("expect no error, got %v", req.Error) 212 } 213 if e, a := 1, req.RetryCount; e != a { 214 t.Errorf("expect retry count to be %d, got %d", e, a) 215 } 216} 217 218func TestSendWithContextCanceled(t *testing.T) { 219 c := awstesting.NewClient(&aws.Config{ 220 SleepDelay: func(dur time.Duration) { 221 t.Errorf("SleepDelay should not be called") 222 }, 223 }) 224 225 req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil) 226 227 ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})} 228 req.SetContext(ctx) 229 230 req.Error = fmt.Errorf("some error") 231 req.Retryable = aws.Bool(true) 232 req.HTTPResponse = &http.Response{ 233 StatusCode: 500, 234 } 235 236 close(ctx.DoneCh) 237 ctx.Error = fmt.Errorf("context canceled") 238 239 corehandlers.SendHandler.Fn(req) 240 241 if req.Error == nil { 242 t.Fatalf("expect error but didn't receive one") 243 } 244 245 aerr := req.Error.(awserr.Error) 246 247 if e, a := request.CanceledErrorCode, aerr.Code(); e != a { 248 t.Errorf("expect %q, error code got %q", e, a) 249 } 250} 251 252type testSendHandlerTransport struct{} 253 254func (t *testSendHandlerTransport) RoundTrip(r *http.Request) (*http.Response, error) { 255 return nil, fmt.Errorf("mock error") 256} 257 258func TestSendHandlerError(t *testing.T) { 259 svc := awstesting.NewClient(&aws.Config{ 260 HTTPClient: &http.Client{ 261 Transport: &testSendHandlerTransport{}, 262 }, 263 }) 264 svc.Handlers.Clear() 265 svc.Handlers.Send.PushBackNamed(corehandlers.SendHandler) 266 r := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil) 267 268 r.Send() 269 270 if r.Error == nil { 271 t.Errorf("expect error, got none") 272 } 273 if r.HTTPResponse == nil { 274 t.Errorf("expect response, got none") 275 } 276} 277 278func TestSendWithoutFollowRedirects(t *testing.T) { 279 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 280 switch r.URL.Path { 281 case "/original": 282 w.Header().Set("Location", "/redirected") 283 w.WriteHeader(301) 284 case "/redirected": 285 t.Fatalf("expect not to redirect, but was") 286 } 287 })) 288 defer server.Close() 289 290 svc := awstesting.NewClient(&aws.Config{ 291 DisableSSL: aws.Bool(true), 292 Endpoint: aws.String(server.URL), 293 }) 294 svc.Handlers.Clear() 295 svc.Handlers.Send.PushBackNamed(corehandlers.SendHandler) 296 297 r := svc.NewRequest(&request.Operation{ 298 Name: "Operation", 299 HTTPPath: "/original", 300 }, nil, nil) 301 r.DisableFollowRedirects = true 302 303 err := r.Send() 304 if err != nil { 305 t.Errorf("expect no error, got %v", err) 306 } 307 if e, a := 301, r.HTTPResponse.StatusCode; e != a { 308 t.Errorf("expect %d status code, got %d", e, a) 309 } 310} 311 312func TestValidateReqSigHandler(t *testing.T) { 313 cases := []struct { 314 Req *request.Request 315 Resign bool 316 }{ 317 { 318 Req: &request.Request{ 319 Config: aws.Config{Credentials: credentials.AnonymousCredentials}, 320 Time: time.Now().Add(-15 * time.Minute), 321 }, 322 Resign: false, 323 }, 324 { 325 Req: &request.Request{ 326 Time: time.Now().Add(-15 * time.Minute), 327 }, 328 Resign: true, 329 }, 330 { 331 Req: &request.Request{ 332 Time: time.Now().Add(-1 * time.Minute), 333 }, 334 Resign: false, 335 }, 336 } 337 338 for i, c := range cases { 339 c.Req.HTTPRequest = &http.Request{URL: &url.URL{}} 340 341 resigned := false 342 c.Req.Handlers.Sign.PushBack(func(r *request.Request) { 343 resigned = true 344 }) 345 346 corehandlers.ValidateReqSigHandler.Fn(c.Req) 347 348 if c.Req.Error != nil { 349 t.Errorf("expect no error, got %v", c.Req.Error) 350 } 351 if e, a := c.Resign, resigned; e != a { 352 t.Errorf("%d, expect %v to be %v", i, e, a) 353 } 354 } 355} 356 357func setupContentLengthTestServer(t *testing.T, hasContentLength bool, contentLength int64) *httptest.Server { 358 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 359 _, ok := r.Header["Content-Length"] 360 if e, a := hasContentLength, ok; e != a { 361 t.Errorf("expect %v to be %v", e, a) 362 } 363 if hasContentLength { 364 if e, a := contentLength, r.ContentLength; e != a { 365 t.Errorf("expect %v to be %v", e, a) 366 } 367 } 368 369 b, err := ioutil.ReadAll(r.Body) 370 if err != nil { 371 t.Errorf("expect no error, got %v", err) 372 } 373 r.Body.Close() 374 375 authHeader := r.Header.Get("Authorization") 376 if hasContentLength { 377 if e, a := "content-length", authHeader; !strings.Contains(a, e) { 378 t.Errorf("expect %v to be in %v", e, a) 379 } 380 } else { 381 if e, a := "content-length", authHeader; strings.Contains(a, e) { 382 t.Errorf("expect %v to not be in %v", e, a) 383 } 384 } 385 386 if e, a := contentLength, int64(len(b)); e != a { 387 t.Errorf("expect %v to be %v", e, a) 388 } 389 })) 390 391 return server 392} 393 394func TestBuildContentLength_ZeroBody(t *testing.T) { 395 server := setupContentLengthTestServer(t, false, 0) 396 defer server.Close() 397 398 svc := s3.New(unit.Session, &aws.Config{ 399 Endpoint: aws.String(server.URL), 400 S3ForcePathStyle: aws.Bool(true), 401 DisableSSL: aws.Bool(true), 402 }) 403 _, err := svc.GetObject(&s3.GetObjectInput{ 404 Bucket: aws.String("bucketname"), 405 Key: aws.String("keyname"), 406 }) 407 408 if err != nil { 409 t.Errorf("expect no error, got %v", err) 410 } 411} 412 413func TestBuildContentLength_NegativeBody(t *testing.T) { 414 server := setupContentLengthTestServer(t, false, 0) 415 defer server.Close() 416 417 svc := s3.New(unit.Session, &aws.Config{ 418 Endpoint: aws.String(server.URL), 419 S3ForcePathStyle: aws.Bool(true), 420 DisableSSL: aws.Bool(true), 421 }) 422 req, _ := svc.GetObjectRequest(&s3.GetObjectInput{ 423 Bucket: aws.String("bucketname"), 424 Key: aws.String("keyname"), 425 }) 426 427 req.HTTPRequest.Header.Set("Content-Length", "-1") 428 429 if req.Error != nil { 430 t.Errorf("expect no error, got %v", req.Error) 431 } 432} 433 434func TestBuildContentLength_WithBody(t *testing.T) { 435 server := setupContentLengthTestServer(t, true, 1024) 436 defer server.Close() 437 438 svc := s3.New(unit.Session, &aws.Config{ 439 Endpoint: aws.String(server.URL), 440 S3ForcePathStyle: aws.Bool(true), 441 DisableSSL: aws.Bool(true), 442 }) 443 _, err := svc.PutObject(&s3.PutObjectInput{ 444 Bucket: aws.String("bucketname"), 445 Key: aws.String("keyname"), 446 Body: bytes.NewReader(make([]byte, 1024)), 447 }) 448 449 if err != nil { 450 t.Errorf("expect no error, got %v", err) 451 } 452} 453