1// Package pester provides additional resiliency over the standard http client methods by
2// allowing you to control concurrency, retries, and a backoff strategy.
3package pester
4
5import (
6	"bytes"
7	"errors"
8	"fmt"
9	"io"
10	"io/ioutil"
11	"math/rand"
12	"net/http"
13	"net/url"
14	"sync"
15	"time"
16)
17
18//ErrUnexpectedMethod occurs when an http.Client method is unable to be mapped from a calling method in the pester client
19var ErrUnexpectedMethod = errors.New("unexpected client method, must be one of Do, Get, Head, Post, or PostFrom")
20
21// ErrReadingBody happens when we cannot read the body bytes
22var ErrReadingBody = errors.New("error reading body")
23
24// ErrReadingRequestBody happens when we cannot read the request body bytes
25var ErrReadingRequestBody = errors.New("error reading request body")
26
27// Client wraps the http client and exposes all the functionality of the http.Client.
28// Additionally, Client provides pester specific values for handling resiliency.
29type Client struct {
30	// wrap it to provide access to http built ins
31	hc *http.Client
32
33	Transport     http.RoundTripper
34	CheckRedirect func(req *http.Request, via []*http.Request) error
35	Jar           http.CookieJar
36	Timeout       time.Duration
37
38	// pester specific
39	Concurrency int
40	MaxRetries  int
41	Backoff     BackoffStrategy
42	KeepLog     bool
43	LogHook     LogHook
44
45	SuccessReqNum   int
46	SuccessRetryNum int
47
48	wg *sync.WaitGroup
49
50	sync.Mutex
51	ErrLog         []ErrEntry
52	RetryOnHTTP429 bool
53}
54
55// ErrEntry is used to provide the LogString() data and is populated
56// each time an error happens if KeepLog is set.
57// ErrEntry.Retry is deprecated in favor of ErrEntry.Attempt
58type ErrEntry struct {
59	Time    time.Time
60	Method  string
61	URL     string
62	Verb    string
63	Request int
64	Retry   int
65	Attempt int
66	Err     error
67}
68
69// result simplifies the channel communication for concurrent request handling
70type result struct {
71	resp  *http.Response
72	err   error
73	req   int
74	retry int
75}
76
77// params represents all the params needed to run http client calls and pester errors
78type params struct {
79	method   string
80	verb     string
81	req      *http.Request
82	url      string
83	bodyType string
84	body     io.Reader
85	data     url.Values
86}
87
88var random *rand.Rand
89
90func init() {
91	random = rand.New(rand.NewSource(time.Now().UnixNano()))
92}
93
94// New constructs a new DefaultClient with sensible default values
95func New() *Client {
96	return &Client{
97		Concurrency:    DefaultClient.Concurrency,
98		MaxRetries:     DefaultClient.MaxRetries,
99		Backoff:        DefaultClient.Backoff,
100		ErrLog:         DefaultClient.ErrLog,
101		wg:             &sync.WaitGroup{},
102		RetryOnHTTP429: false,
103	}
104}
105
106// NewExtendedClient allows you to pass in an http.Client that is previously set up
107// and extends it to have Pester's features of concurrency and retries.
108func NewExtendedClient(hc *http.Client) *Client {
109	c := New()
110	c.hc = hc
111	return c
112}
113
114// LogHook is used to log attempts as they happen. This function is never called,
115// however, if KeepLog is set to true.
116type LogHook func(e ErrEntry)
117
118// BackoffStrategy is used to determine how long a retry request should wait until attempted
119type BackoffStrategy func(retry int) time.Duration
120
121// DefaultClient provides sensible defaults
122var DefaultClient = &Client{Concurrency: 1, MaxRetries: 3, Backoff: DefaultBackoff, ErrLog: []ErrEntry{}}
123
124// DefaultBackoff always returns 1 second
125func DefaultBackoff(_ int) time.Duration {
126	return 1 * time.Second
127}
128
129// ExponentialBackoff returns ever increasing backoffs by a power of 2
130func ExponentialBackoff(i int) time.Duration {
131	return time.Duration(1<<uint(i)) * time.Second
132}
133
134// ExponentialJitterBackoff returns ever increasing backoffs by a power of 2
135// with +/- 0-33% to prevent sychronized reuqests.
136func ExponentialJitterBackoff(i int) time.Duration {
137	return jitter(int(1 << uint(i)))
138}
139
140// LinearBackoff returns increasing durations, each a second longer than the last
141func LinearBackoff(i int) time.Duration {
142	return time.Duration(i) * time.Second
143}
144
145// LinearJitterBackoff returns increasing durations, each a second longer than the last
146// with +/- 0-33% to prevent sychronized reuqests.
147func LinearJitterBackoff(i int) time.Duration {
148	return jitter(i)
149}
150
151// jitter keeps the +/- 0-33% logic in one place
152func jitter(i int) time.Duration {
153	ms := i * 1000
154
155	maxJitter := ms / 3
156
157	// ms ± rand
158	ms += random.Intn(2*maxJitter) - maxJitter
159
160	// a jitter of 0 messes up the time.Tick chan
161	if ms <= 0 {
162		ms = 1
163	}
164
165	return time.Duration(ms) * time.Millisecond
166}
167
168// Wait blocks until all pester requests have returned
169// Probably not that useful outside of testing.
170func (c *Client) Wait() {
171	c.wg.Wait()
172}
173
174// pester provides all the logic of retries, concurrency, backoff, and logging
175func (c *Client) pester(p params) (*http.Response, error) {
176	resultCh := make(chan result)
177	multiplexCh := make(chan result)
178	finishCh := make(chan struct{})
179
180	// track all requests that go out so we can close the late listener routine that closes late incoming response bodies
181	totalSentRequests := &sync.WaitGroup{}
182	totalSentRequests.Add(1)
183	defer totalSentRequests.Done()
184	allRequestsBackCh := make(chan struct{})
185	go func() {
186		totalSentRequests.Wait()
187		close(allRequestsBackCh)
188	}()
189
190	// GET calls should be idempotent and can make use
191	// of concurrency. Other verbs can mutate and should not
192	// make use of the concurrency feature
193	concurrency := c.Concurrency
194	if p.verb != "GET" {
195		concurrency = 1
196	}
197
198	c.Lock()
199	if c.hc == nil {
200		c.hc = &http.Client{}
201		c.hc.Transport = c.Transport
202		c.hc.CheckRedirect = c.CheckRedirect
203		c.hc.Jar = c.Jar
204		c.hc.Timeout = c.Timeout
205	}
206	c.Unlock()
207
208	// re-create the http client so we can leverage the std lib
209	httpClient := http.Client{
210		Transport:     c.hc.Transport,
211		CheckRedirect: c.hc.CheckRedirect,
212		Jar:           c.hc.Jar,
213		Timeout:       c.hc.Timeout,
214	}
215
216	// if we have a request body, we need to save it for later
217	var originalRequestBody []byte
218	var originalBody []byte
219	var err error
220	if p.req != nil && p.req.Body != nil {
221		originalRequestBody, err = ioutil.ReadAll(p.req.Body)
222		if err != nil {
223			return nil, ErrReadingRequestBody
224		}
225		p.req.Body.Close()
226	}
227	if p.body != nil {
228		originalBody, err = ioutil.ReadAll(p.body)
229		if err != nil {
230			return nil, ErrReadingBody
231		}
232	}
233
234	AttemptLimit := c.MaxRetries
235	if AttemptLimit <= 0 {
236		AttemptLimit = 1
237	}
238
239	for req := 0; req < concurrency; req++ {
240		c.wg.Add(1)
241		totalSentRequests.Add(1)
242		go func(n int, p params) {
243			defer c.wg.Done()
244			defer totalSentRequests.Done()
245
246			var err error
247			for i := 1; i <= AttemptLimit; i++ {
248				c.wg.Add(1)
249				defer c.wg.Done()
250				select {
251				case <-finishCh:
252					return
253				default:
254				}
255
256				// rehydrate the body (it is drained each read)
257				if len(originalRequestBody) > 0 {
258					p.req.Body = ioutil.NopCloser(bytes.NewBuffer(originalRequestBody))
259				}
260				if len(originalBody) > 0 {
261					p.body = bytes.NewBuffer(originalBody)
262				}
263
264				var resp *http.Response
265				// route the calls
266				switch p.method {
267				case "Do":
268					resp, err = httpClient.Do(p.req)
269				case "Get":
270					resp, err = httpClient.Get(p.url)
271				case "Head":
272					resp, err = httpClient.Head(p.url)
273				case "Post":
274					resp, err = httpClient.Post(p.url, p.bodyType, p.body)
275				case "PostForm":
276					resp, err = httpClient.PostForm(p.url, p.data)
277				default:
278					err = ErrUnexpectedMethod
279				}
280
281				// Early return if we have a valid result
282				// Only retry (ie, continue the loop) on 5xx status codes and 429
283
284				if err == nil && resp.StatusCode < 500 && (resp.StatusCode != 429 || (resp.StatusCode == 429 && !c.RetryOnHTTP429)) {
285					multiplexCh <- result{resp: resp, err: err, req: n, retry: i}
286					return
287				}
288
289				c.log(ErrEntry{
290					Time:    time.Now(),
291					Method:  p.method,
292					Verb:    p.verb,
293					URL:     p.url,
294					Request: n,
295					Retry:   i + 1, // would remove, but would break backward compatibility
296					Attempt: i,
297					Err:     err,
298				})
299
300				// if it is the last iteration, grab the result (which is an error at this point)
301				if i == AttemptLimit {
302					multiplexCh <- result{resp: resp, err: err}
303					return
304				}
305
306				//If the request has been cancelled, skip retries
307				if p.req != nil {
308					ctx := p.req.Context()
309					select {
310					case <-ctx.Done():
311						multiplexCh <- result{resp: resp, err: ctx.Err()}
312						return
313					default:
314					}
315				}
316
317				// if we are retrying, we should close this response body to free the fd
318				if resp != nil {
319					resp.Body.Close()
320				}
321
322				// prevent a 0 from causing the tick to block, pass additional microsecond
323				<-time.After(c.Backoff(i) + 1*time.Microsecond)
324			}
325		}(req, p)
326	}
327
328	// spin off the go routine so it can continually listen in on late results and close the response bodies
329	go func() {
330		gotFirstResult := false
331		for {
332			select {
333			case res := <-multiplexCh:
334				if !gotFirstResult {
335					gotFirstResult = true
336					close(finishCh)
337					resultCh <- res
338				} else if res.resp != nil {
339					// we only return one result to the caller; close all other response bodies that come back
340					// drain the body before close as to not prevent keepalive. see https://gist.github.com/mholt/eba0f2cc96658be0f717
341					io.Copy(ioutil.Discard, res.resp.Body)
342					res.resp.Body.Close()
343				}
344			case <-allRequestsBackCh:
345				// don't leave this goroutine running
346				return
347			}
348		}
349	}()
350
351	res := <-resultCh
352	c.Lock()
353	defer c.Unlock()
354	c.SuccessReqNum = res.req
355	c.SuccessRetryNum = res.retry
356	return res.resp, res.err
357
358}
359
360// LogString provides a string representation of the errors the client has seen
361func (c *Client) LogString() string {
362	c.Lock()
363	defer c.Unlock()
364	var res string
365	for _, e := range c.ErrLog {
366		res += c.FormatError(e)
367	}
368	return res
369}
370
371// Format the Error to human readable string
372func (c *Client) FormatError(e ErrEntry) string {
373	return fmt.Sprintf("%d %s [%s] %s request-%d retry-%d error: %s\n",
374		e.Time.Unix(), e.Method, e.Verb, e.URL, e.Request, e.Retry, e.Err)
375}
376
377// LogErrCount is a helper method used primarily for test validation
378func (c *Client) LogErrCount() int {
379	c.Lock()
380	defer c.Unlock()
381	return len(c.ErrLog)
382}
383
384// EmbedHTTPClient allows you to extend an existing Pester client with an
385// underlying http.Client, such as https://godoc.org/golang.org/x/oauth2/google#DefaultClient
386func (c *Client) EmbedHTTPClient(hc *http.Client) {
387	c.hc = hc
388}
389
390func (c *Client) log(e ErrEntry) {
391	if c.KeepLog {
392		c.Lock()
393		defer c.Unlock()
394		c.ErrLog = append(c.ErrLog, e)
395	} else if c.LogHook != nil {
396		// NOTE: There is a possibility that Log Printing hook slows it down.
397		// but the consumer can always do the Job in a go-routine.
398		c.LogHook(e)
399	}
400}
401
402// Do provides the same functionality as http.Client.Do
403func (c *Client) Do(req *http.Request) (resp *http.Response, err error) {
404	return c.pester(params{method: "Do", req: req, verb: req.Method, url: req.URL.String()})
405}
406
407// Get provides the same functionality as http.Client.Get
408func (c *Client) Get(url string) (resp *http.Response, err error) {
409	return c.pester(params{method: "Get", url: url, verb: "GET"})
410}
411
412// Head provides the same functionality as http.Client.Head
413func (c *Client) Head(url string) (resp *http.Response, err error) {
414	return c.pester(params{method: "Head", url: url, verb: "HEAD"})
415}
416
417// Post provides the same functionality as http.Client.Post
418func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *http.Response, err error) {
419	return c.pester(params{method: "Post", url: url, bodyType: bodyType, body: body, verb: "POST"})
420}
421
422// PostForm provides the same functionality as http.Client.PostForm
423func (c *Client) PostForm(url string, data url.Values) (resp *http.Response, err error) {
424	return c.pester(params{method: "PostForm", url: url, data: data, verb: "POST"})
425}
426
427// set RetryOnHTTP429 for clients,
428func (c *Client) SetRetryOnHTTP429(flag bool) {
429	c.RetryOnHTTP429 = flag
430}
431
432////////////////////////////////////////
433// Provide self-constructing variants //
434////////////////////////////////////////
435
436// Do provides the same functionality as http.Client.Do and creates its own constructor
437func Do(req *http.Request) (resp *http.Response, err error) {
438	c := New()
439	return c.Do(req)
440}
441
442// Get provides the same functionality as http.Client.Get and creates its own constructor
443func Get(url string) (resp *http.Response, err error) {
444	c := New()
445	return c.Get(url)
446}
447
448// Head provides the same functionality as http.Client.Head and creates its own constructor
449func Head(url string) (resp *http.Response, err error) {
450	c := New()
451	return c.Head(url)
452}
453
454// Post provides the same functionality as http.Client.Post and creates its own constructor
455func Post(url string, bodyType string, body io.Reader) (resp *http.Response, err error) {
456	c := New()
457	return c.Post(url, bodyType, body)
458}
459
460// PostForm provides the same functionality as http.Client.PostForm and creates its own constructor
461func PostForm(url string, data url.Values) (resp *http.Response, err error) {
462	c := New()
463	return c.PostForm(url, data)
464}
465