1package main
2
3import (
4	"crypto/tls"
5	"crypto/x509"
6	"errors"
7	"flag"
8	"fmt"
9	"io"
10	"io/ioutil"
11	"net"
12	"net/http"
13	"os"
14	"os/signal"
15	"strings"
16	"time"
17
18	"github.com/tsenart/vegeta/v12/internal/resolver"
19	vegeta "github.com/tsenart/vegeta/v12/lib"
20)
21
22func attackCmd() command {
23	fs := flag.NewFlagSet("vegeta attack", flag.ExitOnError)
24	opts := &attackOpts{
25		headers:      headers{http.Header{}},
26		proxyHeaders: headers{http.Header{}},
27		laddr:        localAddr{&vegeta.DefaultLocalAddr},
28		rate:         vegeta.Rate{Freq: 50, Per: time.Second},
29		maxBody:      vegeta.DefaultMaxBody,
30	}
31	fs.StringVar(&opts.name, "name", "", "Attack name")
32	fs.StringVar(&opts.targetsf, "targets", "stdin", "Targets file")
33	fs.StringVar(&opts.format, "format", vegeta.HTTPTargetFormat,
34		fmt.Sprintf("Targets format [%s]", strings.Join(vegeta.TargetFormats, ", ")))
35	fs.StringVar(&opts.outputf, "output", "stdout", "Output file")
36	fs.StringVar(&opts.bodyf, "body", "", "Requests body file")
37	fs.BoolVar(&opts.chunked, "chunked", false, "Send body with chunked transfer encoding")
38	fs.StringVar(&opts.certf, "cert", "", "TLS client PEM encoded certificate file")
39	fs.StringVar(&opts.keyf, "key", "", "TLS client PEM encoded private key file")
40	fs.Var(&opts.rootCerts, "root-certs", "TLS root certificate files (comma separated list)")
41	fs.BoolVar(&opts.http2, "http2", true, "Send HTTP/2 requests when supported by the server")
42	fs.BoolVar(&opts.h2c, "h2c", false, "Send HTTP/2 requests without TLS encryption")
43	fs.BoolVar(&opts.insecure, "insecure", false, "Ignore invalid server TLS certificates")
44	fs.BoolVar(&opts.lazy, "lazy", false, "Read targets lazily")
45	fs.DurationVar(&opts.duration, "duration", 0, "Duration of the test [0 = forever]")
46	fs.DurationVar(&opts.timeout, "timeout", vegeta.DefaultTimeout, "Requests timeout")
47	fs.Uint64Var(&opts.workers, "workers", vegeta.DefaultWorkers, "Initial number of workers")
48	fs.Uint64Var(&opts.maxWorkers, "max-workers", vegeta.DefaultMaxWorkers, "Maximum number of workers")
49	fs.IntVar(&opts.connections, "connections", vegeta.DefaultConnections, "Max open idle connections per target host")
50	fs.IntVar(&opts.maxConnections, "max-connections", vegeta.DefaultMaxConnections, "Max connections per target host")
51	fs.IntVar(&opts.redirects, "redirects", vegeta.DefaultRedirects, "Number of redirects to follow. -1 will not follow but marks as success")
52	fs.Var(&maxBodyFlag{&opts.maxBody}, "max-body", "Maximum number of bytes to capture from response bodies. [-1 = no limit]")
53	fs.Var(&rateFlag{&opts.rate}, "rate", "Number of requests per time unit [0 = infinity]")
54	fs.Var(&opts.headers, "header", "Request header")
55	fs.Var(&opts.proxyHeaders, "proxy-header", "Proxy CONNECT header")
56	fs.Var(&opts.laddr, "laddr", "Local IP address")
57	fs.BoolVar(&opts.keepalive, "keepalive", true, "Use persistent connections")
58	fs.StringVar(&opts.unixSocket, "unix-socket", "", "Connect over a unix socket. This overrides the host address in target URLs")
59	systemSpecificFlags(fs, opts)
60
61	return command{fs, func(args []string) error {
62		fs.Parse(args)
63		return attack(opts)
64	}}
65}
66
67var (
68	errZeroRate = errors.New("rate frequency and time unit must be bigger than zero")
69	errBadCert  = errors.New("bad certificate")
70)
71
72// attackOpts aggregates the attack function command options
73type attackOpts struct {
74	name           string
75	targetsf       string
76	format         string
77	outputf        string
78	bodyf          string
79	certf          string
80	keyf           string
81	rootCerts      csl
82	http2          bool
83	h2c            bool
84	insecure       bool
85	lazy           bool
86	chunked        bool
87	duration       time.Duration
88	timeout        time.Duration
89	rate           vegeta.Rate
90	workers        uint64
91	maxWorkers     uint64
92	connections    int
93	maxConnections int
94	redirects      int
95	maxBody        int64
96	headers        headers
97	proxyHeaders   headers
98	laddr          localAddr
99	keepalive      bool
100	resolvers      csl
101	unixSocket     string
102}
103
104// attack validates the attack arguments, sets up the
105// required resources, launches the attack and writes the results
106func attack(opts *attackOpts) (err error) {
107	if opts.maxWorkers == vegeta.DefaultMaxWorkers && opts.rate.Freq == 0 {
108		return fmt.Errorf("-rate=0 requires setting -max-workers")
109	}
110
111	if len(opts.resolvers) > 0 {
112		res, err := resolver.NewResolver(opts.resolvers)
113		if err != nil {
114			return err
115		}
116		net.DefaultResolver = res
117	}
118
119	files := map[string]io.Reader{}
120	for _, filename := range []string{opts.targetsf, opts.bodyf} {
121		if filename == "" {
122			continue
123		}
124		f, err := file(filename, false)
125		if err != nil {
126			return fmt.Errorf("error opening %s: %s", filename, err)
127		}
128		defer f.Close()
129		files[filename] = f
130	}
131
132	var body []byte
133	if bodyf, ok := files[opts.bodyf]; ok {
134		if body, err = ioutil.ReadAll(bodyf); err != nil {
135			return fmt.Errorf("error reading %s: %s", opts.bodyf, err)
136		}
137	}
138
139	var (
140		tr       vegeta.Targeter
141		src      = files[opts.targetsf]
142		hdr      = opts.headers.Header
143		proxyHdr = opts.proxyHeaders.Header
144	)
145
146	switch opts.format {
147	case vegeta.JSONTargetFormat:
148		tr = vegeta.NewJSONTargeter(src, body, hdr)
149	case vegeta.HTTPTargetFormat:
150		tr = vegeta.NewHTTPTargeter(src, body, hdr)
151	default:
152		return fmt.Errorf("format %q isn't one of [%s]",
153			opts.format, strings.Join(vegeta.TargetFormats, ", "))
154	}
155
156	if !opts.lazy {
157		targets, err := vegeta.ReadAllTargets(tr)
158		if err != nil {
159			return err
160		}
161		tr = vegeta.NewStaticTargeter(targets...)
162	}
163
164	out, err := file(opts.outputf, true)
165	if err != nil {
166		return fmt.Errorf("error opening %s: %s", opts.outputf, err)
167	}
168	defer out.Close()
169
170	tlsc, err := tlsConfig(opts.insecure, opts.certf, opts.keyf, opts.rootCerts)
171	if err != nil {
172		return err
173	}
174
175	atk := vegeta.NewAttacker(
176		vegeta.Redirects(opts.redirects),
177		vegeta.Timeout(opts.timeout),
178		vegeta.LocalAddr(*opts.laddr.IPAddr),
179		vegeta.TLSConfig(tlsc),
180		vegeta.Workers(opts.workers),
181		vegeta.MaxWorkers(opts.maxWorkers),
182		vegeta.KeepAlive(opts.keepalive),
183		vegeta.Connections(opts.connections),
184		vegeta.MaxConnections(opts.maxConnections),
185		vegeta.HTTP2(opts.http2),
186		vegeta.H2C(opts.h2c),
187		vegeta.MaxBody(opts.maxBody),
188		vegeta.UnixSocket(opts.unixSocket),
189		vegeta.ProxyHeader(proxyHdr),
190		vegeta.ChunkedBody(opts.chunked),
191	)
192
193	res := atk.Attack(tr, opts.rate, opts.duration, opts.name)
194	enc := vegeta.NewEncoder(out)
195	sig := make(chan os.Signal, 1)
196	signal.Notify(sig, os.Interrupt)
197
198	for {
199		select {
200		case <-sig:
201			atk.Stop()
202			return nil
203		case r, ok := <-res:
204			if !ok {
205				return nil
206			}
207			if err = enc.Encode(r); err != nil {
208				return err
209			}
210		}
211	}
212}
213
214// tlsConfig builds a *tls.Config from the given options.
215func tlsConfig(insecure bool, certf, keyf string, rootCerts []string) (*tls.Config, error) {
216	var err error
217	files := map[string][]byte{}
218	filenames := append([]string{certf, keyf}, rootCerts...)
219	for _, f := range filenames {
220		if f != "" {
221			if files[f], err = ioutil.ReadFile(f); err != nil {
222				return nil, err
223			}
224		}
225	}
226
227	c := tls.Config{InsecureSkipVerify: insecure}
228	if cert, ok := files[certf]; ok {
229		key, ok := files[keyf]
230		if !ok {
231			key = cert
232		}
233
234		certificate, err := tls.X509KeyPair(cert, key)
235		if err != nil {
236			return nil, err
237		}
238
239		c.Certificates = append(c.Certificates, certificate)
240		c.BuildNameToCertificate()
241	}
242
243	if len(rootCerts) > 0 {
244		c.RootCAs = x509.NewCertPool()
245		for _, f := range rootCerts {
246			if !c.RootCAs.AppendCertsFromPEM(files[f]) {
247				return nil, errBadCert
248			}
249		}
250	}
251
252	return &c, nil
253}
254