1package vegeta 2 3import ( 4 "context" 5 "crypto/tls" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "math" 10 "net" 11 "net/http" 12 "net/url" 13 "strconv" 14 "sync" 15 "time" 16 17 "golang.org/x/net/http2" 18) 19 20// Attacker is an attack executor which wraps an http.Client 21type Attacker struct { 22 dialer *net.Dialer 23 client http.Client 24 stopch chan struct{} 25 workers uint64 26 maxWorkers uint64 27 maxBody int64 28 redirects int 29 seqmu sync.Mutex 30 seq uint64 31 began time.Time 32 chunked bool 33} 34 35const ( 36 // DefaultRedirects is the default number of times an Attacker follows 37 // redirects. 38 DefaultRedirects = 10 39 // DefaultTimeout is the default amount of time an Attacker waits for a request 40 // before it times out. 41 DefaultTimeout = 30 * time.Second 42 // DefaultConnections is the default amount of max open idle connections per 43 // target host. 44 DefaultConnections = 10000 45 // DefaultMaxConnections is the default amount of connections per target 46 // host. 47 DefaultMaxConnections = 0 48 // DefaultWorkers is the default initial number of workers used to carry an attack. 49 DefaultWorkers = 10 50 // DefaultMaxWorkers is the default maximum number of workers used to carry an attack. 51 DefaultMaxWorkers = math.MaxUint64 52 // DefaultMaxBody is the default max number of bytes to be read from response bodies. 53 // Defaults to no limit. 54 DefaultMaxBody = int64(-1) 55 // NoFollow is the value when redirects are not followed but marked successful 56 NoFollow = -1 57) 58 59var ( 60 // DefaultLocalAddr is the default local IP address an Attacker uses. 61 DefaultLocalAddr = net.IPAddr{IP: net.IPv4zero} 62 // DefaultTLSConfig is the default tls.Config an Attacker uses. 63 DefaultTLSConfig = &tls.Config{InsecureSkipVerify: true} 64) 65 66// NewAttacker returns a new Attacker with default options which are overridden 67// by the optionally provided opts. 68func NewAttacker(opts ...func(*Attacker)) *Attacker { 69 a := &Attacker{ 70 stopch: make(chan struct{}), 71 workers: DefaultWorkers, 72 maxWorkers: DefaultMaxWorkers, 73 maxBody: DefaultMaxBody, 74 began: time.Now(), 75 } 76 77 a.dialer = &net.Dialer{ 78 LocalAddr: &net.TCPAddr{IP: DefaultLocalAddr.IP, Zone: DefaultLocalAddr.Zone}, 79 KeepAlive: 30 * time.Second, 80 } 81 82 a.client = http.Client{ 83 Timeout: DefaultTimeout, 84 Transport: &http.Transport{ 85 Proxy: http.ProxyFromEnvironment, 86 Dial: a.dialer.Dial, 87 TLSClientConfig: DefaultTLSConfig, 88 MaxIdleConnsPerHost: DefaultConnections, 89 MaxConnsPerHost: DefaultMaxConnections, 90 }, 91 } 92 93 for _, opt := range opts { 94 opt(a) 95 } 96 97 return a 98} 99 100// Workers returns a functional option which sets the initial number of workers 101// an Attacker uses to hit its targets. More workers may be spawned dynamically 102// to sustain the requested rate in the face of slow responses and errors. 103func Workers(n uint64) func(*Attacker) { 104 return func(a *Attacker) { a.workers = n } 105} 106 107// MaxWorkers returns a functional option which sets the maximum number of workers 108// an Attacker can use to hit its targets. 109func MaxWorkers(n uint64) func(*Attacker) { 110 return func(a *Attacker) { a.maxWorkers = n } 111} 112 113// Connections returns a functional option which sets the number of maximum idle 114// open connections per target host. 115func Connections(n int) func(*Attacker) { 116 return func(a *Attacker) { 117 tr := a.client.Transport.(*http.Transport) 118 tr.MaxIdleConnsPerHost = n 119 } 120} 121 122// MaxConnections returns a functional option which sets the number of maximum 123// connections per target host. 124func MaxConnections(n int) func(*Attacker) { 125 return func(a *Attacker) { 126 tr := a.client.Transport.(*http.Transport) 127 tr.MaxConnsPerHost = n 128 } 129} 130 131// ChunkedBody returns a functional option which makes the attacker send the 132// body of each request with the chunked transfer encoding. 133func ChunkedBody(b bool) func(*Attacker) { 134 return func(a *Attacker) { a.chunked = b } 135} 136 137// Redirects returns a functional option which sets the maximum 138// number of redirects an Attacker will follow. 139func Redirects(n int) func(*Attacker) { 140 return func(a *Attacker) { 141 a.redirects = n 142 a.client.CheckRedirect = func(_ *http.Request, via []*http.Request) error { 143 switch { 144 case n == NoFollow: 145 return http.ErrUseLastResponse 146 case n < len(via): 147 return fmt.Errorf("stopped after %d redirects", n) 148 default: 149 return nil 150 } 151 } 152 } 153} 154 155// Proxy returns a functional option which sets the `Proxy` field on 156// the http.Client's Transport 157func Proxy(proxy func(*http.Request) (*url.URL, error)) func(*Attacker) { 158 return func(a *Attacker) { 159 tr := a.client.Transport.(*http.Transport) 160 tr.Proxy = proxy 161 } 162} 163 164// Timeout returns a functional option which sets the maximum amount of time 165// an Attacker will wait for a request to be responded to and completely read. 166func Timeout(d time.Duration) func(*Attacker) { 167 return func(a *Attacker) { 168 a.client.Timeout = d 169 } 170} 171 172// LocalAddr returns a functional option which sets the local address 173// an Attacker will use with its requests. 174func LocalAddr(addr net.IPAddr) func(*Attacker) { 175 return func(a *Attacker) { 176 tr := a.client.Transport.(*http.Transport) 177 a.dialer.LocalAddr = &net.TCPAddr{IP: addr.IP, Zone: addr.Zone} 178 tr.Dial = a.dialer.Dial 179 } 180} 181 182// KeepAlive returns a functional option which toggles KeepAlive 183// connections on the dialer and transport. 184func KeepAlive(keepalive bool) func(*Attacker) { 185 return func(a *Attacker) { 186 tr := a.client.Transport.(*http.Transport) 187 tr.DisableKeepAlives = !keepalive 188 if !keepalive { 189 a.dialer.KeepAlive = 0 190 tr.Dial = a.dialer.Dial 191 } 192 } 193} 194 195// TLSConfig returns a functional option which sets the *tls.Config for a 196// Attacker to use with its requests. 197func TLSConfig(c *tls.Config) func(*Attacker) { 198 return func(a *Attacker) { 199 tr := a.client.Transport.(*http.Transport) 200 tr.TLSClientConfig = c 201 } 202} 203 204// HTTP2 returns a functional option which enables or disables HTTP/2 support 205// on requests performed by an Attacker. 206func HTTP2(enabled bool) func(*Attacker) { 207 return func(a *Attacker) { 208 if tr := a.client.Transport.(*http.Transport); enabled { 209 http2.ConfigureTransport(tr) 210 } else { 211 tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{} 212 } 213 } 214} 215 216// H2C returns a functional option which enables H2C support on requests 217// performed by an Attacker 218func H2C(enabled bool) func(*Attacker) { 219 return func(a *Attacker) { 220 if tr := a.client.Transport.(*http.Transport); enabled { 221 a.client.Transport = &http2.Transport{ 222 AllowHTTP: true, 223 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { 224 return tr.Dial(network, addr) 225 }, 226 } 227 } 228 } 229} 230 231// MaxBody returns a functional option which limits the max number of bytes 232// read from response bodies. Set to -1 to disable any limits. 233func MaxBody(n int64) func(*Attacker) { 234 return func(a *Attacker) { a.maxBody = n } 235} 236 237// UnixSocket changes the dialer for the attacker to use the specified unix socket file 238func UnixSocket(socket string) func(*Attacker) { 239 return func(a *Attacker) { 240 if tr, ok := a.client.Transport.(*http.Transport); socket != "" && ok { 241 tr.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { 242 return net.Dial("unix", socket) 243 } 244 } 245 } 246} 247 248// Client returns a functional option that allows you to bring your own http.Client 249func Client(c *http.Client) func(*Attacker) { 250 return func(a *Attacker) { a.client = *c } 251} 252 253// ProxyHeader returns a functional option that allows you to add your own 254// Proxy CONNECT headers 255func ProxyHeader(h http.Header) func(*Attacker) { 256 return func(a *Attacker) { 257 if tr, ok := a.client.Transport.(*http.Transport); ok { 258 tr.ProxyConnectHeader = h 259 } 260 } 261} 262 263// Attack reads its Targets from the passed Targeter and attacks them at 264// the rate specified by the Pacer. When the duration is zero the attack 265// runs until Stop is called. Results are sent to the returned channel as soon 266// as they arrive and will have their Attack field set to the given name. 267func (a *Attacker) Attack(tr Targeter, p Pacer, du time.Duration, name string) <-chan *Result { 268 var wg sync.WaitGroup 269 270 workers := a.workers 271 if workers > a.maxWorkers { 272 workers = a.maxWorkers 273 } 274 275 results := make(chan *Result) 276 ticks := make(chan struct{}) 277 for i := uint64(0); i < workers; i++ { 278 wg.Add(1) 279 go a.attack(tr, name, &wg, ticks, results) 280 } 281 282 go func() { 283 defer close(results) 284 defer wg.Wait() 285 defer close(ticks) 286 287 began, count := time.Now(), uint64(0) 288 for { 289 elapsed := time.Since(began) 290 if du > 0 && elapsed > du { 291 return 292 } 293 294 wait, stop := p.Pace(elapsed, count) 295 if stop { 296 return 297 } 298 299 time.Sleep(wait) 300 301 if workers < a.maxWorkers { 302 select { 303 case ticks <- struct{}{}: 304 count++ 305 continue 306 case <-a.stopch: 307 return 308 default: 309 // all workers are blocked. start one more and try again 310 workers++ 311 wg.Add(1) 312 go a.attack(tr, name, &wg, ticks, results) 313 } 314 } 315 316 select { 317 case ticks <- struct{}{}: 318 count++ 319 case <-a.stopch: 320 return 321 } 322 } 323 }() 324 325 return results 326} 327 328// Stop stops the current attack. 329func (a *Attacker) Stop() { 330 select { 331 case <-a.stopch: 332 return 333 default: 334 close(a.stopch) 335 } 336} 337 338func (a *Attacker) attack(tr Targeter, name string, workers *sync.WaitGroup, ticks <-chan struct{}, results chan<- *Result) { 339 defer workers.Done() 340 for range ticks { 341 results <- a.hit(tr, name) 342 } 343} 344 345func (a *Attacker) hit(tr Targeter, name string) *Result { 346 var ( 347 res = Result{Attack: name} 348 tgt Target 349 err error 350 ) 351 352 a.seqmu.Lock() 353 res.Timestamp = a.began.Add(time.Since(a.began)) 354 res.Seq = a.seq 355 a.seq++ 356 a.seqmu.Unlock() 357 358 defer func() { 359 res.Latency = time.Since(res.Timestamp) 360 if err != nil { 361 res.Error = err.Error() 362 } 363 }() 364 365 if err = tr(&tgt); err != nil { 366 a.Stop() 367 return &res 368 } 369 370 res.Method = tgt.Method 371 res.URL = tgt.URL 372 373 req, err := tgt.Request() 374 if err != nil { 375 return &res 376 } 377 378 if name != "" { 379 req.Header.Set("X-Vegeta-Attack", name) 380 } 381 382 req.Header.Set("X-Vegeta-Seq", strconv.FormatUint(res.Seq, 10)) 383 384 if a.chunked { 385 req.TransferEncoding = append(req.TransferEncoding, "chunked") 386 } 387 388 r, err := a.client.Do(req) 389 if err != nil { 390 return &res 391 } 392 defer r.Body.Close() 393 394 body := io.Reader(r.Body) 395 if a.maxBody >= 0 { 396 body = io.LimitReader(r.Body, a.maxBody) 397 } 398 399 if res.Body, err = ioutil.ReadAll(body); err != nil { 400 return &res 401 } else if _, err = io.Copy(ioutil.Discard, r.Body); err != nil { 402 return &res 403 } 404 405 res.BytesIn = uint64(len(res.Body)) 406 407 if req.ContentLength != -1 { 408 res.BytesOut = uint64(req.ContentLength) 409 } 410 411 if res.Code = uint16(r.StatusCode); res.Code < 200 || res.Code >= 400 { 412 res.Error = r.Status 413 } 414 415 res.Headers = r.Header 416 417 return &res 418} 419