1package autorest
2
3// Copyright 2017 Microsoft Corporation
4//
5//  Licensed under the Apache License, Version 2.0 (the "License");
6//  you may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at
8//
9//      http://www.apache.org/licenses/LICENSE-2.0
10//
11//  Unless required by applicable law or agreed to in writing, software
12//  distributed under the License is distributed on an "AS IS" BASIS,
13//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14//  See the License for the specific language governing permissions and
15//  limitations under the License.
16
17import (
18	"context"
19	"crypto/tls"
20	"fmt"
21	"log"
22	"math"
23	"net/http"
24	"net/http/cookiejar"
25	"strconv"
26	"sync"
27	"time"
28
29	"github.com/Azure/go-autorest/logger"
30	"github.com/Azure/go-autorest/tracing"
31)
32
33// there is one sender per TLS renegotiation type, i.e. count of tls.RenegotiationSupport enums
34const defaultSendersCount = 3
35
36type defaultSender struct {
37	sender Sender
38	init   *sync.Once
39}
40
41// each type of sender will be created on demand in sender()
42var defaultSenders [defaultSendersCount]defaultSender
43
44func init() {
45	for i := 0; i < defaultSendersCount; i++ {
46		defaultSenders[i].init = &sync.Once{}
47	}
48}
49
50// used as a key type in context.WithValue()
51type ctxSendDecorators struct{}
52
53// WithSendDecorators adds the specified SendDecorators to the provided context.
54// If no SendDecorators are provided the context is unchanged.
55func WithSendDecorators(ctx context.Context, sendDecorator []SendDecorator) context.Context {
56	if len(sendDecorator) == 0 {
57		return ctx
58	}
59	return context.WithValue(ctx, ctxSendDecorators{}, sendDecorator)
60}
61
62// GetSendDecorators returns the SendDecorators in the provided context or the provided default SendDecorators.
63func GetSendDecorators(ctx context.Context, defaultSendDecorators ...SendDecorator) []SendDecorator {
64	inCtx := ctx.Value(ctxSendDecorators{})
65	if sd, ok := inCtx.([]SendDecorator); ok {
66		return sd
67	}
68	return defaultSendDecorators
69}
70
71// Sender is the interface that wraps the Do method to send HTTP requests.
72//
73// The standard http.Client conforms to this interface.
74type Sender interface {
75	Do(*http.Request) (*http.Response, error)
76}
77
78// SenderFunc is a method that implements the Sender interface.
79type SenderFunc func(*http.Request) (*http.Response, error)
80
81// Do implements the Sender interface on SenderFunc.
82func (sf SenderFunc) Do(r *http.Request) (*http.Response, error) {
83	return sf(r)
84}
85
86// SendDecorator takes and possibly decorates, by wrapping, a Sender. Decorators may affect the
87// http.Request and pass it along or, first, pass the http.Request along then react to the
88// http.Response result.
89type SendDecorator func(Sender) Sender
90
91// CreateSender creates, decorates, and returns, as a Sender, the default http.Client.
92func CreateSender(decorators ...SendDecorator) Sender {
93	return DecorateSender(sender(tls.RenegotiateNever), decorators...)
94}
95
96// DecorateSender accepts a Sender and a, possibly empty, set of SendDecorators, which is applies to
97// the Sender. Decorators are applied in the order received, but their affect upon the request
98// depends on whether they are a pre-decorator (change the http.Request and then pass it along) or a
99// post-decorator (pass the http.Request along and react to the results in http.Response).
100func DecorateSender(s Sender, decorators ...SendDecorator) Sender {
101	for _, decorate := range decorators {
102		s = decorate(s)
103	}
104	return s
105}
106
107// Send sends, by means of the default http.Client, the passed http.Request, returning the
108// http.Response and possible error. It also accepts a, possibly empty, set of SendDecorators which
109// it will apply the http.Client before invoking the Do method.
110//
111// Send is a convenience method and not recommended for production. Advanced users should use
112// SendWithSender, passing and sharing their own Sender (e.g., instance of http.Client).
113//
114// Send will not poll or retry requests.
115func Send(r *http.Request, decorators ...SendDecorator) (*http.Response, error) {
116	return SendWithSender(sender(tls.RenegotiateNever), r, decorators...)
117}
118
119// SendWithSender sends the passed http.Request, through the provided Sender, returning the
120// http.Response and possible error. It also accepts a, possibly empty, set of SendDecorators which
121// it will apply the http.Client before invoking the Do method.
122//
123// SendWithSender will not poll or retry requests.
124func SendWithSender(s Sender, r *http.Request, decorators ...SendDecorator) (*http.Response, error) {
125	return DecorateSender(s, decorators...).Do(r)
126}
127
128func sender(renengotiation tls.RenegotiationSupport) Sender {
129	// note that we can't init defaultSenders in init() since it will
130	// execute before calling code has had a chance to enable tracing
131	defaultSenders[renengotiation].init.Do(func() {
132		// Use behaviour compatible with DefaultTransport, but require TLS minimum version.
133		defaultTransport := http.DefaultTransport.(*http.Transport)
134		transport := &http.Transport{
135			Proxy:                 defaultTransport.Proxy,
136			DialContext:           defaultTransport.DialContext,
137			MaxIdleConns:          defaultTransport.MaxIdleConns,
138			IdleConnTimeout:       defaultTransport.IdleConnTimeout,
139			TLSHandshakeTimeout:   defaultTransport.TLSHandshakeTimeout,
140			ExpectContinueTimeout: defaultTransport.ExpectContinueTimeout,
141			TLSClientConfig: &tls.Config{
142				MinVersion:    tls.VersionTLS12,
143				Renegotiation: renengotiation,
144			},
145		}
146		var roundTripper http.RoundTripper = transport
147		if tracing.IsEnabled() {
148			roundTripper = tracing.NewTransport(transport)
149		}
150		j, _ := cookiejar.New(nil)
151		defaultSenders[renengotiation].sender = &http.Client{Jar: j, Transport: roundTripper}
152	})
153	return defaultSenders[renengotiation].sender
154}
155
156// AfterDelay returns a SendDecorator that delays for the passed time.Duration before
157// invoking the Sender. The delay may be terminated by closing the optional channel on the
158// http.Request. If canceled, no further Senders are invoked.
159func AfterDelay(d time.Duration) SendDecorator {
160	return func(s Sender) Sender {
161		return SenderFunc(func(r *http.Request) (*http.Response, error) {
162			if !DelayForBackoff(d, 0, r.Context().Done()) {
163				return nil, fmt.Errorf("autorest: AfterDelay canceled before full delay")
164			}
165			return s.Do(r)
166		})
167	}
168}
169
170// AsIs returns a SendDecorator that invokes the passed Sender without modifying the http.Request.
171func AsIs() SendDecorator {
172	return func(s Sender) Sender {
173		return SenderFunc(func(r *http.Request) (*http.Response, error) {
174			return s.Do(r)
175		})
176	}
177}
178
179// DoCloseIfError returns a SendDecorator that first invokes the passed Sender after which
180// it closes the response if the passed Sender returns an error and the response body exists.
181func DoCloseIfError() SendDecorator {
182	return func(s Sender) Sender {
183		return SenderFunc(func(r *http.Request) (*http.Response, error) {
184			resp, err := s.Do(r)
185			if err != nil {
186				Respond(resp, ByDiscardingBody(), ByClosing())
187			}
188			return resp, err
189		})
190	}
191}
192
193// DoErrorIfStatusCode returns a SendDecorator that emits an error if the response StatusCode is
194// among the set passed. Since these are artificial errors, the response body may still require
195// closing.
196func DoErrorIfStatusCode(codes ...int) SendDecorator {
197	return func(s Sender) Sender {
198		return SenderFunc(func(r *http.Request) (*http.Response, error) {
199			resp, err := s.Do(r)
200			if err == nil && ResponseHasStatusCode(resp, codes...) {
201				err = NewErrorWithResponse("autorest", "DoErrorIfStatusCode", resp, "%v %v failed with %s",
202					resp.Request.Method,
203					resp.Request.URL,
204					resp.Status)
205			}
206			return resp, err
207		})
208	}
209}
210
211// DoErrorUnlessStatusCode returns a SendDecorator that emits an error unless the response
212// StatusCode is among the set passed. Since these are artificial errors, the response body
213// may still require closing.
214func DoErrorUnlessStatusCode(codes ...int) SendDecorator {
215	return func(s Sender) Sender {
216		return SenderFunc(func(r *http.Request) (*http.Response, error) {
217			resp, err := s.Do(r)
218			if err == nil && !ResponseHasStatusCode(resp, codes...) {
219				err = NewErrorWithResponse("autorest", "DoErrorUnlessStatusCode", resp, "%v %v failed with %s",
220					resp.Request.Method,
221					resp.Request.URL,
222					resp.Status)
223			}
224			return resp, err
225		})
226	}
227}
228
229// DoPollForStatusCodes returns a SendDecorator that polls if the http.Response contains one of the
230// passed status codes. It expects the http.Response to contain a Location header providing the
231// URL at which to poll (using GET) and will poll until the time passed is equal to or greater than
232// the supplied duration. It will delay between requests for the duration specified in the
233// RetryAfter header or, if the header is absent, the passed delay. Polling may be canceled by
234// closing the optional channel on the http.Request.
235func DoPollForStatusCodes(duration time.Duration, delay time.Duration, codes ...int) SendDecorator {
236	return func(s Sender) Sender {
237		return SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
238			resp, err = s.Do(r)
239
240			if err == nil && ResponseHasStatusCode(resp, codes...) {
241				r, err = NewPollingRequestWithContext(r.Context(), resp)
242
243				for err == nil && ResponseHasStatusCode(resp, codes...) {
244					Respond(resp,
245						ByDiscardingBody(),
246						ByClosing())
247					resp, err = SendWithSender(s, r,
248						AfterDelay(GetRetryAfter(resp, delay)))
249				}
250			}
251
252			return resp, err
253		})
254	}
255}
256
257// DoRetryForAttempts returns a SendDecorator that retries a failed request for up to the specified
258// number of attempts, exponentially backing off between requests using the supplied backoff
259// time.Duration (which may be zero). Retrying may be canceled by closing the optional channel on
260// the http.Request.
261func DoRetryForAttempts(attempts int, backoff time.Duration) SendDecorator {
262	return func(s Sender) Sender {
263		return SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
264			rr := NewRetriableRequest(r)
265			for attempt := 0; attempt < attempts; attempt++ {
266				err = rr.Prepare()
267				if err != nil {
268					return resp, err
269				}
270				DrainResponseBody(resp)
271				resp, err = s.Do(rr.Request())
272				if err == nil {
273					return resp, err
274				}
275				logger.Instance.Writef(logger.LogError, "DoRetryForAttempts: received error for attempt %d: %v\n", attempt+1, err)
276				if !DelayForBackoff(backoff, attempt, r.Context().Done()) {
277					return nil, r.Context().Err()
278				}
279			}
280			return resp, err
281		})
282	}
283}
284
285// Count429AsRetry indicates that a 429 response should be included as a retry attempt.
286var Count429AsRetry = true
287
288// Max429Delay is the maximum duration to wait between retries on a 429 if no Retry-After header was received.
289var Max429Delay time.Duration
290
291// DoRetryForStatusCodes returns a SendDecorator that retries for specified statusCodes for up to the specified
292// number of attempts, exponentially backing off between requests using the supplied backoff
293// time.Duration (which may be zero). Retrying may be canceled by cancelling the context on the http.Request.
294// NOTE: Code http.StatusTooManyRequests (429) will *not* be counted against the number of attempts.
295func DoRetryForStatusCodes(attempts int, backoff time.Duration, codes ...int) SendDecorator {
296	return func(s Sender) Sender {
297		return SenderFunc(func(r *http.Request) (*http.Response, error) {
298			return doRetryForStatusCodesImpl(s, r, Count429AsRetry, attempts, backoff, 0, codes...)
299		})
300	}
301}
302
303// DoRetryForStatusCodesWithCap returns a SendDecorator that retries for specified statusCodes for up to the
304// specified number of attempts, exponentially backing off between requests using the supplied backoff
305// time.Duration (which may be zero). To cap the maximum possible delay between iterations specify a value greater
306// than zero for cap. Retrying may be canceled by cancelling the context on the http.Request.
307func DoRetryForStatusCodesWithCap(attempts int, backoff, cap time.Duration, codes ...int) SendDecorator {
308	return func(s Sender) Sender {
309		return SenderFunc(func(r *http.Request) (*http.Response, error) {
310			return doRetryForStatusCodesImpl(s, r, Count429AsRetry, attempts, backoff, cap, codes...)
311		})
312	}
313}
314
315func doRetryForStatusCodesImpl(s Sender, r *http.Request, count429 bool, attempts int, backoff, cap time.Duration, codes ...int) (resp *http.Response, err error) {
316	rr := NewRetriableRequest(r)
317	// Increment to add the first call (attempts denotes number of retries)
318	for attempt, delayCount := 0, 0; attempt < attempts+1; {
319		err = rr.Prepare()
320		if err != nil {
321			return
322		}
323		DrainResponseBody(resp)
324		resp, err = s.Do(rr.Request())
325		// we want to retry if err is not nil (e.g. transient network failure).  note that for failed authentication
326		// resp and err will both have a value, so in this case we don't want to retry as it will never succeed.
327		if err == nil && !ResponseHasStatusCode(resp, codes...) || IsTokenRefreshError(err) {
328			return resp, err
329		}
330		if err != nil {
331			logger.Instance.Writef(logger.LogError, "DoRetryForStatusCodes: received error for attempt %d: %v\n", attempt+1, err)
332		}
333		delayed := DelayWithRetryAfter(resp, r.Context().Done())
334		// if this was a 429 set the delay cap as specified.
335		// applicable only in the absence of a retry-after header.
336		if resp != nil && resp.StatusCode == http.StatusTooManyRequests {
337			cap = Max429Delay
338		}
339		if !delayed && !DelayForBackoffWithCap(backoff, cap, delayCount, r.Context().Done()) {
340			return resp, r.Context().Err()
341		}
342		// when count429 == false don't count a 429 against the number
343		// of attempts so that we continue to retry until it succeeds
344		if count429 || (resp == nil || resp.StatusCode != http.StatusTooManyRequests) {
345			attempt++
346		}
347		// delay count is tracked separately from attempts to
348		// ensure that 429 participates in exponential back-off
349		delayCount++
350	}
351	return resp, err
352}
353
354// DelayWithRetryAfter invokes time.After for the duration specified in the "Retry-After" header.
355// The value of Retry-After can be either the number of seconds or a date in RFC1123 format.
356// The function returns true after successfully waiting for the specified duration.  If there is
357// no Retry-After header or the wait is cancelled the return value is false.
358func DelayWithRetryAfter(resp *http.Response, cancel <-chan struct{}) bool {
359	if resp == nil {
360		return false
361	}
362	var dur time.Duration
363	ra := resp.Header.Get("Retry-After")
364	if retryAfter, _ := strconv.Atoi(ra); retryAfter > 0 {
365		dur = time.Duration(retryAfter) * time.Second
366	} else if t, err := time.Parse(time.RFC1123, ra); err == nil {
367		dur = t.Sub(time.Now())
368	}
369	if dur > 0 {
370		select {
371		case <-time.After(dur):
372			return true
373		case <-cancel:
374			return false
375		}
376	}
377	return false
378}
379
380// DoRetryForDuration returns a SendDecorator that retries the request until the total time is equal
381// to or greater than the specified duration, exponentially backing off between requests using the
382// supplied backoff time.Duration (which may be zero). Retrying may be canceled by closing the
383// optional channel on the http.Request.
384func DoRetryForDuration(d time.Duration, backoff time.Duration) SendDecorator {
385	return func(s Sender) Sender {
386		return SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
387			rr := NewRetriableRequest(r)
388			end := time.Now().Add(d)
389			for attempt := 0; time.Now().Before(end); attempt++ {
390				err = rr.Prepare()
391				if err != nil {
392					return resp, err
393				}
394				DrainResponseBody(resp)
395				resp, err = s.Do(rr.Request())
396				if err == nil {
397					return resp, err
398				}
399				logger.Instance.Writef(logger.LogError, "DoRetryForDuration: received error for attempt %d: %v\n", attempt+1, err)
400				if !DelayForBackoff(backoff, attempt, r.Context().Done()) {
401					return nil, r.Context().Err()
402				}
403			}
404			return resp, err
405		})
406	}
407}
408
409// WithLogging returns a SendDecorator that implements simple before and after logging of the
410// request.
411func WithLogging(logger *log.Logger) SendDecorator {
412	return func(s Sender) Sender {
413		return SenderFunc(func(r *http.Request) (*http.Response, error) {
414			logger.Printf("Sending %s %s", r.Method, r.URL)
415			resp, err := s.Do(r)
416			if err != nil {
417				logger.Printf("%s %s received error '%v'", r.Method, r.URL, err)
418			} else {
419				logger.Printf("%s %s received %s", r.Method, r.URL, resp.Status)
420			}
421			return resp, err
422		})
423	}
424}
425
426// DelayForBackoff invokes time.After for the supplied backoff duration raised to the power of
427// passed attempt (i.e., an exponential backoff delay). Backoff duration is in seconds and can set
428// to zero for no delay. The delay may be canceled by closing the passed channel. If terminated early,
429// returns false.
430// Note: Passing attempt 1 will result in doubling "backoff" duration. Treat this as a zero-based attempt
431// count.
432func DelayForBackoff(backoff time.Duration, attempt int, cancel <-chan struct{}) bool {
433	return DelayForBackoffWithCap(backoff, 0, attempt, cancel)
434}
435
436// DelayForBackoffWithCap invokes time.After for the supplied backoff duration raised to the power of
437// passed attempt (i.e., an exponential backoff delay). Backoff duration is in seconds and can set
438// to zero for no delay. To cap the maximum possible delay specify a value greater than zero for cap.
439// The delay may be canceled by closing the passed channel. If terminated early, returns false.
440// Note: Passing attempt 1 will result in doubling "backoff" duration. Treat this as a zero-based attempt
441// count.
442func DelayForBackoffWithCap(backoff, cap time.Duration, attempt int, cancel <-chan struct{}) bool {
443	d := time.Duration(backoff.Seconds()*math.Pow(2, float64(attempt))) * time.Second
444	if cap > 0 && d > cap {
445		d = cap
446	}
447	logger.Instance.Writef(logger.LogInfo, "DelayForBackoffWithCap: sleeping for %s\n", d)
448	select {
449	case <-time.After(d):
450		return true
451	case <-cancel:
452		return false
453	}
454}
455