1package autorest
2
3// Copyright 2017 Microsoft Corporation
4//
5//  Licensed under the Apache License, Version 2.0 (the "License");
6//  you may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at
8//
9//      http://www.apache.org/licenses/LICENSE-2.0
10//
11//  Unless required by applicable law or agreed to in writing, software
12//  distributed under the License is distributed on an "AS IS" BASIS,
13//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14//  See the License for the specific language governing permissions and
15//  limitations under the License.
16
17import (
18	"context"
19	"crypto/tls"
20	"fmt"
21	"log"
22	"math"
23	"net/http"
24	"net/http/cookiejar"
25	"strconv"
26	"time"
27
28	"github.com/Azure/go-autorest/tracing"
29)
30
31// used as a key type in context.WithValue()
32type ctxSendDecorators struct{}
33
34// WithSendDecorators adds the specified SendDecorators to the provided context.
35// If no SendDecorators are provided the context is unchanged.
36func WithSendDecorators(ctx context.Context, sendDecorator []SendDecorator) context.Context {
37	if len(sendDecorator) == 0 {
38		return ctx
39	}
40	return context.WithValue(ctx, ctxSendDecorators{}, sendDecorator)
41}
42
43// GetSendDecorators returns the SendDecorators in the provided context or the provided default SendDecorators.
44func GetSendDecorators(ctx context.Context, defaultSendDecorators ...SendDecorator) []SendDecorator {
45	inCtx := ctx.Value(ctxSendDecorators{})
46	if sd, ok := inCtx.([]SendDecorator); ok {
47		return sd
48	}
49	return defaultSendDecorators
50}
51
52// Sender is the interface that wraps the Do method to send HTTP requests.
53//
54// The standard http.Client conforms to this interface.
55type Sender interface {
56	Do(*http.Request) (*http.Response, error)
57}
58
59// SenderFunc is a method that implements the Sender interface.
60type SenderFunc func(*http.Request) (*http.Response, error)
61
62// Do implements the Sender interface on SenderFunc.
63func (sf SenderFunc) Do(r *http.Request) (*http.Response, error) {
64	return sf(r)
65}
66
67// SendDecorator takes and possibly decorates, by wrapping, a Sender. Decorators may affect the
68// http.Request and pass it along or, first, pass the http.Request along then react to the
69// http.Response result.
70type SendDecorator func(Sender) Sender
71
72// CreateSender creates, decorates, and returns, as a Sender, the default http.Client.
73func CreateSender(decorators ...SendDecorator) Sender {
74	return DecorateSender(sender(tls.RenegotiateNever), decorators...)
75}
76
77// DecorateSender accepts a Sender and a, possibly empty, set of SendDecorators, which is applies to
78// the Sender. Decorators are applied in the order received, but their affect upon the request
79// depends on whether they are a pre-decorator (change the http.Request and then pass it along) or a
80// post-decorator (pass the http.Request along and react to the results in http.Response).
81func DecorateSender(s Sender, decorators ...SendDecorator) Sender {
82	for _, decorate := range decorators {
83		s = decorate(s)
84	}
85	return s
86}
87
88// Send sends, by means of the default http.Client, the passed http.Request, returning the
89// http.Response and possible error. It also accepts a, possibly empty, set of SendDecorators which
90// it will apply the http.Client before invoking the Do method.
91//
92// Send is a convenience method and not recommended for production. Advanced users should use
93// SendWithSender, passing and sharing their own Sender (e.g., instance of http.Client).
94//
95// Send will not poll or retry requests.
96func Send(r *http.Request, decorators ...SendDecorator) (*http.Response, error) {
97	return SendWithSender(sender(tls.RenegotiateNever), r, decorators...)
98}
99
100// SendWithSender sends the passed http.Request, through the provided Sender, returning the
101// http.Response and possible error. It also accepts a, possibly empty, set of SendDecorators which
102// it will apply the http.Client before invoking the Do method.
103//
104// SendWithSender will not poll or retry requests.
105func SendWithSender(s Sender, r *http.Request, decorators ...SendDecorator) (*http.Response, error) {
106	return DecorateSender(s, decorators...).Do(r)
107}
108
109func sender(renengotiation tls.RenegotiationSupport) Sender {
110	// Use behaviour compatible with DefaultTransport, but require TLS minimum version.
111	defaultTransport := http.DefaultTransport.(*http.Transport)
112	transport := &http.Transport{
113		Proxy:                 defaultTransport.Proxy,
114		DialContext:           defaultTransport.DialContext,
115		MaxIdleConns:          defaultTransport.MaxIdleConns,
116		IdleConnTimeout:       defaultTransport.IdleConnTimeout,
117		TLSHandshakeTimeout:   defaultTransport.TLSHandshakeTimeout,
118		ExpectContinueTimeout: defaultTransport.ExpectContinueTimeout,
119		TLSClientConfig: &tls.Config{
120			MinVersion:    tls.VersionTLS12,
121			Renegotiation: renengotiation,
122		},
123	}
124	var roundTripper http.RoundTripper = transport
125	if tracing.IsEnabled() {
126		roundTripper = tracing.NewTransport(transport)
127	}
128	j, _ := cookiejar.New(nil)
129	return &http.Client{Jar: j, Transport: roundTripper}
130}
131
132// AfterDelay returns a SendDecorator that delays for the passed time.Duration before
133// invoking the Sender. The delay may be terminated by closing the optional channel on the
134// http.Request. If canceled, no further Senders are invoked.
135func AfterDelay(d time.Duration) SendDecorator {
136	return func(s Sender) Sender {
137		return SenderFunc(func(r *http.Request) (*http.Response, error) {
138			if !DelayForBackoff(d, 0, r.Context().Done()) {
139				return nil, fmt.Errorf("autorest: AfterDelay canceled before full delay")
140			}
141			return s.Do(r)
142		})
143	}
144}
145
146// AsIs returns a SendDecorator that invokes the passed Sender without modifying the http.Request.
147func AsIs() SendDecorator {
148	return func(s Sender) Sender {
149		return SenderFunc(func(r *http.Request) (*http.Response, error) {
150			return s.Do(r)
151		})
152	}
153}
154
155// DoCloseIfError returns a SendDecorator that first invokes the passed Sender after which
156// it closes the response if the passed Sender returns an error and the response body exists.
157func DoCloseIfError() SendDecorator {
158	return func(s Sender) Sender {
159		return SenderFunc(func(r *http.Request) (*http.Response, error) {
160			resp, err := s.Do(r)
161			if err != nil {
162				Respond(resp, ByDiscardingBody(), ByClosing())
163			}
164			return resp, err
165		})
166	}
167}
168
169// DoErrorIfStatusCode returns a SendDecorator that emits an error if the response StatusCode is
170// among the set passed. Since these are artificial errors, the response body may still require
171// closing.
172func DoErrorIfStatusCode(codes ...int) SendDecorator {
173	return func(s Sender) Sender {
174		return SenderFunc(func(r *http.Request) (*http.Response, error) {
175			resp, err := s.Do(r)
176			if err == nil && ResponseHasStatusCode(resp, codes...) {
177				err = NewErrorWithResponse("autorest", "DoErrorIfStatusCode", resp, "%v %v failed with %s",
178					resp.Request.Method,
179					resp.Request.URL,
180					resp.Status)
181			}
182			return resp, err
183		})
184	}
185}
186
187// DoErrorUnlessStatusCode returns a SendDecorator that emits an error unless the response
188// StatusCode is among the set passed. Since these are artificial errors, the response body
189// may still require closing.
190func DoErrorUnlessStatusCode(codes ...int) SendDecorator {
191	return func(s Sender) Sender {
192		return SenderFunc(func(r *http.Request) (*http.Response, error) {
193			resp, err := s.Do(r)
194			if err == nil && !ResponseHasStatusCode(resp, codes...) {
195				err = NewErrorWithResponse("autorest", "DoErrorUnlessStatusCode", resp, "%v %v failed with %s",
196					resp.Request.Method,
197					resp.Request.URL,
198					resp.Status)
199			}
200			return resp, err
201		})
202	}
203}
204
205// DoPollForStatusCodes returns a SendDecorator that polls if the http.Response contains one of the
206// passed status codes. It expects the http.Response to contain a Location header providing the
207// URL at which to poll (using GET) and will poll until the time passed is equal to or greater than
208// the supplied duration. It will delay between requests for the duration specified in the
209// RetryAfter header or, if the header is absent, the passed delay. Polling may be canceled by
210// closing the optional channel on the http.Request.
211func DoPollForStatusCodes(duration time.Duration, delay time.Duration, codes ...int) SendDecorator {
212	return func(s Sender) Sender {
213		return SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
214			resp, err = s.Do(r)
215
216			if err == nil && ResponseHasStatusCode(resp, codes...) {
217				r, err = NewPollingRequestWithContext(r.Context(), resp)
218
219				for err == nil && ResponseHasStatusCode(resp, codes...) {
220					Respond(resp,
221						ByDiscardingBody(),
222						ByClosing())
223					resp, err = SendWithSender(s, r,
224						AfterDelay(GetRetryAfter(resp, delay)))
225				}
226			}
227
228			return resp, err
229		})
230	}
231}
232
233// DoRetryForAttempts returns a SendDecorator that retries a failed request for up to the specified
234// number of attempts, exponentially backing off between requests using the supplied backoff
235// time.Duration (which may be zero). Retrying may be canceled by closing the optional channel on
236// the http.Request.
237func DoRetryForAttempts(attempts int, backoff time.Duration) SendDecorator {
238	return func(s Sender) Sender {
239		return SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
240			rr := NewRetriableRequest(r)
241			for attempt := 0; attempt < attempts; attempt++ {
242				err = rr.Prepare()
243				if err != nil {
244					return resp, err
245				}
246				DrainResponseBody(resp)
247				resp, err = s.Do(rr.Request())
248				if err == nil {
249					return resp, err
250				}
251				if !DelayForBackoff(backoff, attempt, r.Context().Done()) {
252					return nil, r.Context().Err()
253				}
254			}
255			return resp, err
256		})
257	}
258}
259
260// Count429AsRetry indicates that a 429 response should be included as a retry attempt.
261var Count429AsRetry = true
262
263// Max429Delay is the maximum duration to wait between retries on a 429 if no Retry-After header was received.
264var Max429Delay time.Duration
265
266// DoRetryForStatusCodes returns a SendDecorator that retries for specified statusCodes for up to the specified
267// number of attempts, exponentially backing off between requests using the supplied backoff
268// time.Duration (which may be zero). Retrying may be canceled by cancelling the context on the http.Request.
269// NOTE: Code http.StatusTooManyRequests (429) will *not* be counted against the number of attempts.
270func DoRetryForStatusCodes(attempts int, backoff time.Duration, codes ...int) SendDecorator {
271	return func(s Sender) Sender {
272		return SenderFunc(func(r *http.Request) (*http.Response, error) {
273			return doRetryForStatusCodesImpl(s, r, Count429AsRetry, attempts, backoff, 0, codes...)
274		})
275	}
276}
277
278// DoRetryForStatusCodesWithCap returns a SendDecorator that retries for specified statusCodes for up to the
279// specified number of attempts, exponentially backing off between requests using the supplied backoff
280// time.Duration (which may be zero). To cap the maximum possible delay between iterations specify a value greater
281// than zero for cap. Retrying may be canceled by cancelling the context on the http.Request.
282func DoRetryForStatusCodesWithCap(attempts int, backoff, cap time.Duration, codes ...int) SendDecorator {
283	return func(s Sender) Sender {
284		return SenderFunc(func(r *http.Request) (*http.Response, error) {
285			return doRetryForStatusCodesImpl(s, r, Count429AsRetry, attempts, backoff, cap, codes...)
286		})
287	}
288}
289
290func doRetryForStatusCodesImpl(s Sender, r *http.Request, count429 bool, attempts int, backoff, cap time.Duration, codes ...int) (resp *http.Response, err error) {
291	rr := NewRetriableRequest(r)
292	// Increment to add the first call (attempts denotes number of retries)
293	for attempt, delayCount := 0, 0; attempt < attempts+1; {
294		err = rr.Prepare()
295		if err != nil {
296			return
297		}
298		DrainResponseBody(resp)
299		resp, err = s.Do(rr.Request())
300		// we want to retry if err is not nil (e.g. transient network failure).  note that for failed authentication
301		// resp and err will both have a value, so in this case we don't want to retry as it will never succeed.
302		if err == nil && !ResponseHasStatusCode(resp, codes...) || IsTokenRefreshError(err) {
303			return resp, err
304		}
305		delayed := DelayWithRetryAfter(resp, r.Context().Done())
306		// if this was a 429 set the delay cap as specified.
307		// applicable only in the absence of a retry-after header.
308		if resp != nil && resp.StatusCode == http.StatusTooManyRequests {
309			cap = Max429Delay
310		}
311		if !delayed && !DelayForBackoffWithCap(backoff, cap, delayCount, r.Context().Done()) {
312			return resp, r.Context().Err()
313		}
314		// when count429 == false don't count a 429 against the number
315		// of attempts so that we continue to retry until it succeeds
316		if count429 || (resp == nil || resp.StatusCode != http.StatusTooManyRequests) {
317			attempt++
318		}
319		// delay count is tracked separately from attempts to
320		// ensure that 429 participates in exponential back-off
321		delayCount++
322	}
323	return resp, err
324}
325
326// DelayWithRetryAfter invokes time.After for the duration specified in the "Retry-After" header.
327// The value of Retry-After can be either the number of seconds or a date in RFC1123 format.
328// The function returns true after successfully waiting for the specified duration.  If there is
329// no Retry-After header or the wait is cancelled the return value is false.
330func DelayWithRetryAfter(resp *http.Response, cancel <-chan struct{}) bool {
331	if resp == nil {
332		return false
333	}
334	var dur time.Duration
335	ra := resp.Header.Get("Retry-After")
336	if retryAfter, _ := strconv.Atoi(ra); retryAfter > 0 {
337		dur = time.Duration(retryAfter) * time.Second
338	} else if t, err := time.Parse(time.RFC1123, ra); err == nil {
339		dur = t.Sub(time.Now())
340	}
341	if dur > 0 {
342		select {
343		case <-time.After(dur):
344			return true
345		case <-cancel:
346			return false
347		}
348	}
349	return false
350}
351
352// DoRetryForDuration returns a SendDecorator that retries the request until the total time is equal
353// to or greater than the specified duration, exponentially backing off between requests using the
354// supplied backoff time.Duration (which may be zero). Retrying may be canceled by closing the
355// optional channel on the http.Request.
356func DoRetryForDuration(d time.Duration, backoff time.Duration) SendDecorator {
357	return func(s Sender) Sender {
358		return SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
359			rr := NewRetriableRequest(r)
360			end := time.Now().Add(d)
361			for attempt := 0; time.Now().Before(end); attempt++ {
362				err = rr.Prepare()
363				if err != nil {
364					return resp, err
365				}
366				DrainResponseBody(resp)
367				resp, err = s.Do(rr.Request())
368				if err == nil {
369					return resp, err
370				}
371				if !DelayForBackoff(backoff, attempt, r.Context().Done()) {
372					return nil, r.Context().Err()
373				}
374			}
375			return resp, err
376		})
377	}
378}
379
380// WithLogging returns a SendDecorator that implements simple before and after logging of the
381// request.
382func WithLogging(logger *log.Logger) SendDecorator {
383	return func(s Sender) Sender {
384		return SenderFunc(func(r *http.Request) (*http.Response, error) {
385			logger.Printf("Sending %s %s", r.Method, r.URL)
386			resp, err := s.Do(r)
387			if err != nil {
388				logger.Printf("%s %s received error '%v'", r.Method, r.URL, err)
389			} else {
390				logger.Printf("%s %s received %s", r.Method, r.URL, resp.Status)
391			}
392			return resp, err
393		})
394	}
395}
396
397// DelayForBackoff invokes time.After for the supplied backoff duration raised to the power of
398// passed attempt (i.e., an exponential backoff delay). Backoff duration is in seconds and can set
399// to zero for no delay. The delay may be canceled by closing the passed channel. If terminated early,
400// returns false.
401// Note: Passing attempt 1 will result in doubling "backoff" duration. Treat this as a zero-based attempt
402// count.
403func DelayForBackoff(backoff time.Duration, attempt int, cancel <-chan struct{}) bool {
404	return DelayForBackoffWithCap(backoff, 0, attempt, cancel)
405}
406
407// DelayForBackoffWithCap invokes time.After for the supplied backoff duration raised to the power of
408// passed attempt (i.e., an exponential backoff delay). Backoff duration is in seconds and can set
409// to zero for no delay. To cap the maximum possible delay specify a value greater than zero for cap.
410// The delay may be canceled by closing the passed channel. If terminated early, returns false.
411// Note: Passing attempt 1 will result in doubling "backoff" duration. Treat this as a zero-based attempt
412// count.
413func DelayForBackoffWithCap(backoff, cap time.Duration, attempt int, cancel <-chan struct{}) bool {
414	d := time.Duration(backoff.Seconds()*math.Pow(2, float64(attempt))) * time.Second
415	if cap > 0 && d > cap {
416		d = cap
417	}
418	select {
419	case <-time.After(d):
420		return true
421	case <-cancel:
422		return false
423	}
424}
425