1package autorest
2
3// Copyright 2017 Microsoft Corporation
4//
5//  Licensed under the Apache License, Version 2.0 (the "License");
6//  you may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at
8//
9//      http://www.apache.org/licenses/LICENSE-2.0
10//
11//  Unless required by applicable law or agreed to in writing, software
12//  distributed under the License is distributed on an "AS IS" BASIS,
13//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14//  See the License for the specific language governing permissions and
15//  limitations under the License.
16
17import (
18	"fmt"
19	"log"
20	"math"
21	"net/http"
22	"strconv"
23	"time"
24
25	"github.com/Azure/go-autorest/tracing"
26)
27
28// Sender is the interface that wraps the Do method to send HTTP requests.
29//
30// The standard http.Client conforms to this interface.
31type Sender interface {
32	Do(*http.Request) (*http.Response, error)
33}
34
35// SenderFunc is a method that implements the Sender interface.
36type SenderFunc func(*http.Request) (*http.Response, error)
37
38// Do implements the Sender interface on SenderFunc.
39func (sf SenderFunc) Do(r *http.Request) (*http.Response, error) {
40	return sf(r)
41}
42
43// SendDecorator takes and possibly decorates, by wrapping, a Sender. Decorators may affect the
44// http.Request and pass it along or, first, pass the http.Request along then react to the
45// http.Response result.
46type SendDecorator func(Sender) Sender
47
48// CreateSender creates, decorates, and returns, as a Sender, the default http.Client.
49func CreateSender(decorators ...SendDecorator) Sender {
50	return DecorateSender(&http.Client{}, decorators...)
51}
52
53// DecorateSender accepts a Sender and a, possibly empty, set of SendDecorators, which is applies to
54// the Sender. Decorators are applied in the order received, but their affect upon the request
55// depends on whether they are a pre-decorator (change the http.Request and then pass it along) or a
56// post-decorator (pass the http.Request along and react to the results in http.Response).
57func DecorateSender(s Sender, decorators ...SendDecorator) Sender {
58	for _, decorate := range decorators {
59		s = decorate(s)
60	}
61	return s
62}
63
64// Send sends, by means of the default http.Client, the passed http.Request, returning the
65// http.Response and possible error. It also accepts a, possibly empty, set of SendDecorators which
66// it will apply the http.Client before invoking the Do method.
67//
68// Send is a convenience method and not recommended for production. Advanced users should use
69// SendWithSender, passing and sharing their own Sender (e.g., instance of http.Client).
70//
71// Send will not poll or retry requests.
72func Send(r *http.Request, decorators ...SendDecorator) (*http.Response, error) {
73	return SendWithSender(&http.Client{Transport: tracing.Transport}, r, decorators...)
74}
75
76// SendWithSender sends the passed http.Request, through the provided Sender, returning the
77// http.Response and possible error. It also accepts a, possibly empty, set of SendDecorators which
78// it will apply the http.Client before invoking the Do method.
79//
80// SendWithSender will not poll or retry requests.
81func SendWithSender(s Sender, r *http.Request, decorators ...SendDecorator) (*http.Response, error) {
82	return DecorateSender(s, decorators...).Do(r)
83}
84
85// AfterDelay returns a SendDecorator that delays for the passed time.Duration before
86// invoking the Sender. The delay may be terminated by closing the optional channel on the
87// http.Request. If canceled, no further Senders are invoked.
88func AfterDelay(d time.Duration) SendDecorator {
89	return func(s Sender) Sender {
90		return SenderFunc(func(r *http.Request) (*http.Response, error) {
91			if !DelayForBackoff(d, 0, r.Context().Done()) {
92				return nil, fmt.Errorf("autorest: AfterDelay canceled before full delay")
93			}
94			return s.Do(r)
95		})
96	}
97}
98
99// AsIs returns a SendDecorator that invokes the passed Sender without modifying the http.Request.
100func AsIs() SendDecorator {
101	return func(s Sender) Sender {
102		return SenderFunc(func(r *http.Request) (*http.Response, error) {
103			return s.Do(r)
104		})
105	}
106}
107
108// DoCloseIfError returns a SendDecorator that first invokes the passed Sender after which
109// it closes the response if the passed Sender returns an error and the response body exists.
110func DoCloseIfError() SendDecorator {
111	return func(s Sender) Sender {
112		return SenderFunc(func(r *http.Request) (*http.Response, error) {
113			resp, err := s.Do(r)
114			if err != nil {
115				Respond(resp, ByDiscardingBody(), ByClosing())
116			}
117			return resp, err
118		})
119	}
120}
121
122// DoErrorIfStatusCode returns a SendDecorator that emits an error if the response StatusCode is
123// among the set passed. Since these are artificial errors, the response body may still require
124// closing.
125func DoErrorIfStatusCode(codes ...int) SendDecorator {
126	return func(s Sender) Sender {
127		return SenderFunc(func(r *http.Request) (*http.Response, error) {
128			resp, err := s.Do(r)
129			if err == nil && ResponseHasStatusCode(resp, codes...) {
130				err = NewErrorWithResponse("autorest", "DoErrorIfStatusCode", resp, "%v %v failed with %s",
131					resp.Request.Method,
132					resp.Request.URL,
133					resp.Status)
134			}
135			return resp, err
136		})
137	}
138}
139
140// DoErrorUnlessStatusCode returns a SendDecorator that emits an error unless the response
141// StatusCode is among the set passed. Since these are artificial errors, the response body
142// may still require closing.
143func DoErrorUnlessStatusCode(codes ...int) SendDecorator {
144	return func(s Sender) Sender {
145		return SenderFunc(func(r *http.Request) (*http.Response, error) {
146			resp, err := s.Do(r)
147			if err == nil && !ResponseHasStatusCode(resp, codes...) {
148				err = NewErrorWithResponse("autorest", "DoErrorUnlessStatusCode", resp, "%v %v failed with %s",
149					resp.Request.Method,
150					resp.Request.URL,
151					resp.Status)
152			}
153			return resp, err
154		})
155	}
156}
157
158// DoPollForStatusCodes returns a SendDecorator that polls if the http.Response contains one of the
159// passed status codes. It expects the http.Response to contain a Location header providing the
160// URL at which to poll (using GET) and will poll until the time passed is equal to or greater than
161// the supplied duration. It will delay between requests for the duration specified in the
162// RetryAfter header or, if the header is absent, the passed delay. Polling may be canceled by
163// closing the optional channel on the http.Request.
164func DoPollForStatusCodes(duration time.Duration, delay time.Duration, codes ...int) SendDecorator {
165	return func(s Sender) Sender {
166		return SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
167			resp, err = s.Do(r)
168
169			if err == nil && ResponseHasStatusCode(resp, codes...) {
170				r, err = NewPollingRequestWithContext(r.Context(), resp)
171
172				for err == nil && ResponseHasStatusCode(resp, codes...) {
173					Respond(resp,
174						ByDiscardingBody(),
175						ByClosing())
176					resp, err = SendWithSender(s, r,
177						AfterDelay(GetRetryAfter(resp, delay)))
178				}
179			}
180
181			return resp, err
182		})
183	}
184}
185
186// DoRetryForAttempts returns a SendDecorator that retries a failed request for up to the specified
187// number of attempts, exponentially backing off between requests using the supplied backoff
188// time.Duration (which may be zero). Retrying may be canceled by closing the optional channel on
189// the http.Request.
190func DoRetryForAttempts(attempts int, backoff time.Duration) SendDecorator {
191	return func(s Sender) Sender {
192		return SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
193			rr := NewRetriableRequest(r)
194			for attempt := 0; attempt < attempts; attempt++ {
195				err = rr.Prepare()
196				if err != nil {
197					return resp, err
198				}
199				resp, err = s.Do(rr.Request())
200				if err == nil {
201					return resp, err
202				}
203				if !DelayForBackoff(backoff, attempt, r.Context().Done()) {
204					return nil, r.Context().Err()
205				}
206			}
207			return resp, err
208		})
209	}
210}
211
212// DoRetryForStatusCodes returns a SendDecorator that retries for specified statusCodes for up to the specified
213// number of attempts, exponentially backing off between requests using the supplied backoff
214// time.Duration (which may be zero). Retrying may be canceled by closing the optional channel on
215// the http.Request.
216func DoRetryForStatusCodes(attempts int, backoff time.Duration, codes ...int) SendDecorator {
217	return func(s Sender) Sender {
218		return SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
219			rr := NewRetriableRequest(r)
220			// Increment to add the first call (attempts denotes number of retries)
221			for attempt := 0; attempt < attempts+1; {
222				err = rr.Prepare()
223				if err != nil {
224					return resp, err
225				}
226				resp, err = s.Do(rr.Request())
227				// if the error isn't temporary don't bother retrying
228				if err != nil && !IsTemporaryNetworkError(err) {
229					return nil, err
230				}
231				// we want to retry if err is not nil (e.g. transient network failure).  note that for failed authentication
232				// resp and err will both have a value, so in this case we don't want to retry as it will never succeed.
233				if err == nil && !ResponseHasStatusCode(resp, codes...) || IsTokenRefreshError(err) {
234					return resp, err
235				}
236				delayed := DelayWithRetryAfter(resp, r.Context().Done())
237				if !delayed && !DelayForBackoff(backoff, attempt, r.Context().Done()) {
238					return resp, r.Context().Err()
239				}
240				// don't count a 429 against the number of attempts
241				// so that we continue to retry until it succeeds
242				if resp == nil || resp.StatusCode != http.StatusTooManyRequests {
243					attempt++
244				}
245			}
246			return resp, err
247		})
248	}
249}
250
251// DelayWithRetryAfter invokes time.After for the duration specified in the "Retry-After" header in
252// responses with status code 429
253func DelayWithRetryAfter(resp *http.Response, cancel <-chan struct{}) bool {
254	if resp == nil {
255		return false
256	}
257	retryAfter, _ := strconv.Atoi(resp.Header.Get("Retry-After"))
258	if resp.StatusCode == http.StatusTooManyRequests && retryAfter > 0 {
259		select {
260		case <-time.After(time.Duration(retryAfter) * time.Second):
261			return true
262		case <-cancel:
263			return false
264		}
265	}
266	return false
267}
268
269// DoRetryForDuration returns a SendDecorator that retries the request until the total time is equal
270// to or greater than the specified duration, exponentially backing off between requests using the
271// supplied backoff time.Duration (which may be zero). Retrying may be canceled by closing the
272// optional channel on the http.Request.
273func DoRetryForDuration(d time.Duration, backoff time.Duration) SendDecorator {
274	return func(s Sender) Sender {
275		return SenderFunc(func(r *http.Request) (resp *http.Response, err error) {
276			rr := NewRetriableRequest(r)
277			end := time.Now().Add(d)
278			for attempt := 0; time.Now().Before(end); attempt++ {
279				err = rr.Prepare()
280				if err != nil {
281					return resp, err
282				}
283				resp, err = s.Do(rr.Request())
284				if err == nil {
285					return resp, err
286				}
287				if !DelayForBackoff(backoff, attempt, r.Context().Done()) {
288					return nil, r.Context().Err()
289				}
290			}
291			return resp, err
292		})
293	}
294}
295
296// WithLogging returns a SendDecorator that implements simple before and after logging of the
297// request.
298func WithLogging(logger *log.Logger) SendDecorator {
299	return func(s Sender) Sender {
300		return SenderFunc(func(r *http.Request) (*http.Response, error) {
301			logger.Printf("Sending %s %s", r.Method, r.URL)
302			resp, err := s.Do(r)
303			if err != nil {
304				logger.Printf("%s %s received error '%v'", r.Method, r.URL, err)
305			} else {
306				logger.Printf("%s %s received %s", r.Method, r.URL, resp.Status)
307			}
308			return resp, err
309		})
310	}
311}
312
313// DelayForBackoff invokes time.After for the supplied backoff duration raised to the power of
314// passed attempt (i.e., an exponential backoff delay). Backoff duration is in seconds and can set
315// to zero for no delay. The delay may be canceled by closing the passed channel. If terminated early,
316// returns false.
317// Note: Passing attempt 1 will result in doubling "backoff" duration. Treat this as a zero-based attempt
318// count.
319func DelayForBackoff(backoff time.Duration, attempt int, cancel <-chan struct{}) bool {
320	select {
321	case <-time.After(time.Duration(backoff.Seconds()*math.Pow(2, float64(attempt))) * time.Second):
322		return true
323	case <-cancel:
324		return false
325	}
326}
327