1package taskrunner
2
3import (
4	"context"
5	"fmt"
6	"sync"
7	"time"
8
9	"github.com/hashicorp/consul/api"
10	log "github.com/hashicorp/go-hclog"
11	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
12	tinterfaces "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces"
13	"github.com/hashicorp/nomad/client/consul"
14	"github.com/hashicorp/nomad/client/taskenv"
15	agentconsul "github.com/hashicorp/nomad/command/agent/consul"
16	"github.com/hashicorp/nomad/nomad/structs"
17)
18
19var _ interfaces.TaskPoststartHook = &scriptCheckHook{}
20var _ interfaces.TaskUpdateHook = &scriptCheckHook{}
21var _ interfaces.TaskStopHook = &scriptCheckHook{}
22
23// default max amount of time to wait for all scripts on shutdown.
24const defaultShutdownWait = time.Minute
25
26type scriptCheckHookConfig struct {
27	alloc        *structs.Allocation
28	task         *structs.Task
29	consul       consul.ConsulServiceAPI
30	logger       log.Logger
31	shutdownWait time.Duration
32}
33
34// scriptCheckHook implements a task runner hook for running script
35// checks in the context of a task
36type scriptCheckHook struct {
37	consul          consul.ConsulServiceAPI
38	consulNamespace string
39	alloc           *structs.Allocation
40	task            *structs.Task
41	logger          log.Logger
42	shutdownWait    time.Duration // max time to wait for scripts to shutdown
43	shutdownCh      chan struct{} // closed when all scripts should shutdown
44
45	// The following fields can be changed by Update()
46	driverExec tinterfaces.ScriptExecutor
47	taskEnv    *taskenv.TaskEnv
48
49	// These maintain state and are populated by Poststart() or Update()
50	scripts        map[string]*scriptCheck
51	runningScripts map[string]*taskletHandle
52
53	// Since Update() may be called concurrently with any other hook all
54	// hook methods must be fully serialized
55	mu sync.Mutex
56}
57
58// newScriptCheckHook returns a hook without any scriptChecks.
59// They will get created only once their task environment is ready
60// in Poststart() or Update()
61func newScriptCheckHook(c scriptCheckHookConfig) *scriptCheckHook {
62	h := &scriptCheckHook{
63		consul:          c.consul,
64		consulNamespace: c.alloc.Job.LookupTaskGroup(c.alloc.TaskGroup).Consul.GetNamespace(),
65		alloc:           c.alloc,
66		task:            c.task,
67		scripts:         make(map[string]*scriptCheck),
68		runningScripts:  make(map[string]*taskletHandle),
69		shutdownWait:    defaultShutdownWait,
70		shutdownCh:      make(chan struct{}),
71	}
72
73	if c.shutdownWait != 0 {
74		h.shutdownWait = c.shutdownWait // override for testing
75	}
76	h.logger = c.logger.Named(h.Name())
77	return h
78}
79
80func (h *scriptCheckHook) Name() string {
81	return "script_checks"
82}
83
84// Prestart implements interfaces.TaskPrestartHook. It stores the
85// initial structs.Task
86func (h *scriptCheckHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, _ *interfaces.TaskPrestartResponse) error {
87	h.mu.Lock()
88	defer h.mu.Unlock()
89	h.task = req.Task
90	return nil
91}
92
93// PostStart implements interfaces.TaskPoststartHook. It creates new
94// script checks with the current task context (driver and env), and
95// starts up the scripts.
96func (h *scriptCheckHook) Poststart(ctx context.Context, req *interfaces.TaskPoststartRequest, _ *interfaces.TaskPoststartResponse) error {
97	h.mu.Lock()
98	defer h.mu.Unlock()
99
100	if req.DriverExec == nil {
101		h.logger.Debug("driver doesn't support script checks")
102		return nil
103	}
104	h.driverExec = req.DriverExec
105	h.taskEnv = req.TaskEnv
106
107	return h.upsertChecks()
108}
109
110// Updated implements interfaces.TaskUpdateHook. It creates new
111// script checks with the current task context (driver and env and possibly
112// new structs.Task), and starts up the scripts.
113func (h *scriptCheckHook) Update(ctx context.Context, req *interfaces.TaskUpdateRequest, _ *interfaces.TaskUpdateResponse) error {
114	h.mu.Lock()
115	defer h.mu.Unlock()
116
117	task := req.Alloc.LookupTask(h.task.Name)
118	if task == nil {
119		return fmt.Errorf("task %q not found in updated alloc", h.task.Name)
120	}
121	h.alloc = req.Alloc
122	h.task = task
123	h.taskEnv = req.TaskEnv
124
125	return h.upsertChecks()
126}
127
128func (h *scriptCheckHook) upsertChecks() error {
129	// Create new script checks struct with new task context
130	oldScriptChecks := h.scripts
131	h.scripts = h.newScriptChecks()
132
133	// Run new or replacement scripts
134	for id, script := range h.scripts {
135		// If it's already running, cancel and replace
136		if oldScript, running := h.runningScripts[id]; running {
137			oldScript.cancel()
138		}
139		// Start and store the handle
140		h.runningScripts[id] = script.run()
141	}
142
143	// Cancel scripts we no longer want
144	for id := range oldScriptChecks {
145		if _, ok := h.scripts[id]; !ok {
146			if oldScript, running := h.runningScripts[id]; running {
147				oldScript.cancel()
148			}
149		}
150	}
151	return nil
152}
153
154// Stop implements interfaces.TaskStopHook and blocks waiting for running
155// scripts to finish (or for the shutdownWait timeout to expire).
156func (h *scriptCheckHook) Stop(ctx context.Context, req *interfaces.TaskStopRequest, resp *interfaces.TaskStopResponse) error {
157	h.mu.Lock()
158	defer h.mu.Unlock()
159	close(h.shutdownCh)
160	deadline := time.After(h.shutdownWait)
161	err := fmt.Errorf("timed out waiting for script checks to exit")
162	for _, script := range h.runningScripts {
163		select {
164		case <-script.wait():
165		case <-ctx.Done():
166			// the caller is passing the background context, so
167			// we should never really see this outside of testing
168		case <-deadline:
169			// at this point the Consul client has been cleaned
170			// up so we don't want to hang onto this.
171			return err
172		}
173	}
174	return nil
175}
176
177func (h *scriptCheckHook) newScriptChecks() map[string]*scriptCheck {
178	scriptChecks := make(map[string]*scriptCheck)
179	interpolatedTaskServices := taskenv.InterpolateServices(h.taskEnv, h.task.Services)
180	for _, service := range interpolatedTaskServices {
181		for _, check := range service.Checks {
182			if check.Type != structs.ServiceCheckScript {
183				continue
184			}
185			serviceID := agentconsul.MakeAllocServiceID(
186				h.alloc.ID, h.task.Name, service)
187			sc := newScriptCheck(&scriptCheckConfig{
188				consulNamespace: h.consulNamespace,
189				allocID:         h.alloc.ID,
190				taskName:        h.task.Name,
191				check:           check,
192				serviceID:       serviceID,
193				ttlUpdater:      h.consul,
194				driverExec:      h.driverExec,
195				taskEnv:         h.taskEnv,
196				logger:          h.logger,
197				shutdownCh:      h.shutdownCh,
198			})
199			if sc != nil {
200				scriptChecks[sc.id] = sc
201			}
202		}
203	}
204
205	// Walk back through the task group to see if there are script checks
206	// associated with the task. If so, we'll create scriptCheck tasklets
207	// for them. The group-level service and any check restart behaviors it
208	// needs are entirely encapsulated within the group service hook which
209	// watches Consul for status changes.
210	//
211	// The script check is associated with a group task if the service.task or
212	// service.check.task matches the task name. The service.check.task takes
213	// precedence.
214	tg := h.alloc.Job.LookupTaskGroup(h.alloc.TaskGroup)
215	interpolatedGroupServices := taskenv.InterpolateServices(h.taskEnv, tg.Services)
216	for _, service := range interpolatedGroupServices {
217		for _, check := range service.Checks {
218			if check.Type != structs.ServiceCheckScript {
219				continue
220			}
221			if !h.associated(h.task.Name, service.TaskName, check.TaskName) {
222				continue
223			}
224			groupTaskName := "group-" + tg.Name
225			serviceID := agentconsul.MakeAllocServiceID(
226				h.alloc.ID, groupTaskName, service)
227			sc := newScriptCheck(&scriptCheckConfig{
228				consulNamespace: h.consulNamespace,
229				allocID:         h.alloc.ID,
230				taskName:        groupTaskName,
231				check:           check,
232				serviceID:       serviceID,
233				ttlUpdater:      h.consul,
234				driverExec:      h.driverExec,
235				taskEnv:         h.taskEnv,
236				logger:          h.logger,
237				shutdownCh:      h.shutdownCh,
238				isGroup:         true,
239			})
240			if sc != nil {
241				scriptChecks[sc.id] = sc
242			}
243		}
244	}
245	return scriptChecks
246}
247
248// associated returns true if the script check is associated with the task. This
249// would be the case if the check.task is the same as task, or if the service.task
250// is the same as the task _and_ check.task is not configured (i.e. the check
251// inherits the task of the service).
252func (*scriptCheckHook) associated(task, serviceTask, checkTask string) bool {
253	if checkTask == task {
254		return true
255	}
256	if serviceTask == task && checkTask == "" {
257		return true
258	}
259	return false
260}
261
262// TTLUpdater is the subset of consul agent functionality needed by script
263// checks to heartbeat
264type TTLUpdater interface {
265	UpdateTTL(id, namespace, output, status string) error
266}
267
268// scriptCheck runs script checks via a interfaces.ScriptExecutor and updates the
269// appropriate check's TTL when the script succeeds.
270type scriptCheck struct {
271	id              string
272	consulNamespace string
273	ttlUpdater      TTLUpdater
274	check           *structs.ServiceCheck
275	lastCheckOk     bool // true if the last check was ok; otherwise false
276	tasklet
277}
278
279// scriptCheckConfig is a parameter struct for newScriptCheck
280type scriptCheckConfig struct {
281	allocID         string
282	taskName        string
283	serviceID       string
284	consulNamespace string
285	check           *structs.ServiceCheck
286	ttlUpdater      TTLUpdater
287	driverExec      tinterfaces.ScriptExecutor
288	taskEnv         *taskenv.TaskEnv
289	logger          log.Logger
290	shutdownCh      chan struct{}
291	isGroup         bool
292}
293
294// newScriptCheck constructs a scriptCheck. we're only going to
295// configure the immutable fields of scriptCheck here, with the
296// rest being configured during the Poststart hook so that we have
297// the rest of the task execution environment
298func newScriptCheck(config *scriptCheckConfig) *scriptCheck {
299
300	// Guard against not having a valid taskEnv. This can be the case if the
301	// PreKilling or Exited hook is run before Poststart.
302	if config.taskEnv == nil || config.driverExec == nil {
303		return nil
304	}
305
306	orig := config.check
307	sc := &scriptCheck{
308		ttlUpdater:  config.ttlUpdater,
309		check:       config.check.Copy(),
310		lastCheckOk: true, // start logging on first failure
311	}
312
313	// we can't use the promoted fields of tasklet in the struct literal
314	sc.Command = config.taskEnv.ReplaceEnv(config.check.Command)
315	sc.Args = config.taskEnv.ParseAndReplace(config.check.Args)
316	sc.Interval = config.check.Interval
317	sc.Timeout = config.check.Timeout
318	sc.exec = config.driverExec
319	sc.callback = newScriptCheckCallback(sc)
320	sc.logger = config.logger
321	sc.shutdownCh = config.shutdownCh
322	sc.check.Command = sc.Command
323	sc.check.Args = sc.Args
324
325	if config.isGroup {
326		// group services don't have access to a task environment
327		// at creation, so their checks get registered before the
328		// check can be interpolated here. if we don't use the
329		// original checkID, they can't be updated.
330		sc.id = agentconsul.MakeCheckID(config.serviceID, orig)
331	} else {
332		sc.id = agentconsul.MakeCheckID(config.serviceID, sc.check)
333	}
334	sc.consulNamespace = config.consulNamespace
335	return sc
336}
337
338// Copy does a *shallow* copy of script checks.
339func (sc *scriptCheck) Copy() *scriptCheck {
340	newSc := sc
341	return newSc
342}
343
344// closes over the script check and returns the taskletCallback for
345// when the script check executes.
346func newScriptCheckCallback(s *scriptCheck) taskletCallback {
347
348	return func(ctx context.Context, params execResult) {
349		output := params.output
350		code := params.code
351		err := params.err
352
353		state := api.HealthCritical
354		switch code {
355		case 0:
356			state = api.HealthPassing
357		case 1:
358			state = api.HealthWarning
359		}
360
361		var outputMsg string
362		if err != nil {
363			state = api.HealthCritical
364			outputMsg = err.Error()
365		} else {
366			outputMsg = string(output)
367		}
368
369		// heartbeat the check to Consul
370		err = s.updateTTL(ctx, outputMsg, state)
371		select {
372		case <-ctx.Done():
373			// check has been removed; don't report errors
374			return
375		default:
376		}
377
378		if err != nil {
379			if s.lastCheckOk {
380				s.lastCheckOk = false
381				s.logger.Warn("updating check failed", "error", err)
382			} else {
383				s.logger.Debug("updating check still failing", "error", err)
384			}
385
386		} else if !s.lastCheckOk {
387			// Succeeded for the first time or after failing; log
388			s.lastCheckOk = true
389			s.logger.Info("updating check succeeded")
390		}
391	}
392}
393
394const (
395	updateTTLBackoffBaseline = 1 * time.Second
396	updateTTLBackoffLimit    = 3 * time.Second
397)
398
399// updateTTL updates the state to Consul, performing an exponential backoff
400// in the case where the check isn't registered in Consul to avoid a race between
401// service registration and the first check.
402func (s *scriptCheck) updateTTL(ctx context.Context, msg, state string) error {
403	for attempts := 0; ; attempts++ {
404		err := s.ttlUpdater.UpdateTTL(s.id, s.consulNamespace, msg, state)
405		if err == nil {
406			return nil
407		}
408
409		// Handle the retry case
410		backoff := (1 << (2 * uint64(attempts))) * updateTTLBackoffBaseline
411		if backoff > updateTTLBackoffLimit {
412			return err
413		}
414
415		// Wait till retrying
416		select {
417		case <-ctx.Done():
418			return err
419		case <-time.After(backoff):
420		}
421	}
422}
423