1package vegeta
2
3import (
4	"context"
5	"crypto/tls"
6	"fmt"
7	"io"
8	"io/ioutil"
9	"math"
10	"net"
11	"net/http"
12	"net/url"
13	"strconv"
14	"sync"
15	"time"
16
17	"golang.org/x/net/http2"
18)
19
20// Attacker is an attack executor which wraps an http.Client
21type Attacker struct {
22	dialer     *net.Dialer
23	client     http.Client
24	stopch     chan struct{}
25	workers    uint64
26	maxWorkers uint64
27	maxBody    int64
28	redirects  int
29	seqmu      sync.Mutex
30	seq        uint64
31	began      time.Time
32	chunked    bool
33}
34
35const (
36	// DefaultRedirects is the default number of times an Attacker follows
37	// redirects.
38	DefaultRedirects = 10
39	// DefaultTimeout is the default amount of time an Attacker waits for a request
40	// before it times out.
41	DefaultTimeout = 30 * time.Second
42	// DefaultConnections is the default amount of max open idle connections per
43	// target host.
44	DefaultConnections = 10000
45	// DefaultMaxConnections is the default amount of connections per target
46	// host.
47	DefaultMaxConnections = 0
48	// DefaultWorkers is the default initial number of workers used to carry an attack.
49	DefaultWorkers = 10
50	// DefaultMaxWorkers is the default maximum number of workers used to carry an attack.
51	DefaultMaxWorkers = math.MaxUint64
52	// DefaultMaxBody is the default max number of bytes to be read from response bodies.
53	// Defaults to no limit.
54	DefaultMaxBody = int64(-1)
55	// NoFollow is the value when redirects are not followed but marked successful
56	NoFollow = -1
57)
58
59var (
60	// DefaultLocalAddr is the default local IP address an Attacker uses.
61	DefaultLocalAddr = net.IPAddr{IP: net.IPv4zero}
62	// DefaultTLSConfig is the default tls.Config an Attacker uses.
63	DefaultTLSConfig = &tls.Config{InsecureSkipVerify: true}
64)
65
66// NewAttacker returns a new Attacker with default options which are overridden
67// by the optionally provided opts.
68func NewAttacker(opts ...func(*Attacker)) *Attacker {
69	a := &Attacker{
70		stopch:     make(chan struct{}),
71		workers:    DefaultWorkers,
72		maxWorkers: DefaultMaxWorkers,
73		maxBody:    DefaultMaxBody,
74		began:      time.Now(),
75	}
76
77	a.dialer = &net.Dialer{
78		LocalAddr: &net.TCPAddr{IP: DefaultLocalAddr.IP, Zone: DefaultLocalAddr.Zone},
79		KeepAlive: 30 * time.Second,
80	}
81
82	a.client = http.Client{
83		Timeout: DefaultTimeout,
84		Transport: &http.Transport{
85			Proxy:               http.ProxyFromEnvironment,
86			Dial:                a.dialer.Dial,
87			TLSClientConfig:     DefaultTLSConfig,
88			MaxIdleConnsPerHost: DefaultConnections,
89			MaxConnsPerHost:     DefaultMaxConnections,
90		},
91	}
92
93	for _, opt := range opts {
94		opt(a)
95	}
96
97	return a
98}
99
100// Workers returns a functional option which sets the initial number of workers
101// an Attacker uses to hit its targets. More workers may be spawned dynamically
102// to sustain the requested rate in the face of slow responses and errors.
103func Workers(n uint64) func(*Attacker) {
104	return func(a *Attacker) { a.workers = n }
105}
106
107// MaxWorkers returns a functional option which sets the maximum number of workers
108// an Attacker can use to hit its targets.
109func MaxWorkers(n uint64) func(*Attacker) {
110	return func(a *Attacker) { a.maxWorkers = n }
111}
112
113// Connections returns a functional option which sets the number of maximum idle
114// open connections per target host.
115func Connections(n int) func(*Attacker) {
116	return func(a *Attacker) {
117		tr := a.client.Transport.(*http.Transport)
118		tr.MaxIdleConnsPerHost = n
119	}
120}
121
122// MaxConnections returns a functional option which sets the number of maximum
123// connections per target host.
124func MaxConnections(n int) func(*Attacker) {
125	return func(a *Attacker) {
126		tr := a.client.Transport.(*http.Transport)
127		tr.MaxConnsPerHost = n
128	}
129}
130
131// ChunkedBody returns a functional option which makes the attacker send the
132// body of each request with the chunked transfer encoding.
133func ChunkedBody(b bool) func(*Attacker) {
134	return func(a *Attacker) { a.chunked = b }
135}
136
137// Redirects returns a functional option which sets the maximum
138// number of redirects an Attacker will follow.
139func Redirects(n int) func(*Attacker) {
140	return func(a *Attacker) {
141		a.redirects = n
142		a.client.CheckRedirect = func(_ *http.Request, via []*http.Request) error {
143			switch {
144			case n == NoFollow:
145				return http.ErrUseLastResponse
146			case n < len(via):
147				return fmt.Errorf("stopped after %d redirects", n)
148			default:
149				return nil
150			}
151		}
152	}
153}
154
155// Proxy returns a functional option which sets the `Proxy` field on
156// the http.Client's Transport
157func Proxy(proxy func(*http.Request) (*url.URL, error)) func(*Attacker) {
158	return func(a *Attacker) {
159		tr := a.client.Transport.(*http.Transport)
160		tr.Proxy = proxy
161	}
162}
163
164// Timeout returns a functional option which sets the maximum amount of time
165// an Attacker will wait for a request to be responded to and completely read.
166func Timeout(d time.Duration) func(*Attacker) {
167	return func(a *Attacker) {
168		a.client.Timeout = d
169	}
170}
171
172// LocalAddr returns a functional option which sets the local address
173// an Attacker will use with its requests.
174func LocalAddr(addr net.IPAddr) func(*Attacker) {
175	return func(a *Attacker) {
176		tr := a.client.Transport.(*http.Transport)
177		a.dialer.LocalAddr = &net.TCPAddr{IP: addr.IP, Zone: addr.Zone}
178		tr.Dial = a.dialer.Dial
179	}
180}
181
182// KeepAlive returns a functional option which toggles KeepAlive
183// connections on the dialer and transport.
184func KeepAlive(keepalive bool) func(*Attacker) {
185	return func(a *Attacker) {
186		tr := a.client.Transport.(*http.Transport)
187		tr.DisableKeepAlives = !keepalive
188		if !keepalive {
189			a.dialer.KeepAlive = 0
190			tr.Dial = a.dialer.Dial
191		}
192	}
193}
194
195// TLSConfig returns a functional option which sets the *tls.Config for a
196// Attacker to use with its requests.
197func TLSConfig(c *tls.Config) func(*Attacker) {
198	return func(a *Attacker) {
199		tr := a.client.Transport.(*http.Transport)
200		tr.TLSClientConfig = c
201	}
202}
203
204// HTTP2 returns a functional option which enables or disables HTTP/2 support
205// on requests performed by an Attacker.
206func HTTP2(enabled bool) func(*Attacker) {
207	return func(a *Attacker) {
208		if tr := a.client.Transport.(*http.Transport); enabled {
209			http2.ConfigureTransport(tr)
210		} else {
211			tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{}
212		}
213	}
214}
215
216// H2C returns a functional option which enables H2C support on requests
217// performed by an Attacker
218func H2C(enabled bool) func(*Attacker) {
219	return func(a *Attacker) {
220		if tr := a.client.Transport.(*http.Transport); enabled {
221			a.client.Transport = &http2.Transport{
222				AllowHTTP: true,
223				DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
224					return tr.Dial(network, addr)
225				},
226			}
227		}
228	}
229}
230
231// MaxBody returns a functional option which limits the max number of bytes
232// read from response bodies. Set to -1 to disable any limits.
233func MaxBody(n int64) func(*Attacker) {
234	return func(a *Attacker) { a.maxBody = n }
235}
236
237// UnixSocket changes the dialer for the attacker to use the specified unix socket file
238func UnixSocket(socket string) func(*Attacker) {
239	return func(a *Attacker) {
240		if tr, ok := a.client.Transport.(*http.Transport); socket != "" && ok {
241			tr.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) {
242				return net.Dial("unix", socket)
243			}
244		}
245	}
246}
247
248// Client returns a functional option that allows you to bring your own http.Client
249func Client(c *http.Client) func(*Attacker) {
250	return func(a *Attacker) { a.client = *c }
251}
252
253// ProxyHeader returns a functional option that allows you to add your own
254// Proxy CONNECT headers
255func ProxyHeader(h http.Header) func(*Attacker) {
256	return func(a *Attacker) {
257		if tr, ok := a.client.Transport.(*http.Transport); ok {
258			tr.ProxyConnectHeader = h
259		}
260	}
261}
262
263// Attack reads its Targets from the passed Targeter and attacks them at
264// the rate specified by the Pacer. When the duration is zero the attack
265// runs until Stop is called. Results are sent to the returned channel as soon
266// as they arrive and will have their Attack field set to the given name.
267func (a *Attacker) Attack(tr Targeter, p Pacer, du time.Duration, name string) <-chan *Result {
268	var wg sync.WaitGroup
269
270	workers := a.workers
271	if workers > a.maxWorkers {
272		workers = a.maxWorkers
273	}
274
275	results := make(chan *Result)
276	ticks := make(chan struct{})
277	for i := uint64(0); i < workers; i++ {
278		wg.Add(1)
279		go a.attack(tr, name, &wg, ticks, results)
280	}
281
282	go func() {
283		defer close(results)
284		defer wg.Wait()
285		defer close(ticks)
286
287		began, count := time.Now(), uint64(0)
288		for {
289			elapsed := time.Since(began)
290			if du > 0 && elapsed > du {
291				return
292			}
293
294			wait, stop := p.Pace(elapsed, count)
295			if stop {
296				return
297			}
298
299			time.Sleep(wait)
300
301			if workers < a.maxWorkers {
302				select {
303				case ticks <- struct{}{}:
304					count++
305					continue
306				case <-a.stopch:
307					return
308				default:
309					// all workers are blocked. start one more and try again
310					workers++
311					wg.Add(1)
312					go a.attack(tr, name, &wg, ticks, results)
313				}
314			}
315
316			select {
317			case ticks <- struct{}{}:
318				count++
319			case <-a.stopch:
320				return
321			}
322		}
323	}()
324
325	return results
326}
327
328// Stop stops the current attack.
329func (a *Attacker) Stop() {
330	select {
331	case <-a.stopch:
332		return
333	default:
334		close(a.stopch)
335	}
336}
337
338func (a *Attacker) attack(tr Targeter, name string, workers *sync.WaitGroup, ticks <-chan struct{}, results chan<- *Result) {
339	defer workers.Done()
340	for range ticks {
341		results <- a.hit(tr, name)
342	}
343}
344
345func (a *Attacker) hit(tr Targeter, name string) *Result {
346	var (
347		res = Result{Attack: name}
348		tgt Target
349		err error
350	)
351
352	a.seqmu.Lock()
353	res.Timestamp = a.began.Add(time.Since(a.began))
354	res.Seq = a.seq
355	a.seq++
356	a.seqmu.Unlock()
357
358	defer func() {
359		res.Latency = time.Since(res.Timestamp)
360		if err != nil {
361			res.Error = err.Error()
362		}
363	}()
364
365	if err = tr(&tgt); err != nil {
366		a.Stop()
367		return &res
368	}
369
370	res.Method = tgt.Method
371	res.URL = tgt.URL
372
373	req, err := tgt.Request()
374	if err != nil {
375		return &res
376	}
377
378	if name != "" {
379		req.Header.Set("X-Vegeta-Attack", name)
380	}
381
382	req.Header.Set("X-Vegeta-Seq", strconv.FormatUint(res.Seq, 10))
383
384	if a.chunked {
385		req.TransferEncoding = append(req.TransferEncoding, "chunked")
386	}
387
388	r, err := a.client.Do(req)
389	if err != nil {
390		return &res
391	}
392	defer r.Body.Close()
393
394	body := io.Reader(r.Body)
395	if a.maxBody >= 0 {
396		body = io.LimitReader(r.Body, a.maxBody)
397	}
398
399	if res.Body, err = ioutil.ReadAll(body); err != nil {
400		return &res
401	} else if _, err = io.Copy(ioutil.Discard, r.Body); err != nil {
402		return &res
403	}
404
405	res.BytesIn = uint64(len(res.Body))
406
407	if req.ContentLength != -1 {
408		res.BytesOut = uint64(req.ContentLength)
409	}
410
411	if res.Code = uint16(r.StatusCode); res.Code < 200 || res.Code >= 400 {
412		res.Error = r.Status
413	}
414
415	res.Headers = r.Header
416
417	return &res
418}
419