1// Package check standardizes /health and /ready endpoints.
2// This allows you to easily know when your server is ready and healthy.
3package check
4
5import (
6	"context"
7	"encoding/json"
8	"fmt"
9	"net/http"
10	"sort"
11	"sync"
12)
13
14// Status string to indicate the overall status of the check.
15type Status string
16
17const (
18	// StatusFail indicates a specific check has failed.
19	StatusFail Status = "fail"
20	// StatusPass indicates a specific check has passed.
21	StatusPass Status = "pass"
22
23	// DefaultCheckName is the name of the default checker.
24	DefaultCheckName = "internal"
25)
26
27// Check wraps a map of service names to status checkers.
28type Check struct {
29	healthChecks   []Checker
30	readyChecks    []Checker
31	healthOverride override
32	readyOverride  override
33
34	passthroughHandler http.Handler
35}
36
37// Checker indicates a service whose health can be checked.
38type Checker interface {
39	Check(ctx context.Context) Response
40}
41
42// NewCheck returns a Health with a default checker.
43func NewCheck() *Check {
44	ch := &Check{}
45	ch.healthOverride.disable()
46	ch.readyOverride.disable()
47	return ch
48}
49
50// AddHealthCheck adds the check to the list of ready checks.
51// If c is a NamedChecker, the name will be added.
52func (c *Check) AddHealthCheck(check Checker) {
53	if nc, ok := check.(NamedChecker); ok {
54		c.healthChecks = append(c.healthChecks, Named(nc.CheckName(), nc))
55	} else {
56		c.healthChecks = append(c.healthChecks, check)
57	}
58}
59
60// AddReadyCheck adds the check to the list of ready checks.
61// If c is a NamedChecker, the name will be added.
62func (c *Check) AddReadyCheck(check Checker) {
63	if nc, ok := check.(NamedChecker); ok {
64		c.readyChecks = append(c.readyChecks, Named(nc.CheckName(), nc))
65	} else {
66		c.readyChecks = append(c.readyChecks, check)
67	}
68}
69
70// CheckHealth evaluates c's set of health checks and returns a populated Response.
71func (c *Check) CheckHealth(ctx context.Context) Response {
72	response := Response{
73		Name:   "Health",
74		Status: StatusPass,
75		Checks: make(Responses, len(c.healthChecks)),
76	}
77
78	status, overriding := c.healthOverride.get()
79	if overriding {
80		response.Status = status
81		overrideResponse := Response{
82			Name:    "manual-override",
83			Message: "health manually overridden",
84		}
85		response.Checks = append(response.Checks, overrideResponse)
86	}
87	for i, ch := range c.healthChecks {
88		resp := ch.Check(ctx)
89		if resp.Status != StatusPass && !overriding {
90			response.Status = resp.Status
91		}
92		response.Checks[i] = resp
93	}
94	sort.Sort(response.Checks)
95	return response
96}
97
98// CheckReady evaluates c's set of ready checks and returns a populated Response.
99func (c *Check) CheckReady(ctx context.Context) Response {
100	response := Response{
101		Name:   "Ready",
102		Status: StatusPass,
103		Checks: make(Responses, len(c.readyChecks)),
104	}
105
106	status, overriding := c.readyOverride.get()
107	if overriding {
108		response.Status = status
109		overrideResponse := Response{
110			Name:    "manual-override",
111			Message: "ready manually overridden",
112		}
113		response.Checks = append(response.Checks, overrideResponse)
114	}
115	for i, c := range c.readyChecks {
116		resp := c.Check(ctx)
117		if resp.Status != StatusPass && !overriding {
118			response.Status = resp.Status
119		}
120		response.Checks[i] = resp
121	}
122	sort.Sort(response.Checks)
123	return response
124}
125
126// SetPassthrough allows you to set a handler to use if the request is not a ready or health check.
127// This can be useful if you intend to use this as a middleware.
128func (c *Check) SetPassthrough(h http.Handler) {
129	c.passthroughHandler = h
130}
131
132// ServeHTTP serves /ready and /health requests with the respective checks.
133func (c *Check) ServeHTTP(w http.ResponseWriter, r *http.Request) {
134	const (
135		pathReady  = "/ready"
136		pathHealth = "/health"
137		queryForce = "force"
138	)
139
140	path := r.URL.Path
141
142	// Allow requests not intended for checks to pass through.
143	if path != pathReady && path != pathHealth {
144		if c.passthroughHandler != nil {
145			c.passthroughHandler.ServeHTTP(w, r)
146			return
147		}
148
149		// We can't handle this request.
150		w.WriteHeader(http.StatusNotFound)
151		return
152	}
153
154	ctx := r.Context()
155	query := r.URL.Query()
156
157	switch path {
158	case pathReady:
159		switch query.Get(queryForce) {
160		case "true":
161			switch query.Get("ready") {
162			case "true":
163				c.readyOverride.enable(StatusPass)
164			case "false":
165				c.readyOverride.enable(StatusFail)
166			}
167		case "false":
168			c.readyOverride.disable()
169		}
170		writeResponse(w, c.CheckReady(ctx))
171	case pathHealth:
172		switch query.Get(queryForce) {
173		case "true":
174			switch query.Get("healthy") {
175			case "true":
176				c.healthOverride.enable(StatusPass)
177			case "false":
178				c.healthOverride.enable(StatusFail)
179			}
180		case "false":
181			c.healthOverride.disable()
182		}
183		writeResponse(w, c.CheckHealth(ctx))
184	}
185}
186
187// writeResponse writes a Response to the wire as JSON. The HTTP status code
188// accompanying the payload is the primary means for signaling the status of the
189// checks. The possible status codes are:
190//
191// - 200 OK: All checks pass.
192// - 503 Service Unavailable: Some checks are failing.
193// - 500 Internal Server Error: There was a problem serializing the Response.
194func writeResponse(w http.ResponseWriter, resp Response) {
195	status := http.StatusOK
196	if resp.Status == StatusFail {
197		status = http.StatusServiceUnavailable
198	}
199
200	msg, err := json.MarshalIndent(resp, "", "  ")
201	if err != nil {
202		msg = []byte(`{"message": "error marshaling response", "status": "fail"}`)
203		status = http.StatusInternalServerError
204	}
205	w.WriteHeader(status)
206	fmt.Fprintln(w, string(msg))
207}
208
209// override is a manual override for an entire group of checks.
210type override struct {
211	mtx    sync.Mutex
212	status Status
213	active bool
214}
215
216// get returns the Status of an override as well as whether or not an override
217// is currently active.
218func (m *override) get() (Status, bool) {
219	m.mtx.Lock()
220	defer m.mtx.Unlock()
221	return m.status, m.active
222}
223
224// disable disables the override.
225func (m *override) disable() {
226	m.mtx.Lock()
227	m.active = false
228	m.status = StatusFail
229	m.mtx.Unlock()
230}
231
232// enable turns on the override and establishes a specific Status for which to.
233func (m *override) enable(s Status) {
234	m.mtx.Lock()
235	m.active = true
236	m.status = s
237	m.mtx.Unlock()
238}
239