1package corehandlers
2
3import (
4	"bytes"
5	"fmt"
6	"io/ioutil"
7	"net/http"
8	"net/url"
9	"regexp"
10	"strconv"
11	"time"
12
13	"github.com/aws/aws-sdk-go/aws"
14	"github.com/aws/aws-sdk-go/aws/awserr"
15	"github.com/aws/aws-sdk-go/aws/credentials"
16	"github.com/aws/aws-sdk-go/aws/request"
17)
18
19// Interface for matching types which also have a Len method.
20type lener interface {
21	Len() int
22}
23
24// BuildContentLengthHandler builds the content length of a request based on the body,
25// or will use the HTTPRequest.Header's "Content-Length" if defined. If unable
26// to determine request body length and no "Content-Length" was specified it will panic.
27//
28// The Content-Length will only be added to the request if the length of the body
29// is greater than 0. If the body is empty or the current `Content-Length`
30// header is <= 0, the header will also be stripped.
31var BuildContentLengthHandler = request.NamedHandler{Name: "core.BuildContentLengthHandler", Fn: func(r *request.Request) {
32	var length int64
33
34	if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" {
35		length, _ = strconv.ParseInt(slength, 10, 64)
36	} else {
37		if r.Body != nil {
38			var err error
39			length, err = aws.SeekerLen(r.Body)
40			if err != nil {
41				r.Error = awserr.New(request.ErrCodeSerialization, "failed to get request body's length", err)
42				return
43			}
44		}
45	}
46
47	if length > 0 {
48		r.HTTPRequest.ContentLength = length
49		r.HTTPRequest.Header.Set("Content-Length", fmt.Sprintf("%d", length))
50	} else {
51		r.HTTPRequest.ContentLength = 0
52		r.HTTPRequest.Header.Del("Content-Length")
53	}
54}}
55
56var reStatusCode = regexp.MustCompile(`^(\d{3})`)
57
58// ValidateReqSigHandler is a request handler to ensure that the request's
59// signature doesn't expire before it is sent. This can happen when a request
60// is built and signed significantly before it is sent. Or significant delays
61// occur when retrying requests that would cause the signature to expire.
62var ValidateReqSigHandler = request.NamedHandler{
63	Name: "core.ValidateReqSigHandler",
64	Fn: func(r *request.Request) {
65		// Unsigned requests are not signed
66		if r.Config.Credentials == credentials.AnonymousCredentials {
67			return
68		}
69
70		signedTime := r.Time
71		if !r.LastSignedAt.IsZero() {
72			signedTime = r.LastSignedAt
73		}
74
75		// 5 minutes to allow for some clock skew/delays in transmission.
76		// Would be improved with aws/aws-sdk-go#423
77		if signedTime.Add(5 * time.Minute).After(time.Now()) {
78			return
79		}
80
81		fmt.Println("request expired, resigning")
82		r.Sign()
83	},
84}
85
86// SendHandler is a request handler to send service request using HTTP client.
87var SendHandler = request.NamedHandler{
88	Name: "core.SendHandler",
89	Fn: func(r *request.Request) {
90		sender := sendFollowRedirects
91		if r.DisableFollowRedirects {
92			sender = sendWithoutFollowRedirects
93		}
94
95		if request.NoBody == r.HTTPRequest.Body {
96			// Strip off the request body if the NoBody reader was used as a
97			// place holder for a request body. This prevents the SDK from
98			// making requests with a request body when it would be invalid
99			// to do so.
100			//
101			// Use a shallow copy of the http.Request to ensure the race condition
102			// of transport on Body will not trigger
103			reqOrig, reqCopy := r.HTTPRequest, *r.HTTPRequest
104			reqCopy.Body = nil
105			r.HTTPRequest = &reqCopy
106			defer func() {
107				r.HTTPRequest = reqOrig
108			}()
109		}
110
111		var err error
112		r.HTTPResponse, err = sender(r)
113		if err != nil {
114			handleSendError(r, err)
115		}
116	},
117}
118
119func sendFollowRedirects(r *request.Request) (*http.Response, error) {
120	return r.Config.HTTPClient.Do(r.HTTPRequest)
121}
122
123func sendWithoutFollowRedirects(r *request.Request) (*http.Response, error) {
124	transport := r.Config.HTTPClient.Transport
125	if transport == nil {
126		transport = http.DefaultTransport
127	}
128
129	return transport.RoundTrip(r.HTTPRequest)
130}
131
132func handleSendError(r *request.Request, err error) {
133	// Prevent leaking if an HTTPResponse was returned. Clean up
134	// the body.
135	if r.HTTPResponse != nil {
136		r.HTTPResponse.Body.Close()
137	}
138	// Capture the case where url.Error is returned for error processing
139	// response. e.g. 301 without location header comes back as string
140	// error and r.HTTPResponse is nil. Other URL redirect errors will
141	// comeback in a similar method.
142	if e, ok := err.(*url.Error); ok && e.Err != nil {
143		if s := reStatusCode.FindStringSubmatch(e.Err.Error()); s != nil {
144			code, _ := strconv.ParseInt(s[1], 10, 64)
145			r.HTTPResponse = &http.Response{
146				StatusCode: int(code),
147				Status:     http.StatusText(int(code)),
148				Body:       ioutil.NopCloser(bytes.NewReader([]byte{})),
149			}
150			return
151		}
152	}
153	if r.HTTPResponse == nil {
154		// Add a dummy request response object to ensure the HTTPResponse
155		// value is consistent.
156		r.HTTPResponse = &http.Response{
157			StatusCode: int(0),
158			Status:     http.StatusText(int(0)),
159			Body:       ioutil.NopCloser(bytes.NewReader([]byte{})),
160		}
161	}
162	// Catch all request errors, and let the default retrier determine
163	// if the error is retryable.
164	r.Error = awserr.New(request.ErrCodeRequestError, "send request failed", err)
165
166	// Override the error with a context canceled error, if that was canceled.
167	ctx := r.Context()
168	select {
169	case <-ctx.Done():
170		r.Error = awserr.New(request.CanceledErrorCode,
171			"request context canceled", ctx.Err())
172		r.Retryable = aws.Bool(false)
173	default:
174	}
175}
176
177// ValidateResponseHandler is a request handler to validate service response.
178var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseHandler", Fn: func(r *request.Request) {
179	if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 {
180		// this may be replaced by an UnmarshalError handler
181		r.Error = awserr.New("UnknownError", "unknown error", nil)
182	}
183}}
184
185// AfterRetryHandler performs final checks to determine if the request should
186// be retried and how long to delay.
187var AfterRetryHandler = request.NamedHandler{
188	Name: "core.AfterRetryHandler",
189	Fn: func(r *request.Request) {
190		// If one of the other handlers already set the retry state
191		// we don't want to override it based on the service's state
192		if r.Retryable == nil || aws.BoolValue(r.Config.EnforceShouldRetryCheck) {
193			r.Retryable = aws.Bool(r.ShouldRetry(r))
194		}
195
196		if r.WillRetry() {
197			r.RetryDelay = r.RetryRules(r)
198
199			if sleepFn := r.Config.SleepDelay; sleepFn != nil {
200				// Support SleepDelay for backwards compatibility and testing
201				sleepFn(r.RetryDelay)
202			} else if err := aws.SleepWithContext(r.Context(), r.RetryDelay); err != nil {
203				r.Error = awserr.New(request.CanceledErrorCode,
204					"request context canceled", err)
205				r.Retryable = aws.Bool(false)
206				return
207			}
208
209			// when the expired token exception occurs the credentials
210			// need to be expired locally so that the next request to
211			// get credentials will trigger a credentials refresh.
212			if r.IsErrorExpired() {
213				r.Config.Credentials.Expire()
214			}
215
216			r.RetryCount++
217			r.Error = nil
218		}
219	}}
220
221// ValidateEndpointHandler is a request handler to validate a request had the
222// appropriate Region and Endpoint set. Will set r.Error if the endpoint or
223// region is not valid.
224var ValidateEndpointHandler = request.NamedHandler{Name: "core.ValidateEndpointHandler", Fn: func(r *request.Request) {
225	if r.ClientInfo.SigningRegion == "" && aws.StringValue(r.Config.Region) == "" {
226		r.Error = aws.ErrMissingRegion
227	} else if r.ClientInfo.Endpoint == "" {
228		r.Error = aws.ErrMissingEndpoint
229	}
230}}
231