1package taskrunner
2
3import (
4	"context"
5	"fmt"
6	"sync"
7	"time"
8
9	"github.com/hashicorp/consul/api"
10	log "github.com/hashicorp/go-hclog"
11	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
12	tinterfaces "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces"
13	"github.com/hashicorp/nomad/client/consul"
14	"github.com/hashicorp/nomad/client/taskenv"
15	agentconsul "github.com/hashicorp/nomad/command/agent/consul"
16	"github.com/hashicorp/nomad/nomad/structs"
17)
18
19var _ interfaces.TaskPoststartHook = &scriptCheckHook{}
20var _ interfaces.TaskUpdateHook = &scriptCheckHook{}
21var _ interfaces.TaskStopHook = &scriptCheckHook{}
22
23// default max amount of time to wait for all scripts on shutdown.
24const defaultShutdownWait = time.Minute
25
26type scriptCheckHookConfig struct {
27	alloc        *structs.Allocation
28	task         *structs.Task
29	consul       consul.ConsulServiceAPI
30	logger       log.Logger
31	shutdownWait time.Duration
32}
33
34// scriptCheckHook implements a task runner hook for running script
35// checks in the context of a task
36type scriptCheckHook struct {
37	consul       consul.ConsulServiceAPI
38	alloc        *structs.Allocation
39	task         *structs.Task
40	logger       log.Logger
41	shutdownWait time.Duration // max time to wait for scripts to shutdown
42	shutdownCh   chan struct{} // closed when all scripts should shutdown
43
44	// The following fields can be changed by Update()
45	driverExec tinterfaces.ScriptExecutor
46	taskEnv    *taskenv.TaskEnv
47
48	// These maintain state and are populated by Poststart() or Update()
49	scripts        map[string]*scriptCheck
50	runningScripts map[string]*taskletHandle
51
52	// Since Update() may be called concurrently with any other hook all
53	// hook methods must be fully serialized
54	mu sync.Mutex
55}
56
57// newScriptCheckHook returns a hook without any scriptChecks.
58// They will get created only once their task environment is ready
59// in Poststart() or Update()
60func newScriptCheckHook(c scriptCheckHookConfig) *scriptCheckHook {
61	h := &scriptCheckHook{
62		consul:         c.consul,
63		alloc:          c.alloc,
64		task:           c.task,
65		scripts:        make(map[string]*scriptCheck),
66		runningScripts: make(map[string]*taskletHandle),
67		shutdownWait:   defaultShutdownWait,
68		shutdownCh:     make(chan struct{}),
69	}
70
71	if c.shutdownWait != 0 {
72		h.shutdownWait = c.shutdownWait // override for testing
73	}
74	h.logger = c.logger.Named(h.Name())
75	return h
76}
77
78func (h *scriptCheckHook) Name() string {
79	return "script_checks"
80}
81
82// Prestart implements interfaces.TaskPrestartHook. It stores the
83// initial structs.Task
84func (h *scriptCheckHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, _ *interfaces.TaskPrestartResponse) error {
85	h.mu.Lock()
86	defer h.mu.Unlock()
87	h.task = req.Task
88	return nil
89}
90
91// PostStart implements interfaces.TaskPoststartHook. It creates new
92// script checks with the current task context (driver and env), and
93// starts up the scripts.
94func (h *scriptCheckHook) Poststart(ctx context.Context, req *interfaces.TaskPoststartRequest, _ *interfaces.TaskPoststartResponse) error {
95	h.mu.Lock()
96	defer h.mu.Unlock()
97
98	if req.DriverExec == nil {
99		h.logger.Debug("driver doesn't support script checks")
100		return nil
101	}
102	h.driverExec = req.DriverExec
103	h.taskEnv = req.TaskEnv
104
105	return h.upsertChecks()
106}
107
108// Updated implements interfaces.TaskUpdateHook. It creates new
109// script checks with the current task context (driver and env and possibly
110// new structs.Task), and starts up the scripts.
111func (h *scriptCheckHook) Update(ctx context.Context, req *interfaces.TaskUpdateRequest, _ *interfaces.TaskUpdateResponse) error {
112	h.mu.Lock()
113	defer h.mu.Unlock()
114
115	task := req.Alloc.LookupTask(h.task.Name)
116	if task == nil {
117		return fmt.Errorf("task %q not found in updated alloc", h.task.Name)
118	}
119	h.alloc = req.Alloc
120	h.task = task
121	h.taskEnv = req.TaskEnv
122
123	return h.upsertChecks()
124}
125
126func (h *scriptCheckHook) upsertChecks() error {
127	// Create new script checks struct with new task context
128	oldScriptChecks := h.scripts
129	h.scripts = h.newScriptChecks()
130
131	// Run new or replacement scripts
132	for id, script := range h.scripts {
133		// If it's already running, cancel and replace
134		if oldScript, running := h.runningScripts[id]; running {
135			oldScript.cancel()
136		}
137		// Start and store the handle
138		h.runningScripts[id] = script.run()
139	}
140
141	// Cancel scripts we no longer want
142	for id := range oldScriptChecks {
143		if _, ok := h.scripts[id]; !ok {
144			if oldScript, running := h.runningScripts[id]; running {
145				oldScript.cancel()
146			}
147		}
148	}
149	return nil
150}
151
152// Stop implements interfaces.TaskStopHook and blocks waiting for running
153// scripts to finish (or for the shutdownWait timeout to expire).
154func (h *scriptCheckHook) Stop(ctx context.Context, req *interfaces.TaskStopRequest, resp *interfaces.TaskStopResponse) error {
155	h.mu.Lock()
156	defer h.mu.Unlock()
157	close(h.shutdownCh)
158	deadline := time.After(h.shutdownWait)
159	err := fmt.Errorf("timed out waiting for script checks to exit")
160	for _, script := range h.runningScripts {
161		select {
162		case <-script.wait():
163		case <-ctx.Done():
164			// the caller is passing the background context, so
165			// we should never really see this outside of testing
166		case <-deadline:
167			// at this point the Consul client has been cleaned
168			// up so we don't want to hang onto this.
169			return err
170		}
171	}
172	return nil
173}
174
175func (h *scriptCheckHook) newScriptChecks() map[string]*scriptCheck {
176	scriptChecks := make(map[string]*scriptCheck)
177	for _, service := range h.task.Services {
178		for _, check := range service.Checks {
179			if check.Type != structs.ServiceCheckScript {
180				continue
181			}
182			serviceID := agentconsul.MakeAllocServiceID(
183				h.alloc.ID, h.task.Name, service)
184			sc := newScriptCheck(&scriptCheckConfig{
185				allocID:    h.alloc.ID,
186				taskName:   h.task.Name,
187				check:      check,
188				serviceID:  serviceID,
189				agent:      h.consul,
190				driverExec: h.driverExec,
191				taskEnv:    h.taskEnv,
192				logger:     h.logger,
193				shutdownCh: h.shutdownCh,
194			})
195			if sc != nil {
196				scriptChecks[sc.id] = sc
197			}
198		}
199	}
200
201	// Walk back through the task group to see if there are script checks
202	// associated with the task. If so, we'll create scriptCheck tasklets
203	// for them. The group-level service and any check restart behaviors it
204	// needs are entirely encapsulated within the group service hook which
205	// watches Consul for status changes.
206	tg := h.alloc.Job.LookupTaskGroup(h.alloc.TaskGroup)
207	for _, service := range tg.Services {
208		for _, check := range service.Checks {
209			if check.Type != structs.ServiceCheckScript {
210				continue
211			}
212			if check.TaskName != h.task.Name {
213				continue
214			}
215			groupTaskName := "group-" + tg.Name
216			serviceID := agentconsul.MakeAllocServiceID(
217				h.alloc.ID, groupTaskName, service)
218			sc := newScriptCheck(&scriptCheckConfig{
219				allocID:    h.alloc.ID,
220				taskName:   groupTaskName,
221				check:      check,
222				serviceID:  serviceID,
223				agent:      h.consul,
224				driverExec: h.driverExec,
225				taskEnv:    h.taskEnv,
226				logger:     h.logger,
227				shutdownCh: h.shutdownCh,
228				isGroup:    true,
229			})
230			if sc != nil {
231				scriptChecks[sc.id] = sc
232			}
233		}
234	}
235	return scriptChecks
236}
237
238// heartbeater is the subset of consul agent functionality needed by script
239// checks to heartbeat
240type heartbeater interface {
241	UpdateTTL(id, output, status string) error
242}
243
244// scriptCheck runs script checks via a interfaces.ScriptExecutor and updates the
245// appropriate check's TTL when the script succeeds.
246type scriptCheck struct {
247	id          string
248	agent       heartbeater
249	check       *structs.ServiceCheck
250	lastCheckOk bool // true if the last check was ok; otherwise false
251	tasklet
252}
253
254// scriptCheckConfig is a parameter struct for newScriptCheck
255type scriptCheckConfig struct {
256	allocID    string
257	taskName   string
258	serviceID  string
259	check      *structs.ServiceCheck
260	agent      heartbeater
261	driverExec tinterfaces.ScriptExecutor
262	taskEnv    *taskenv.TaskEnv
263	logger     log.Logger
264	shutdownCh chan struct{}
265	isGroup    bool
266}
267
268// newScriptCheck constructs a scriptCheck. we're only going to
269// configure the immutable fields of scriptCheck here, with the
270// rest being configured during the Poststart hook so that we have
271// the rest of the task execution environment
272func newScriptCheck(config *scriptCheckConfig) *scriptCheck {
273
274	// Guard against not having a valid taskEnv. This can be the case if the
275	// PreKilling or Exited hook is run before Poststart.
276	if config.taskEnv == nil || config.driverExec == nil {
277		return nil
278	}
279
280	orig := config.check
281	sc := &scriptCheck{
282		agent:       config.agent,
283		check:       config.check.Copy(),
284		lastCheckOk: true, // start logging on first failure
285	}
286
287	// we can't use the promoted fields of tasklet in the struct literal
288	sc.Command = config.taskEnv.ReplaceEnv(config.check.Command)
289	sc.Args = config.taskEnv.ParseAndReplace(config.check.Args)
290	sc.Interval = config.check.Interval
291	sc.Timeout = config.check.Timeout
292	sc.exec = config.driverExec
293	sc.callback = newScriptCheckCallback(sc)
294	sc.logger = config.logger
295	sc.shutdownCh = config.shutdownCh
296
297	// the hash of the interior structs.ServiceCheck is used by the
298	// Consul client to get the ID to register for the check. So we
299	// update it here so that we have the same ID for UpdateTTL.
300
301	// TODO(tgross): this block is similar to one in service_hook
302	// and we can pull that out to a function so we know we're
303	// interpolating the same everywhere
304	sc.check.Name = config.taskEnv.ReplaceEnv(orig.Name)
305	sc.check.Type = config.taskEnv.ReplaceEnv(orig.Type)
306	sc.check.Command = sc.Command
307	sc.check.Args = sc.Args
308	sc.check.Path = config.taskEnv.ReplaceEnv(orig.Path)
309	sc.check.Protocol = config.taskEnv.ReplaceEnv(orig.Protocol)
310	sc.check.PortLabel = config.taskEnv.ReplaceEnv(orig.PortLabel)
311	sc.check.InitialStatus = config.taskEnv.ReplaceEnv(orig.InitialStatus)
312	sc.check.Method = config.taskEnv.ReplaceEnv(orig.Method)
313	sc.check.GRPCService = config.taskEnv.ReplaceEnv(orig.GRPCService)
314	if len(orig.Header) > 0 {
315		header := make(map[string][]string, len(orig.Header))
316		for k, vs := range orig.Header {
317			newVals := make([]string, len(vs))
318			for i, v := range vs {
319				newVals[i] = config.taskEnv.ReplaceEnv(v)
320			}
321			header[config.taskEnv.ReplaceEnv(k)] = newVals
322		}
323		sc.check.Header = header
324	}
325	if config.isGroup {
326		// TODO(tgross):
327		// group services don't have access to a task environment
328		// at creation, so their checks get registered before the
329		// check can be interpolated here. if we don't use the
330		// original checkID, they can't be updated.
331		sc.id = agentconsul.MakeCheckID(config.serviceID, orig)
332	} else {
333		sc.id = agentconsul.MakeCheckID(config.serviceID, sc.check)
334	}
335	return sc
336}
337
338// Copy does a *shallow* copy of script checks.
339func (sc *scriptCheck) Copy() *scriptCheck {
340	newSc := sc
341	return newSc
342}
343
344// closes over the script check and returns the taskletCallback for
345// when the script check executes.
346func newScriptCheckCallback(s *scriptCheck) taskletCallback {
347
348	return func(ctx context.Context, params execResult) {
349		output := params.output
350		code := params.code
351		err := params.err
352
353		state := api.HealthCritical
354		switch code {
355		case 0:
356			state = api.HealthPassing
357		case 1:
358			state = api.HealthWarning
359		}
360
361		var outputMsg string
362		if err != nil {
363			state = api.HealthCritical
364			outputMsg = err.Error()
365		} else {
366			outputMsg = string(output)
367		}
368
369		// heartbeat the check to Consul
370		err = s.updateTTL(ctx, outputMsg, state)
371		select {
372		case <-ctx.Done():
373			// check has been removed; don't report errors
374			return
375		default:
376		}
377
378		if err != nil {
379			if s.lastCheckOk {
380				s.lastCheckOk = false
381				s.logger.Warn("updating check failed", "error", err)
382			} else {
383				s.logger.Debug("updating check still failing", "error", err)
384			}
385
386		} else if !s.lastCheckOk {
387			// Succeeded for the first time or after failing; log
388			s.lastCheckOk = true
389			s.logger.Info("updating check succeeded")
390		}
391	}
392}
393
394const (
395	updateTTLBackoffBaseline = 1 * time.Second
396	updateTTLBackoffLimit    = 3 * time.Second
397)
398
399// updateTTL updates the state to Consul, performing an expontential backoff
400// in the case where the check isn't registered in Consul to avoid a race between
401// service registration and the first check.
402func (s *scriptCheck) updateTTL(ctx context.Context, msg, state string) error {
403	for attempts := 0; ; attempts++ {
404		err := s.agent.UpdateTTL(s.id, msg, state)
405		if err == nil {
406			return nil
407		}
408
409		// Handle the retry case
410		backoff := (1 << (2 * uint64(attempts))) * updateTTLBackoffBaseline
411		if backoff > updateTTLBackoffLimit {
412			return err
413		}
414
415		// Wait till retrying
416		select {
417		case <-ctx.Done():
418			return err
419		case <-time.After(backoff):
420		}
421	}
422}
423