1package autorest
2
3// Copyright 2017 Microsoft Corporation
4//
5//  Licensed under the Apache License, Version 2.0 (the "License");
6//  you may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at
8//
9//      http://www.apache.org/licenses/LICENSE-2.0
10//
11//  Unless required by applicable law or agreed to in writing, software
12//  distributed under the License is distributed on an "AS IS" BASIS,
13//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14//  See the License for the specific language governing permissions and
15//  limitations under the License.
16
17import (
18	"context"
19	"crypto/tls"
20	"fmt"
21	"log"
22	"math"
23	"net/http"
24	"net/http/cookiejar"
25	"strconv"
26	"sync"
27	"time"
28
29	"github.com/Azure/go-autorest/tracing"
30)
31
32// there is one sender per TLS renegotiation type, i.e. count of tls.RenegotiationSupport enums
33const defaultSendersCount = 3
34
35type defaultSender struct {
36	sender Sender
37	init   *sync.Once
38}
39
40// each type of sender will be created on demand in sender()
41var defaultSenders [defaultSendersCount]defaultSender
42
43func init() {
44	for i := 0; i < defaultSendersCount; i++ {
45		defaultSenders[i].init = &sync.Once{}
46	}
47}
48
49// used as a key type in context.WithValue()
50type ctxSendDecorators struct{}
51
52// WithSendDecorators adds the specified SendDecorators to the provided context.
53// If no SendDecorators are provided the context is unchanged.
54func WithSendDecorators(ctx context.Context, sendDecorator []SendDecorator) context.Context {
55	if len(sendDecorator) == 0 {
56		return ctx
57	}
58	return context.WithValue(ctx, ctxSendDecorators{}, sendDecorator)
59}
60
61// GetSendDecorators returns the SendDecorators in the provided context or the provided default SendDecorators.
62func GetSendDecorators(ctx context.Context, defaultSendDecorators ...SendDecorator) []SendDecorator {
63	inCtx := ctx.Value(ctxSendDecorators{})
64	if sd, ok := inCtx.([]SendDecorator); ok {
65		return sd
66	}
67	return defaultSendDecorators
68}
69
70// Sender is the interface that wraps the Do method to send HTTP requests.
71//
72// The standard http.Client conforms to this interface.
73type Sender interface {
74	Do(*http.Request) (*http.Response, error)
75}
76
77// SenderFunc is a method that implements the Sender interface.
78type SenderFunc func(*http.Request) (*http.Response, error)
79
80// Do implements the Sender interface on SenderFunc.
81func (sf SenderFunc) Do(r *http.Request) (*http.Response, error) {
82	return sf(r)
83}
84
85// SendDecorator takes and possibly decorates, by wrapping, a Sender. Decorators may affect the
86// http.Request and pass it along or, first, pass the http.Request along then react to the
87// http.Response result.
88type SendDecorator func(Sender) Sender
89
90// CreateSender creates, decorates, and returns, as a Sender, the default http.Client.
91func CreateSender(decorators ...SendDecorator) Sender {
92	return DecorateSender(sender(tls.RenegotiateNever), decorators...)
93}
94
95// DecorateSender accepts a Sender and a, possibly empty, set of SendDecorators, which is applies to
96// the Sender. Decorators are applied in the order received, but their affect upon the request
97// depends on whether they are a pre-decorator (change the http.Request and then pass it along) or a
98// post-decorator (pass the http.Request along and react to the results in http.Response).
99func DecorateSender(s Sender, decorators ...SendDecorator) Sender {
100	for _, decorate := range decorators {
101		s = decorate(s)
102	}
103	return s
104}
105
106// Send sends, by means of the default http.Client, the passed http.Request, returning the
107// http.Response and possible error. It also accepts a, possibly empty, set of SendDecorators which
108// it will apply the http.Client before invoking the Do method.
109//
110// Send is a convenience method and not recommended for production. Advanced users should use
111// SendWithSender, passing and sharing their own Sender (e.g., instance of http.Client).
112//
113// Send will not poll or retry requests.
114func Send(r *http.Request, decorators ...SendDecorator) (*http.Response, error) {
115	return SendWithSender(sender(tls.RenegotiateNever), r, decorators...)
116}
117
118// SendWithSender sends the passed http.Request, through the provided Sender, returning the
119// http.Response and possible error. It also accepts a, possibly empty, set of SendDecorators which
120// it will apply the http.Client before invoking the Do method.
121//
122// SendWithSender will not poll or retry requests.
123func SendWithSender(s Sender, r *http.Request, decorators ...SendDecorator) (*http.Response, error) {
124	return DecorateSender(s, decorators...).Do(r)
125}
126
127func sender(renengotiation tls.RenegotiationSupport) Sender {
128	// note that we can't init defaultSenders in init() since it will
129	// execute before calling code has had a chance to enable tracing
130	defaultSenders[renengotiation].init.Do(func() {
131		// Use behaviour compatible with DefaultTransport, but require TLS minimum version.
132		defaultTransport := http.DefaultTransport.(*http.Transport)
133		transport := &http.Transport{
134			Proxy:                 defaultTransport.Proxy,
135			DialContext:           defaultTransport.DialContext,
136			MaxIdleConns:          defaultTransport.MaxIdleConns,
137			IdleConnTimeout:       defaultTransport.IdleConnTimeout,
138			TLSHandshakeTimeout:   defaultTransport.TLSHandshakeTimeout,
139			ExpectContinueTimeout: defaultTransport.ExpectContinueTimeout,
140			TLSClientConfig: &tls.Config{
141				MinVersion:    tls.VersionTLS12,
142				Renegotiation: renengotiation,
143			},
144		}
145		var roundTripper http.RoundTripper = transport
146		if tracing.IsEnabled() {
147			roundTripper = tracing.NewTransport(transport)
148		}
149		j, _ := cookiejar.New(nil)
150		defaultSenders[renengotiation].sender = &http.Client{Jar: j, Transport: roundTripper}
151	})
152	return defaultSenders[renengotiation].sender
153}
154
155// AfterDelay returns a SendDecorator that delays for the passed time.Duration before
156// invoking the Sender. The delay may be terminated by closing the optional channel on the
157// http.Request. If canceled, no further Senders are invoked.
158func AfterDelay(d time.Duration) SendDecorator {
159	return func(s Sender) Sender {
160		return SenderFunc(func(r *http.Request) (*http.Response, error) {
161			if !DelayForBackoff(d, 0, r.Context().Done()) {
162				return nil, fmt.Errorf("autorest: AfterDelay canceled before full delay")
163			}
164			return s.Do(r)
165		})
166	}
167}
168
169// AsIs returns a SendDecorator that invokes the passed Sender without modifying the http.Request.
170func AsIs() SendDecorator {
171	return func(s Sender) Sender {
172		return SenderFunc(func(r *http.Request) (*http.Response, error) {
173			return s.Do(r)
174		})
175	}
176}
177
178// DoCloseIfError returns a SendDecorator that first invokes the passed Sender after which
179// it closes the response if the passed Sender returns an error and the response body exists.
180func DoCloseIfError() SendDecorator {
181	return func(s Sender) Sender {
182		return SenderFunc(func(r *http.Request) (*http.Response, error) {
183			resp, err := s.Do(r)
184			if err != nil {
185				Respond(resp, ByDiscardingBody(), ByClosing())
186			}
187			return resp, err
188		})
189	}
190}
191
192// DoErrorIfStatusCode returns a SendDecorator that emits an error if the response StatusCode is
193// among the set passed. Since these are artificial errors, the response body may still require
194// closing.
195func DoErrorIfStatusCode(codes ...int) SendDecorator {
196	return func(s Sender) Sender {
197		return SenderFunc(func(r *http.Request) (*http.Response, error) {
198			resp, err := s.Do(r)
199			if err == nil && ResponseHasStatusCode(resp, codes...) {
200				err = NewErrorWithResponse("autorest", "DoErrorIfStatusCode", resp, "%v %v failed with %s",
201					resp.Request.Method,
202					resp.Request.URL,
203					resp.Status)
204			}
205			return resp, err
206		})
207	}
208}
209
210// DoErrorUnlessStatusCode returns a SendDecorator that emits an error unless the response
211// StatusCode is among the set passed. Since these are artificial errors, the response body
212// may still require closing.
213func DoErrorUnlessStatusCode(codes ...int) SendDecorator {
214	return func(s Sender) Sender {
215		return SenderFunc(func(r *http.Request) (*http.Response, error) {
216			resp, err := s.Do(r)
217			if err == nil && !ResponseHasStatusCode(resp, codes...) {
218				err = NewErrorWithResponse("autorest", "DoErrorUnlessStatusCode", resp, "%v %v failed with %s",
219					resp.Request.Method,
220					resp.Request.URL,
221					resp.Status)
222			}
223			return resp, err
224		})
225	}
226}
227
228// DoPollForStatusCodes returns a SendDecorator that polls if the http.Response contains one of the
229// passed status codes. It expects the http.Response to contain a Location header providing the
230// URL at which to poll (using GET) and will poll until the time passed is equal to or greater than
231// the supplied duration. It will delay between requests for the duration specified in the
232// RetryAfter header or, if the header is absent, the passed delay. Polling may be canceled by
233// closing the optional channel on the http.Request.
234func DoPollForStatusCodes(duration time.Duration, delay time.Duration, codes ...int) SendDecorator {
235	return func(s Sender) Sender {
236		return SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
237			resp, err = s.Do(r)
238
239			if err == nil && ResponseHasStatusCode(resp, codes...) {
240				r, err = NewPollingRequestWithContext(r.Context(), resp)
241
242				for err == nil && ResponseHasStatusCode(resp, codes...) {
243					Respond(resp,
244						ByDiscardingBody(),
245						ByClosing())
246					resp, err = SendWithSender(s, r,
247						AfterDelay(GetRetryAfter(resp, delay)))
248				}
249			}
250
251			return resp, err
252		})
253	}
254}
255
256// DoRetryForAttempts returns a SendDecorator that retries a failed request for up to the specified
257// number of attempts, exponentially backing off between requests using the supplied backoff
258// time.Duration (which may be zero). Retrying may be canceled by closing the optional channel on
259// the http.Request.
260func DoRetryForAttempts(attempts int, backoff time.Duration) SendDecorator {
261	return func(s Sender) Sender {
262		return SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
263			rr := NewRetriableRequest(r)
264			for attempt := 0; attempt < attempts; attempt++ {
265				err = rr.Prepare()
266				if err != nil {
267					return resp, err
268				}
269				DrainResponseBody(resp)
270				resp, err = s.Do(rr.Request())
271				if err == nil {
272					return resp, err
273				}
274				if !DelayForBackoff(backoff, attempt, r.Context().Done()) {
275					return nil, r.Context().Err()
276				}
277			}
278			return resp, err
279		})
280	}
281}
282
283// Count429AsRetry indicates that a 429 response should be included as a retry attempt.
284var Count429AsRetry = true
285
286// Max429Delay is the maximum duration to wait between retries on a 429 if no Retry-After header was received.
287var Max429Delay time.Duration
288
289// DoRetryForStatusCodes returns a SendDecorator that retries for specified statusCodes for up to the specified
290// number of attempts, exponentially backing off between requests using the supplied backoff
291// time.Duration (which may be zero). Retrying may be canceled by cancelling the context on the http.Request.
292// NOTE: Code http.StatusTooManyRequests (429) will *not* be counted against the number of attempts.
293func DoRetryForStatusCodes(attempts int, backoff time.Duration, codes ...int) SendDecorator {
294	return func(s Sender) Sender {
295		return SenderFunc(func(r *http.Request) (*http.Response, error) {
296			return doRetryForStatusCodesImpl(s, r, Count429AsRetry, attempts, backoff, 0, codes...)
297		})
298	}
299}
300
301// DoRetryForStatusCodesWithCap returns a SendDecorator that retries for specified statusCodes for up to the
302// specified number of attempts, exponentially backing off between requests using the supplied backoff
303// time.Duration (which may be zero). To cap the maximum possible delay between iterations specify a value greater
304// than zero for cap. Retrying may be canceled by cancelling the context on the http.Request.
305func DoRetryForStatusCodesWithCap(attempts int, backoff, cap time.Duration, codes ...int) SendDecorator {
306	return func(s Sender) Sender {
307		return SenderFunc(func(r *http.Request) (*http.Response, error) {
308			return doRetryForStatusCodesImpl(s, r, Count429AsRetry, attempts, backoff, cap, codes...)
309		})
310	}
311}
312
313func doRetryForStatusCodesImpl(s Sender, r *http.Request, count429 bool, attempts int, backoff, cap time.Duration, codes ...int) (resp *http.Response, err error) {
314	rr := NewRetriableRequest(r)
315	// Increment to add the first call (attempts denotes number of retries)
316	for attempt, delayCount := 0, 0; attempt < attempts+1; {
317		err = rr.Prepare()
318		if err != nil {
319			return
320		}
321		DrainResponseBody(resp)
322		resp, err = s.Do(rr.Request())
323		// we want to retry if err is not nil (e.g. transient network failure).  note that for failed authentication
324		// resp and err will both have a value, so in this case we don't want to retry as it will never succeed.
325		if err == nil && !ResponseHasStatusCode(resp, codes...) || IsTokenRefreshError(err) {
326			return resp, err
327		}
328		delayed := DelayWithRetryAfter(resp, r.Context().Done())
329		// if this was a 429 set the delay cap as specified.
330		// applicable only in the absence of a retry-after header.
331		if resp != nil && resp.StatusCode == http.StatusTooManyRequests {
332			cap = Max429Delay
333		}
334		if !delayed && !DelayForBackoffWithCap(backoff, cap, delayCount, r.Context().Done()) {
335			return resp, r.Context().Err()
336		}
337		// when count429 == false don't count a 429 against the number
338		// of attempts so that we continue to retry until it succeeds
339		if count429 || (resp == nil || resp.StatusCode != http.StatusTooManyRequests) {
340			attempt++
341		}
342		// delay count is tracked separately from attempts to
343		// ensure that 429 participates in exponential back-off
344		delayCount++
345	}
346	return resp, err
347}
348
349// DelayWithRetryAfter invokes time.After for the duration specified in the "Retry-After" header.
350// The value of Retry-After can be either the number of seconds or a date in RFC1123 format.
351// The function returns true after successfully waiting for the specified duration.  If there is
352// no Retry-After header or the wait is cancelled the return value is false.
353func DelayWithRetryAfter(resp *http.Response, cancel <-chan struct{}) bool {
354	if resp == nil {
355		return false
356	}
357	var dur time.Duration
358	ra := resp.Header.Get("Retry-After")
359	if retryAfter, _ := strconv.Atoi(ra); retryAfter > 0 {
360		dur = time.Duration(retryAfter) * time.Second
361	} else if t, err := time.Parse(time.RFC1123, ra); err == nil {
362		dur = t.Sub(time.Now())
363	}
364	if dur > 0 {
365		select {
366		case <-time.After(dur):
367			return true
368		case <-cancel:
369			return false
370		}
371	}
372	return false
373}
374
375// DoRetryForDuration returns a SendDecorator that retries the request until the total time is equal
376// to or greater than the specified duration, exponentially backing off between requests using the
377// supplied backoff time.Duration (which may be zero). Retrying may be canceled by closing the
378// optional channel on the http.Request.
379func DoRetryForDuration(d time.Duration, backoff time.Duration) SendDecorator {
380	return func(s Sender) Sender {
381		return SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
382			rr := NewRetriableRequest(r)
383			end := time.Now().Add(d)
384			for attempt := 0; time.Now().Before(end); attempt++ {
385				err = rr.Prepare()
386				if err != nil {
387					return resp, err
388				}
389				DrainResponseBody(resp)
390				resp, err = s.Do(rr.Request())
391				if err == nil {
392					return resp, err
393				}
394				if !DelayForBackoff(backoff, attempt, r.Context().Done()) {
395					return nil, r.Context().Err()
396				}
397			}
398			return resp, err
399		})
400	}
401}
402
403// WithLogging returns a SendDecorator that implements simple before and after logging of the
404// request.
405func WithLogging(logger *log.Logger) SendDecorator {
406	return func(s Sender) Sender {
407		return SenderFunc(func(r *http.Request) (*http.Response, error) {
408			logger.Printf("Sending %s %s", r.Method, r.URL)
409			resp, err := s.Do(r)
410			if err != nil {
411				logger.Printf("%s %s received error '%v'", r.Method, r.URL, err)
412			} else {
413				logger.Printf("%s %s received %s", r.Method, r.URL, resp.Status)
414			}
415			return resp, err
416		})
417	}
418}
419
420// DelayForBackoff invokes time.After for the supplied backoff duration raised to the power of
421// passed attempt (i.e., an exponential backoff delay). Backoff duration is in seconds and can set
422// to zero for no delay. The delay may be canceled by closing the passed channel. If terminated early,
423// returns false.
424// Note: Passing attempt 1 will result in doubling "backoff" duration. Treat this as a zero-based attempt
425// count.
426func DelayForBackoff(backoff time.Duration, attempt int, cancel <-chan struct{}) bool {
427	return DelayForBackoffWithCap(backoff, 0, attempt, cancel)
428}
429
430// DelayForBackoffWithCap invokes time.After for the supplied backoff duration raised to the power of
431// passed attempt (i.e., an exponential backoff delay). Backoff duration is in seconds and can set
432// to zero for no delay. To cap the maximum possible delay specify a value greater than zero for cap.
433// The delay may be canceled by closing the passed channel. If terminated early, returns false.
434// Note: Passing attempt 1 will result in doubling "backoff" duration. Treat this as a zero-based attempt
435// count.
436func DelayForBackoffWithCap(backoff, cap time.Duration, attempt int, cancel <-chan struct{}) bool {
437	d := time.Duration(backoff.Seconds()*math.Pow(2, float64(attempt))) * time.Second
438	if cap > 0 && d > cap {
439		d = cap
440	}
441	select {
442	case <-time.After(d):
443		return true
444	case <-cancel:
445		return false
446	}
447}
448