1// The retryablehttp package provides a familiar HTTP client interface with
2// automatic retries and exponential backoff. It is a thin wrapper over the
3// standard net/http client library and exposes nearly the same public API.
4// This makes retryablehttp very easy to drop into existing programs.
5//
6// retryablehttp performs automatic retries under certain conditions. Mainly, if
7// an error is returned by the client (connection errors etc), or if a 500-range
8// response is received, then a retry is invoked. Otherwise, the response is
9// returned and left to the caller to interpret.
10//
11// Requests which take a request body should provide a non-nil function
12// parameter. The best choice is to provide either a function satisfying
13// ReaderFunc which provides multiple io.Readers in an efficient manner, a
14// *bytes.Buffer (the underlying raw byte slice will be used) or a raw byte
15// slice. As it is a reference type, and we will wrap it as needed by readers,
16// we can efficiently re-use the request body without needing to copy it. If an
17// io.Reader (such as a *bytes.Reader) is provided, the full body will be read
18// prior to the first request, and will be efficiently re-used for any retries.
19// ReadSeeker can be used, but some users have observed occasional data races
20// between the net/http library and the Seek functionality of some
21// implementations of ReadSeeker, so should be avoided if possible.
22package retryablehttp
23
24import (
25	"bytes"
26	"context"
27	"fmt"
28	"io"
29	"io/ioutil"
30	"log"
31	"math"
32	"math/rand"
33	"net/http"
34	"net/url"
35	"os"
36	"strings"
37	"time"
38
39	cleanhttp "github.com/hashicorp/go-cleanhttp"
40)
41
42var (
43	// Default retry configuration
44	defaultRetryWaitMin = 1 * time.Second
45	defaultRetryWaitMax = 30 * time.Second
46	defaultRetryMax     = 4
47
48	// defaultClient is used for performing requests without explicitly making
49	// a new client. It is purposely private to avoid modifications.
50	defaultClient = NewClient()
51
52	// We need to consume response bodies to maintain http connections, but
53	// limit the size we consume to respReadLimit.
54	respReadLimit = int64(4096)
55)
56
57// ReaderFunc is the type of function that can be given natively to NewRequest
58type ReaderFunc func() (io.Reader, error)
59
60// LenReader is an interface implemented by many in-memory io.Reader's. Used
61// for automatically sending the right Content-Length header when possible.
62type LenReader interface {
63	Len() int
64}
65
66// Request wraps the metadata needed to create HTTP requests.
67type Request struct {
68	// body is a seekable reader over the request body payload. This is
69	// used to rewind the request data in between retries.
70	body ReaderFunc
71
72	// Embed an HTTP request directly. This makes a *Request act exactly
73	// like an *http.Request so that all meta methods are supported.
74	*http.Request
75}
76
77// WithContext returns wrapped Request with a shallow copy of underlying *http.Request
78// with its context changed to ctx. The provided ctx must be non-nil.
79func (r *Request) WithContext(ctx context.Context) *Request {
80	r.Request = r.Request.WithContext(ctx)
81	return r
82}
83
84// BodyBytes allows accessing the request body. It is an analogue to
85// http.Request's Body variable, but it returns a copy of the underlying data
86// rather than consuming it.
87//
88// This function is not thread-safe; do not call it at the same time as another
89// call, or at the same time this request is being used with Client.Do.
90func (r *Request) BodyBytes() ([]byte, error) {
91	if r.body == nil {
92		return nil, nil
93	}
94	body, err := r.body()
95	if err != nil {
96		return nil, err
97	}
98	buf := new(bytes.Buffer)
99	_, err = buf.ReadFrom(body)
100	if err != nil {
101		return nil, err
102	}
103	return buf.Bytes(), nil
104}
105
106func getBodyReaderAndContentLength(rawBody interface{}) (ReaderFunc, int64, error) {
107	var bodyReader ReaderFunc
108	var contentLength int64
109
110	if rawBody != nil {
111		switch body := rawBody.(type) {
112		// If they gave us a function already, great! Use it.
113		case ReaderFunc:
114			bodyReader = body
115			tmp, err := body()
116			if err != nil {
117				return nil, 0, err
118			}
119			if lr, ok := tmp.(LenReader); ok {
120				contentLength = int64(lr.Len())
121			}
122			if c, ok := tmp.(io.Closer); ok {
123				c.Close()
124			}
125
126		case func() (io.Reader, error):
127			bodyReader = body
128			tmp, err := body()
129			if err != nil {
130				return nil, 0, err
131			}
132			if lr, ok := tmp.(LenReader); ok {
133				contentLength = int64(lr.Len())
134			}
135			if c, ok := tmp.(io.Closer); ok {
136				c.Close()
137			}
138
139		// If a regular byte slice, we can read it over and over via new
140		// readers
141		case []byte:
142			buf := body
143			bodyReader = func() (io.Reader, error) {
144				return bytes.NewReader(buf), nil
145			}
146			contentLength = int64(len(buf))
147
148		// If a bytes.Buffer we can read the underlying byte slice over and
149		// over
150		case *bytes.Buffer:
151			buf := body
152			bodyReader = func() (io.Reader, error) {
153				return bytes.NewReader(buf.Bytes()), nil
154			}
155			contentLength = int64(buf.Len())
156
157		// We prioritize *bytes.Reader here because we don't really want to
158		// deal with it seeking so want it to match here instead of the
159		// io.ReadSeeker case.
160		case *bytes.Reader:
161			buf, err := ioutil.ReadAll(body)
162			if err != nil {
163				return nil, 0, err
164			}
165			bodyReader = func() (io.Reader, error) {
166				return bytes.NewReader(buf), nil
167			}
168			contentLength = int64(len(buf))
169
170		// Compat case
171		case io.ReadSeeker:
172			raw := body
173			bodyReader = func() (io.Reader, error) {
174				_, err := raw.Seek(0, 0)
175				return ioutil.NopCloser(raw), err
176			}
177			if lr, ok := raw.(LenReader); ok {
178				contentLength = int64(lr.Len())
179			}
180
181		// Read all in so we can reset
182		case io.Reader:
183			buf, err := ioutil.ReadAll(body)
184			if err != nil {
185				return nil, 0, err
186			}
187			bodyReader = func() (io.Reader, error) {
188				return bytes.NewReader(buf), nil
189			}
190			contentLength = int64(len(buf))
191
192		default:
193			return nil, 0, fmt.Errorf("cannot handle type %T", rawBody)
194		}
195	}
196	return bodyReader, contentLength, nil
197}
198
199// FromRequest wraps an http.Request in a retryablehttp.Request
200func FromRequest(r *http.Request) (*Request, error) {
201	bodyReader, _, err := getBodyReaderAndContentLength(r.Body)
202	if err != nil {
203		return nil, err
204	}
205	// Could assert contentLength == r.ContentLength
206	return &Request{bodyReader, r}, nil
207}
208
209// NewRequest creates a new wrapped request.
210func NewRequest(method, url string, rawBody interface{}) (*Request, error) {
211	bodyReader, contentLength, err := getBodyReaderAndContentLength(rawBody)
212	if err != nil {
213		return nil, err
214	}
215
216	httpReq, err := http.NewRequest(method, url, nil)
217	if err != nil {
218		return nil, err
219	}
220	httpReq.ContentLength = contentLength
221
222	return &Request{bodyReader, httpReq}, nil
223}
224
225// Logger interface allows to use other loggers than
226// standard log.Logger.
227type Logger interface {
228	Printf(string, ...interface{})
229}
230
231// RequestLogHook allows a function to run before each retry. The HTTP
232// request which will be made, and the retry number (0 for the initial
233// request) are available to users. The internal logger is exposed to
234// consumers.
235type RequestLogHook func(Logger, *http.Request, int)
236
237// ResponseLogHook is like RequestLogHook, but allows running a function
238// on each HTTP response. This function will be invoked at the end of
239// every HTTP request executed, regardless of whether a subsequent retry
240// needs to be performed or not. If the response body is read or closed
241// from this method, this will affect the response returned from Do().
242type ResponseLogHook func(Logger, *http.Response)
243
244// CheckRetry specifies a policy for handling retries. It is called
245// following each request with the response and error values returned by
246// the http.Client. If CheckRetry returns false, the Client stops retrying
247// and returns the response to the caller. If CheckRetry returns an error,
248// that error value is returned in lieu of the error from the request. The
249// Client will close any response body when retrying, but if the retry is
250// aborted it is up to the CheckResponse callback to properly close any
251// response body before returning.
252type CheckRetry func(ctx context.Context, resp *http.Response, err error) (bool, error)
253
254// Backoff specifies a policy for how long to wait between retries.
255// It is called after a failing request to determine the amount of time
256// that should pass before trying again.
257type Backoff func(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration
258
259// ErrorHandler is called if retries are expired, containing the last status
260// from the http library. If not specified, default behavior for the library is
261// to close the body and return an error indicating how many tries were
262// attempted. If overriding this, be sure to close the body if needed.
263type ErrorHandler func(resp *http.Response, err error, numTries int) (*http.Response, error)
264
265// Client is used to make HTTP requests. It adds additional functionality
266// like automatic retries to tolerate minor outages.
267type Client struct {
268	HTTPClient *http.Client // Internal HTTP client.
269	Logger     Logger       // Customer logger instance.
270
271	RetryWaitMin time.Duration // Minimum time to wait
272	RetryWaitMax time.Duration // Maximum time to wait
273	RetryMax     int           // Maximum number of retries
274
275	// RequestLogHook allows a user-supplied function to be called
276	// before each retry.
277	RequestLogHook RequestLogHook
278
279	// ResponseLogHook allows a user-supplied function to be called
280	// with the response from each HTTP request executed.
281	ResponseLogHook ResponseLogHook
282
283	// CheckRetry specifies the policy for handling retries, and is called
284	// after each request. The default policy is DefaultRetryPolicy.
285	CheckRetry CheckRetry
286
287	// Backoff specifies the policy for how long to wait between retries
288	Backoff Backoff
289
290	// ErrorHandler specifies the custom error handler to use, if any
291	ErrorHandler ErrorHandler
292}
293
294// NewClient creates a new Client with default settings.
295func NewClient() *Client {
296	return &Client{
297		HTTPClient:   cleanhttp.DefaultClient(),
298		Logger:       log.New(os.Stderr, "", log.LstdFlags),
299		RetryWaitMin: defaultRetryWaitMin,
300		RetryWaitMax: defaultRetryWaitMax,
301		RetryMax:     defaultRetryMax,
302		CheckRetry:   DefaultRetryPolicy,
303		Backoff:      DefaultBackoff,
304	}
305}
306
307// DefaultRetryPolicy provides a default callback for Client.CheckRetry, which
308// will retry on connection errors and server errors.
309func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) {
310	// do not retry on context.Canceled or context.DeadlineExceeded
311	if ctx.Err() != nil {
312		return false, ctx.Err()
313	}
314
315	if err != nil {
316		return true, err
317	}
318	// Check the response code. We retry on 500-range responses to allow
319	// the server time to recover, as 500's are typically not permanent
320	// errors and may relate to outages on the server side. This will catch
321	// invalid response codes as well, like 0 and 999.
322	if resp.StatusCode == 0 || (resp.StatusCode >= 500 && resp.StatusCode != 501) {
323		return true, nil
324	}
325
326	return false, nil
327}
328
329// DefaultBackoff provides a default callback for Client.Backoff which
330// will perform exponential backoff based on the attempt number and limited
331// by the provided minimum and maximum durations.
332func DefaultBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration {
333	mult := math.Pow(2, float64(attemptNum)) * float64(min)
334	sleep := time.Duration(mult)
335	if float64(sleep) != mult || sleep > max {
336		sleep = max
337	}
338	return sleep
339}
340
341// LinearJitterBackoff provides a callback for Client.Backoff which will
342// perform linear backoff based on the attempt number and with jitter to
343// prevent a thundering herd.
344//
345// min and max here are *not* absolute values. The number to be multipled by
346// the attempt number will be chosen at random from between them, thus they are
347// bounding the jitter.
348//
349// For instance:
350// * To get strictly linear backoff of one second increasing each retry, set
351// both to one second (1s, 2s, 3s, 4s, ...)
352// * To get a small amount of jitter centered around one second increasing each
353// retry, set to around one second, such as a min of 800ms and max of 1200ms
354// (892ms, 2102ms, 2945ms, 4312ms, ...)
355// * To get extreme jitter, set to a very wide spread, such as a min of 100ms
356// and a max of 20s (15382ms, 292ms, 51321ms, 35234ms, ...)
357func LinearJitterBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration {
358	// attemptNum always starts at zero but we want to start at 1 for multiplication
359	attemptNum++
360
361	if max <= min {
362		// Unclear what to do here, or they are the same, so return min *
363		// attemptNum
364		return min * time.Duration(attemptNum)
365	}
366
367	// Seed rand; doing this every time is fine
368	rand := rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
369
370	// Pick a random number that lies somewhere between the min and max and
371	// multiply by the attemptNum. attemptNum starts at zero so we always
372	// increment here. We first get a random percentage, then apply that to the
373	// difference between min and max, and add to min.
374	jitter := rand.Float64() * float64(max-min)
375	jitterMin := int64(jitter) + int64(min)
376	return time.Duration(jitterMin * int64(attemptNum))
377}
378
379// PassthroughErrorHandler is an ErrorHandler that directly passes through the
380// values from the net/http library for the final request. The body is not
381// closed.
382func PassthroughErrorHandler(resp *http.Response, err error, _ int) (*http.Response, error) {
383	return resp, err
384}
385
386// Do wraps calling an HTTP method with retries.
387func (c *Client) Do(req *Request) (*http.Response, error) {
388	if c.Logger != nil {
389		c.Logger.Printf("[DEBUG] %s %s", req.Method, req.URL)
390	}
391
392	var resp *http.Response
393	var err error
394
395	for i := 0; ; i++ {
396		var code int // HTTP response code
397
398		// Always rewind the request body when non-nil.
399		if req.body != nil {
400			body, err := req.body()
401			if err != nil {
402				return resp, err
403			}
404			if c, ok := body.(io.ReadCloser); ok {
405				req.Body = c
406			} else {
407				req.Body = ioutil.NopCloser(body)
408			}
409		}
410
411		if c.RequestLogHook != nil {
412			c.RequestLogHook(c.Logger, req.Request, i)
413		}
414
415		// Attempt the request
416		resp, err = c.HTTPClient.Do(req.Request)
417		if resp != nil {
418			code = resp.StatusCode
419		}
420
421		// Check if we should continue with retries.
422		checkOK, checkErr := c.CheckRetry(req.Context(), resp, err)
423
424		if err != nil {
425			if c.Logger != nil {
426				c.Logger.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, err)
427			}
428		} else {
429			// Call this here to maintain the behavior of logging all requests,
430			// even if CheckRetry signals to stop.
431			if c.ResponseLogHook != nil {
432				// Call the response logger function if provided.
433				c.ResponseLogHook(c.Logger, resp)
434			}
435		}
436
437		// Now decide if we should continue.
438		if !checkOK {
439			if checkErr != nil {
440				err = checkErr
441			}
442			return resp, err
443		}
444
445		// We do this before drainBody beause there's no need for the I/O if
446		// we're breaking out
447		remain := c.RetryMax - i
448		if remain <= 0 {
449			break
450		}
451
452		// We're going to retry, consume any response to reuse the connection.
453		if err == nil && resp != nil {
454			c.drainBody(resp.Body)
455		}
456
457		wait := c.Backoff(c.RetryWaitMin, c.RetryWaitMax, i, resp)
458		desc := fmt.Sprintf("%s %s", req.Method, req.URL)
459		if code > 0 {
460			desc = fmt.Sprintf("%s (status: %d)", desc, code)
461		}
462		if c.Logger != nil {
463			c.Logger.Printf("[DEBUG] %s: retrying in %s (%d left)", desc, wait, remain)
464		}
465		select {
466		case <-req.Context().Done():
467			return nil, req.Context().Err()
468		case <-time.After(wait):
469		}
470	}
471
472	if c.ErrorHandler != nil {
473		return c.ErrorHandler(resp, err, c.RetryMax+1)
474	}
475
476	// By default, we close the response body and return an error without
477	// returning the response
478	if resp != nil {
479		resp.Body.Close()
480	}
481	return nil, fmt.Errorf("%s %s giving up after %d attempts",
482		req.Method, req.URL, c.RetryMax+1)
483}
484
485// Try to read the response body so we can reuse this connection.
486func (c *Client) drainBody(body io.ReadCloser) {
487	defer body.Close()
488	_, err := io.Copy(ioutil.Discard, io.LimitReader(body, respReadLimit))
489	if err != nil {
490		if c.Logger != nil {
491			c.Logger.Printf("[ERR] error reading response body: %v", err)
492		}
493	}
494}
495
496// Get is a shortcut for doing a GET request without making a new client.
497func Get(url string) (*http.Response, error) {
498	return defaultClient.Get(url)
499}
500
501// Get is a convenience helper for doing simple GET requests.
502func (c *Client) Get(url string) (*http.Response, error) {
503	req, err := NewRequest("GET", url, nil)
504	if err != nil {
505		return nil, err
506	}
507	return c.Do(req)
508}
509
510// Head is a shortcut for doing a HEAD request without making a new client.
511func Head(url string) (*http.Response, error) {
512	return defaultClient.Head(url)
513}
514
515// Head is a convenience method for doing simple HEAD requests.
516func (c *Client) Head(url string) (*http.Response, error) {
517	req, err := NewRequest("HEAD", url, nil)
518	if err != nil {
519		return nil, err
520	}
521	return c.Do(req)
522}
523
524// Post is a shortcut for doing a POST request without making a new client.
525func Post(url, bodyType string, body interface{}) (*http.Response, error) {
526	return defaultClient.Post(url, bodyType, body)
527}
528
529// Post is a convenience method for doing simple POST requests.
530func (c *Client) Post(url, bodyType string, body interface{}) (*http.Response, error) {
531	req, err := NewRequest("POST", url, body)
532	if err != nil {
533		return nil, err
534	}
535	req.Header.Set("Content-Type", bodyType)
536	return c.Do(req)
537}
538
539// PostForm is a shortcut to perform a POST with form data without creating
540// a new client.
541func PostForm(url string, data url.Values) (*http.Response, error) {
542	return defaultClient.PostForm(url, data)
543}
544
545// PostForm is a convenience method for doing simple POST operations using
546// pre-filled url.Values form data.
547func (c *Client) PostForm(url string, data url.Values) (*http.Response, error) {
548	return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
549}
550