1package checks
2
3import (
4	"crypto/tls"
5	"fmt"
6	"io"
7	"io/ioutil"
8	"log"
9	"net"
10	"net/http"
11	"os"
12	osexec "os/exec"
13	"sync"
14	"syscall"
15	"time"
16
17	"github.com/armon/circbuf"
18	"github.com/hashicorp/consul/agent/exec"
19	"github.com/hashicorp/consul/api"
20	"github.com/hashicorp/consul/lib"
21	"github.com/hashicorp/consul/types"
22	"github.com/hashicorp/go-cleanhttp"
23)
24
25const (
26	// MinInterval is the minimal interval between
27	// two checks. Do not allow for a interval below this value.
28	// Otherwise we risk fork bombing a system.
29	MinInterval = time.Second
30
31	// BufSize is the maximum size of the captured
32	// check output. Prevents an enormous buffer
33	// from being captured
34	BufSize = 4 * 1024 // 4KB
35
36	// UserAgent is the value of the User-Agent header
37	// for HTTP health checks.
38	UserAgent = "Consul Health Check"
39)
40
41// RPC is an interface that an RPC client must implement. This is a helper
42// interface that is implemented by the agent delegate for checks that need
43// to make RPC calls.
44type RPC interface {
45	RPC(method string, args interface{}, reply interface{}) error
46}
47
48// CheckNotifier interface is used by the CheckMonitor
49// to notify when a check has a status update. The update
50// should take care to be idempotent.
51type CheckNotifier interface {
52	UpdateCheck(checkID types.CheckID, status, output string)
53}
54
55// CheckMonitor is used to periodically invoke a script to
56// determine the health of a given check. It is compatible with
57// nagios plugins and expects the output in the same format.
58type CheckMonitor struct {
59	Notify     CheckNotifier
60	CheckID    types.CheckID
61	Script     string
62	ScriptArgs []string
63	Interval   time.Duration
64	Timeout    time.Duration
65	Logger     *log.Logger
66
67	stop     bool
68	stopCh   chan struct{}
69	stopLock sync.Mutex
70}
71
72// Start is used to start a check monitor.
73// Monitor runs until stop is called
74func (c *CheckMonitor) Start() {
75	c.stopLock.Lock()
76	defer c.stopLock.Unlock()
77	c.stop = false
78	c.stopCh = make(chan struct{})
79	go c.run()
80}
81
82// Stop is used to stop a check monitor.
83func (c *CheckMonitor) Stop() {
84	c.stopLock.Lock()
85	defer c.stopLock.Unlock()
86	if !c.stop {
87		c.stop = true
88		close(c.stopCh)
89	}
90}
91
92// run is invoked by a goroutine to run until Stop() is called
93func (c *CheckMonitor) run() {
94	// Get the randomized initial pause time
95	initialPauseTime := lib.RandomStagger(c.Interval)
96	next := time.After(initialPauseTime)
97	for {
98		select {
99		case <-next:
100			c.check()
101			next = time.After(c.Interval)
102		case <-c.stopCh:
103			return
104		}
105	}
106}
107
108// check is invoked periodically to perform the script check
109func (c *CheckMonitor) check() {
110	// Create the command
111	var cmd *osexec.Cmd
112	var err error
113	if len(c.ScriptArgs) > 0 {
114		cmd, err = exec.Subprocess(c.ScriptArgs)
115	} else {
116		cmd, err = exec.Script(c.Script)
117	}
118	if err != nil {
119		c.Logger.Printf("[ERR] agent: Check %q failed to setup: %s", c.CheckID, err)
120		c.Notify.UpdateCheck(c.CheckID, api.HealthCritical, err.Error())
121		return
122	}
123
124	// Collect the output
125	output, _ := circbuf.NewBuffer(BufSize)
126	cmd.Stdout = output
127	cmd.Stderr = output
128	exec.SetSysProcAttr(cmd)
129
130	truncateAndLogOutput := func() string {
131		outputStr := string(output.Bytes())
132		if output.TotalWritten() > output.Size() {
133			outputStr = fmt.Sprintf("Captured %d of %d bytes\n...\n%s",
134				output.Size(), output.TotalWritten(), outputStr)
135		}
136		c.Logger.Printf("[TRACE] agent: Check %q output: %s", c.CheckID, outputStr)
137		return outputStr
138	}
139
140	// Start the check
141	if err := cmd.Start(); err != nil {
142		c.Logger.Printf("[ERR] agent: Check %q failed to invoke: %s", c.CheckID, err)
143		c.Notify.UpdateCheck(c.CheckID, api.HealthCritical, err.Error())
144		return
145	}
146
147	// Wait for the check to complete
148	waitCh := make(chan error, 1)
149	go func() {
150		waitCh <- cmd.Wait()
151	}()
152
153	timeout := 30 * time.Second
154	if c.Timeout > 0 {
155		timeout = c.Timeout
156	}
157	select {
158	case <-time.After(timeout):
159		if err := exec.KillCommandSubtree(cmd); err != nil {
160			c.Logger.Printf("[WARN] agent: Check %q failed to kill after timeout: %s", c.CheckID, err)
161		}
162
163		msg := fmt.Sprintf("Timed out (%s) running check", timeout.String())
164		c.Logger.Printf("[WARN] agent: Check %q: %s", c.CheckID, msg)
165
166		outputStr := truncateAndLogOutput()
167		if len(outputStr) > 0 {
168			msg += "\n\n" + outputStr
169		}
170		c.Notify.UpdateCheck(c.CheckID, api.HealthCritical, msg)
171
172		// Now wait for the process to exit so we never start another
173		// instance concurrently.
174		<-waitCh
175		return
176
177	case err = <-waitCh:
178		// The process returned before the timeout, proceed normally
179	}
180
181	// Check if the check passed
182	outputStr := truncateAndLogOutput()
183	if err == nil {
184		c.Logger.Printf("[DEBUG] agent: Check %q is passing", c.CheckID)
185		c.Notify.UpdateCheck(c.CheckID, api.HealthPassing, outputStr)
186		return
187	}
188
189	// If the exit code is 1, set check as warning
190	exitErr, ok := err.(*osexec.ExitError)
191	if ok {
192		if status, ok := exitErr.Sys().(syscall.WaitStatus); ok {
193			code := status.ExitStatus()
194			if code == 1 {
195				c.Logger.Printf("[WARN] agent: Check %q is now warning", c.CheckID)
196				c.Notify.UpdateCheck(c.CheckID, api.HealthWarning, outputStr)
197				return
198			}
199		}
200	}
201
202	// Set the health as critical
203	c.Logger.Printf("[WARN] agent: Check %q is now critical", c.CheckID)
204	c.Notify.UpdateCheck(c.CheckID, api.HealthCritical, outputStr)
205}
206
207// CheckTTL is used to apply a TTL to check status,
208// and enables clients to set the status of a check
209// but upon the TTL expiring, the check status is
210// automatically set to critical.
211type CheckTTL struct {
212	Notify  CheckNotifier
213	CheckID types.CheckID
214	TTL     time.Duration
215	Logger  *log.Logger
216
217	timer *time.Timer
218
219	lastOutput     string
220	lastOutputLock sync.RWMutex
221
222	stop     bool
223	stopCh   chan struct{}
224	stopLock sync.Mutex
225}
226
227// Start is used to start a check ttl, runs until Stop()
228func (c *CheckTTL) Start() {
229	c.stopLock.Lock()
230	defer c.stopLock.Unlock()
231	c.stop = false
232	c.stopCh = make(chan struct{})
233	c.timer = time.NewTimer(c.TTL)
234	go c.run()
235}
236
237// Stop is used to stop a check ttl.
238func (c *CheckTTL) Stop() {
239	c.stopLock.Lock()
240	defer c.stopLock.Unlock()
241	if !c.stop {
242		c.timer.Stop()
243		c.stop = true
244		close(c.stopCh)
245	}
246}
247
248// run is used to handle TTL expiration and to update the check status
249func (c *CheckTTL) run() {
250	for {
251		select {
252		case <-c.timer.C:
253			c.Logger.Printf("[WARN] agent: Check %q missed TTL, is now critical",
254				c.CheckID)
255			c.Notify.UpdateCheck(c.CheckID, api.HealthCritical, c.getExpiredOutput())
256
257		case <-c.stopCh:
258			return
259		}
260	}
261}
262
263// getExpiredOutput formats the output for the case when the TTL is expired.
264func (c *CheckTTL) getExpiredOutput() string {
265	c.lastOutputLock.RLock()
266	defer c.lastOutputLock.RUnlock()
267
268	const prefix = "TTL expired"
269	if c.lastOutput == "" {
270		return prefix
271	}
272
273	return fmt.Sprintf("%s (last output before timeout follows): %s", prefix, c.lastOutput)
274}
275
276// SetStatus is used to update the status of the check,
277// and to renew the TTL. If expired, TTL is restarted.
278func (c *CheckTTL) SetStatus(status, output string) {
279	c.Logger.Printf("[DEBUG] agent: Check %q status is now %s", c.CheckID, status)
280	c.Notify.UpdateCheck(c.CheckID, status, output)
281
282	// Store the last output so we can retain it if the TTL expires.
283	c.lastOutputLock.Lock()
284	c.lastOutput = output
285	c.lastOutputLock.Unlock()
286
287	c.timer.Reset(c.TTL)
288}
289
290// CheckHTTP is used to periodically make an HTTP request to
291// determine the health of a given check.
292// The check is passing if the response code is 2XX.
293// The check is warning if the response code is 429.
294// The check is critical if the response code is anything else
295// or if the request returns an error
296type CheckHTTP struct {
297	Notify          CheckNotifier
298	CheckID         types.CheckID
299	HTTP            string
300	Header          map[string][]string
301	Method          string
302	Interval        time.Duration
303	Timeout         time.Duration
304	Logger          *log.Logger
305	TLSClientConfig *tls.Config
306
307	httpClient *http.Client
308	stop       bool
309	stopCh     chan struct{}
310	stopLock   sync.Mutex
311}
312
313// Start is used to start an HTTP check.
314// The check runs until stop is called
315func (c *CheckHTTP) Start() {
316	c.stopLock.Lock()
317	defer c.stopLock.Unlock()
318
319	if c.httpClient == nil {
320		// Create the transport. We disable HTTP Keep-Alive's to prevent
321		// failing checks due to the keepalive interval.
322		trans := cleanhttp.DefaultTransport()
323		trans.DisableKeepAlives = true
324
325		// Take on the supplied TLS client config.
326		trans.TLSClientConfig = c.TLSClientConfig
327
328		// Create the HTTP client.
329		c.httpClient = &http.Client{
330			Timeout:   10 * time.Second,
331			Transport: trans,
332		}
333
334		// For long (>10s) interval checks the http timeout is 10s, otherwise the
335		// timeout is the interval. This means that a check *should* return
336		// before the next check begins.
337		if c.Timeout > 0 && c.Timeout < c.Interval {
338			c.httpClient.Timeout = c.Timeout
339		} else if c.Interval < 10*time.Second {
340			c.httpClient.Timeout = c.Interval
341		}
342	}
343
344	c.stop = false
345	c.stopCh = make(chan struct{})
346	go c.run()
347}
348
349// Stop is used to stop an HTTP check.
350func (c *CheckHTTP) Stop() {
351	c.stopLock.Lock()
352	defer c.stopLock.Unlock()
353	if !c.stop {
354		c.stop = true
355		close(c.stopCh)
356	}
357}
358
359// run is invoked by a goroutine to run until Stop() is called
360func (c *CheckHTTP) run() {
361	// Get the randomized initial pause time
362	initialPauseTime := lib.RandomStagger(c.Interval)
363	next := time.After(initialPauseTime)
364	for {
365		select {
366		case <-next:
367			c.check()
368			next = time.After(c.Interval)
369		case <-c.stopCh:
370			return
371		}
372	}
373}
374
375// check is invoked periodically to perform the HTTP check
376func (c *CheckHTTP) check() {
377	method := c.Method
378	if method == "" {
379		method = "GET"
380	}
381
382	req, err := http.NewRequest(method, c.HTTP, nil)
383	if err != nil {
384		c.Logger.Printf("[WARN] agent: Check %q HTTP request failed: %s", c.CheckID, err)
385		c.Notify.UpdateCheck(c.CheckID, api.HealthCritical, err.Error())
386		return
387	}
388
389	req.Header = http.Header(c.Header)
390
391	// this happens during testing but not in prod
392	if req.Header == nil {
393		req.Header = make(http.Header)
394	}
395
396	if host := req.Header.Get("Host"); host != "" {
397		req.Host = host
398	}
399
400	if req.Header.Get("User-Agent") == "" {
401		req.Header.Set("User-Agent", UserAgent)
402	}
403	if req.Header.Get("Accept") == "" {
404		req.Header.Set("Accept", "text/plain, text/*, */*")
405	}
406
407	resp, err := c.httpClient.Do(req)
408	if err != nil {
409		c.Logger.Printf("[WARN] agent: Check %q HTTP request failed: %s", c.CheckID, err)
410		c.Notify.UpdateCheck(c.CheckID, api.HealthCritical, err.Error())
411		return
412	}
413	defer resp.Body.Close()
414
415	// Read the response into a circular buffer to limit the size
416	output, _ := circbuf.NewBuffer(BufSize)
417	if _, err := io.Copy(output, resp.Body); err != nil {
418		c.Logger.Printf("[WARN] agent: Check %q error while reading body: %s", c.CheckID, err)
419	}
420
421	// Format the response body
422	result := fmt.Sprintf("HTTP %s %s: %s Output: %s", method, c.HTTP, resp.Status, output.String())
423
424	if resp.StatusCode >= 200 && resp.StatusCode <= 299 {
425		// PASSING (2xx)
426		c.Logger.Printf("[DEBUG] agent: Check %q is passing", c.CheckID)
427		c.Notify.UpdateCheck(c.CheckID, api.HealthPassing, result)
428
429	} else if resp.StatusCode == 429 {
430		// WARNING
431		// 429 Too Many Requests (RFC 6585)
432		// The user has sent too many requests in a given amount of time.
433		c.Logger.Printf("[WARN] agent: Check %q is now warning", c.CheckID)
434		c.Notify.UpdateCheck(c.CheckID, api.HealthWarning, result)
435
436	} else {
437		// CRITICAL
438		c.Logger.Printf("[WARN] agent: Check %q is now critical", c.CheckID)
439		c.Notify.UpdateCheck(c.CheckID, api.HealthCritical, result)
440	}
441}
442
443// CheckTCP is used to periodically make an TCP/UDP connection to
444// determine the health of a given check.
445// The check is passing if the connection succeeds
446// The check is critical if the connection returns an error
447type CheckTCP struct {
448	Notify   CheckNotifier
449	CheckID  types.CheckID
450	TCP      string
451	Interval time.Duration
452	Timeout  time.Duration
453	Logger   *log.Logger
454
455	dialer   *net.Dialer
456	stop     bool
457	stopCh   chan struct{}
458	stopLock sync.Mutex
459}
460
461// Start is used to start a TCP check.
462// The check runs until stop is called
463func (c *CheckTCP) Start() {
464	c.stopLock.Lock()
465	defer c.stopLock.Unlock()
466
467	if c.dialer == nil {
468		// Create the socket dialer
469		c.dialer = &net.Dialer{DualStack: true}
470
471		// For long (>10s) interval checks the socket timeout is 10s, otherwise
472		// the timeout is the interval. This means that a check *should* return
473		// before the next check begins.
474		if c.Timeout > 0 && c.Timeout < c.Interval {
475			c.dialer.Timeout = c.Timeout
476		} else if c.Interval < 10*time.Second {
477			c.dialer.Timeout = c.Interval
478		}
479	}
480
481	c.stop = false
482	c.stopCh = make(chan struct{})
483	go c.run()
484}
485
486// Stop is used to stop a TCP check.
487func (c *CheckTCP) Stop() {
488	c.stopLock.Lock()
489	defer c.stopLock.Unlock()
490	if !c.stop {
491		c.stop = true
492		close(c.stopCh)
493	}
494}
495
496// run is invoked by a goroutine to run until Stop() is called
497func (c *CheckTCP) run() {
498	// Get the randomized initial pause time
499	initialPauseTime := lib.RandomStagger(c.Interval)
500	next := time.After(initialPauseTime)
501	for {
502		select {
503		case <-next:
504			c.check()
505			next = time.After(c.Interval)
506		case <-c.stopCh:
507			return
508		}
509	}
510}
511
512// check is invoked periodically to perform the TCP check
513func (c *CheckTCP) check() {
514	conn, err := c.dialer.Dial(`tcp`, c.TCP)
515	if err != nil {
516		c.Logger.Printf("[WARN] agent: Check %q socket connection failed: %s", c.CheckID, err)
517		c.Notify.UpdateCheck(c.CheckID, api.HealthCritical, err.Error())
518		return
519	}
520	conn.Close()
521	c.Logger.Printf("[DEBUG] agent: Check %q is passing", c.CheckID)
522	c.Notify.UpdateCheck(c.CheckID, api.HealthPassing, fmt.Sprintf("TCP connect %s: Success", c.TCP))
523}
524
525// CheckDocker is used to periodically invoke a script to
526// determine the health of an application running inside a
527// Docker Container. We assume that the script is compatible
528// with nagios plugins and expects the output in the same format.
529type CheckDocker struct {
530	Notify            CheckNotifier
531	CheckID           types.CheckID
532	Script            string
533	ScriptArgs        []string
534	DockerContainerID string
535	Shell             string
536	Interval          time.Duration
537	Logger            *log.Logger
538	Client            *DockerClient
539
540	stop chan struct{}
541}
542
543func (c *CheckDocker) Start() {
544	if c.stop != nil {
545		panic("Docker check already started")
546	}
547
548	if c.Logger == nil {
549		c.Logger = log.New(ioutil.Discard, "", 0)
550	}
551
552	if c.Shell == "" {
553		c.Shell = os.Getenv("SHELL")
554		if c.Shell == "" {
555			c.Shell = "/bin/sh"
556		}
557	}
558	c.stop = make(chan struct{})
559	go c.run()
560}
561
562func (c *CheckDocker) Stop() {
563	if c.stop == nil {
564		panic("Stop called before start")
565	}
566	close(c.stop)
567}
568
569func (c *CheckDocker) run() {
570	defer c.Client.Close()
571	firstWait := lib.RandomStagger(c.Interval)
572	next := time.After(firstWait)
573	for {
574		select {
575		case <-next:
576			c.check()
577			next = time.After(c.Interval)
578		case <-c.stop:
579			return
580		}
581	}
582}
583
584func (c *CheckDocker) check() {
585	var out string
586	status, b, err := c.doCheck()
587	if err != nil {
588		c.Logger.Printf("[DEBUG] agent: Check %q: %s", c.CheckID, err)
589		out = err.Error()
590	} else {
591		// out is already limited to CheckBufSize since we're getting a
592		// limited buffer. So we don't need to truncate it just report
593		// that it was truncated.
594		out = string(b.Bytes())
595		if int(b.TotalWritten()) > len(out) {
596			out = fmt.Sprintf("Captured %d of %d bytes\n...\n%s", len(out), b.TotalWritten(), out)
597		}
598		c.Logger.Printf("[TRACE] agent: Check %q output: %s", c.CheckID, out)
599	}
600
601	if status == api.HealthCritical {
602		c.Logger.Printf("[WARN] agent: Check %q is now critical", c.CheckID)
603	}
604
605	c.Notify.UpdateCheck(c.CheckID, status, out)
606}
607
608func (c *CheckDocker) doCheck() (string, *circbuf.Buffer, error) {
609	var cmd []string
610	if len(c.ScriptArgs) > 0 {
611		cmd = c.ScriptArgs
612	} else {
613		cmd = []string{c.Shell, "-c", c.Script}
614	}
615
616	execID, err := c.Client.CreateExec(c.DockerContainerID, cmd)
617	if err != nil {
618		return api.HealthCritical, nil, err
619	}
620
621	buf, err := c.Client.StartExec(c.DockerContainerID, execID)
622	if err != nil {
623		return api.HealthCritical, nil, err
624	}
625
626	exitCode, err := c.Client.InspectExec(c.DockerContainerID, execID)
627	if err != nil {
628		return api.HealthCritical, nil, err
629	}
630
631	switch exitCode {
632	case 0:
633		return api.HealthPassing, buf, nil
634	case 1:
635		c.Logger.Printf("[DEBUG] agent: Check %q failed with exit code: %d", c.CheckID, exitCode)
636		return api.HealthWarning, buf, nil
637	default:
638		c.Logger.Printf("[DEBUG] agent: Check %q failed with exit code: %d", c.CheckID, exitCode)
639		return api.HealthCritical, buf, nil
640	}
641}
642
643// CheckGRPC is used to periodically send request to a gRPC server
644// application that implements gRPC health-checking protocol.
645// The check is passing if returned status is SERVING.
646// The check is critical if connection fails or returned status is
647// not SERVING.
648type CheckGRPC struct {
649	Notify          CheckNotifier
650	CheckID         types.CheckID
651	GRPC            string
652	Interval        time.Duration
653	Timeout         time.Duration
654	TLSClientConfig *tls.Config
655	Logger          *log.Logger
656
657	probe    *GrpcHealthProbe
658	stop     bool
659	stopCh   chan struct{}
660	stopLock sync.Mutex
661}
662
663func (c *CheckGRPC) Start() {
664	c.stopLock.Lock()
665	defer c.stopLock.Unlock()
666	timeout := 10 * time.Second
667	if c.Timeout > 0 {
668		timeout = c.Timeout
669	}
670	c.probe = NewGrpcHealthProbe(c.GRPC, timeout, c.TLSClientConfig)
671	c.stop = false
672	c.stopCh = make(chan struct{})
673	go c.run()
674}
675
676func (c *CheckGRPC) run() {
677	// Get the randomized initial pause time
678	initialPauseTime := lib.RandomStagger(c.Interval)
679	next := time.After(initialPauseTime)
680	for {
681		select {
682		case <-next:
683			c.check()
684			next = time.After(c.Interval)
685		case <-c.stopCh:
686			return
687		}
688	}
689}
690
691func (c *CheckGRPC) check() {
692	err := c.probe.Check()
693	if err != nil {
694		c.Logger.Printf("[DEBUG] agent: Check %q failed: %s", c.CheckID, err.Error())
695		c.Notify.UpdateCheck(c.CheckID, api.HealthCritical, err.Error())
696	} else {
697		c.Logger.Printf("[DEBUG] agent: Check %q is passing", c.CheckID)
698		c.Notify.UpdateCheck(c.CheckID, api.HealthPassing, fmt.Sprintf("gRPC check %s: success", c.GRPC))
699	}
700}
701
702func (c *CheckGRPC) Stop() {
703	c.stopLock.Lock()
704	defer c.stopLock.Unlock()
705	if !c.stop {
706		c.stop = true
707		close(c.stopCh)
708	}
709}
710