1package taskrunner
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"io/ioutil"
8	"net/http"
9	"net/http/httptest"
10	"os"
11	"path/filepath"
12	"strings"
13	"testing"
14	"time"
15
16	"github.com/golang/snappy"
17	"github.com/hashicorp/nomad/client/allocdir"
18	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
19	"github.com/hashicorp/nomad/client/config"
20	"github.com/hashicorp/nomad/client/consul"
21	consulapi "github.com/hashicorp/nomad/client/consul"
22	"github.com/hashicorp/nomad/client/devicemanager"
23	"github.com/hashicorp/nomad/client/pluginmanager/drivermanager"
24	cstate "github.com/hashicorp/nomad/client/state"
25	ctestutil "github.com/hashicorp/nomad/client/testutil"
26	"github.com/hashicorp/nomad/client/vaultclient"
27	agentconsul "github.com/hashicorp/nomad/command/agent/consul"
28	mockdriver "github.com/hashicorp/nomad/drivers/mock"
29	"github.com/hashicorp/nomad/drivers/rawexec"
30	"github.com/hashicorp/nomad/helper/testlog"
31	"github.com/hashicorp/nomad/helper/uuid"
32	"github.com/hashicorp/nomad/nomad/mock"
33	"github.com/hashicorp/nomad/nomad/structs"
34	"github.com/hashicorp/nomad/plugins/device"
35	"github.com/hashicorp/nomad/plugins/drivers"
36	"github.com/hashicorp/nomad/testutil"
37	"github.com/kr/pretty"
38	"github.com/stretchr/testify/assert"
39	"github.com/stretchr/testify/require"
40)
41
42type MockTaskStateUpdater struct {
43	ch chan struct{}
44}
45
46func NewMockTaskStateUpdater() *MockTaskStateUpdater {
47	return &MockTaskStateUpdater{
48		ch: make(chan struct{}, 1),
49	}
50}
51
52func (m *MockTaskStateUpdater) TaskStateUpdated() {
53	select {
54	case m.ch <- struct{}{}:
55	default:
56	}
57}
58
59// testTaskRunnerConfig returns a taskrunner.Config for the given alloc+task
60// plus a cleanup func.
61func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName string) (*Config, func()) {
62	logger := testlog.HCLogger(t)
63	clientConf, cleanup := config.TestClientConfig(t)
64
65	// Find the task
66	var thisTask *structs.Task
67	for _, tg := range alloc.Job.TaskGroups {
68		for _, task := range tg.Tasks {
69			if task.Name == taskName {
70				if thisTask != nil {
71					cleanup()
72					t.Fatalf("multiple tasks named %q; cannot use this helper", taskName)
73				}
74				thisTask = task
75			}
76		}
77	}
78	if thisTask == nil {
79		cleanup()
80		t.Fatalf("could not find task %q", taskName)
81	}
82
83	// Create the alloc dir + task dir
84	allocPath := filepath.Join(clientConf.AllocDir, alloc.ID)
85	allocDir := allocdir.NewAllocDir(logger, allocPath)
86	if err := allocDir.Build(); err != nil {
87		cleanup()
88		t.Fatalf("error building alloc dir: %v", err)
89	}
90	taskDir := allocDir.NewTaskDir(taskName)
91
92	trCleanup := func() {
93		if err := allocDir.Destroy(); err != nil {
94			t.Logf("error destroying alloc dir: %v", err)
95		}
96		cleanup()
97	}
98
99	// Create a closed channel to mock TaskHookCoordinator.startConditionForTask.
100	// Closed channel indicates this task is not blocked on prestart hooks.
101	closedCh := make(chan struct{})
102	close(closedCh)
103
104	conf := &Config{
105		Alloc:                alloc,
106		ClientConfig:         clientConf,
107		Task:                 thisTask,
108		TaskDir:              taskDir,
109		Logger:               clientConf.Logger,
110		Consul:               consulapi.NewMockConsulServiceClient(t, logger),
111		ConsulSI:             consulapi.NewMockServiceIdentitiesClient(),
112		Vault:                vaultclient.NewMockVaultClient(),
113		StateDB:              cstate.NoopDB{},
114		StateUpdater:         NewMockTaskStateUpdater(),
115		DeviceManager:        devicemanager.NoopMockManager(),
116		DriverManager:        drivermanager.TestDriverManager(t),
117		ServersContactedCh:   make(chan struct{}),
118		StartConditionMetCtx: closedCh,
119	}
120	return conf, trCleanup
121}
122
123// runTestTaskRunner runs a TaskRunner and returns its configuration as well as
124// a cleanup function that ensures the runner is stopped and cleaned up. Tests
125// which need to change the Config *must* use testTaskRunnerConfig instead.
126func runTestTaskRunner(t *testing.T, alloc *structs.Allocation, taskName string) (*TaskRunner, *Config, func()) {
127	config, cleanup := testTaskRunnerConfig(t, alloc, taskName)
128
129	tr, err := NewTaskRunner(config)
130	require.NoError(t, err)
131	go tr.Run()
132
133	return tr, config, func() {
134		tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
135		cleanup()
136	}
137}
138
139func TestTaskRunner_BuildTaskConfig_CPU_Memory(t *testing.T) {
140	t.Parallel()
141
142	cases := []struct {
143		name                  string
144		cpu                   int64
145		memoryMB              int64
146		memoryMaxMB           int64
147		expectedLinuxMemoryMB int64
148	}{
149		{
150			name:                  "plain no max",
151			cpu:                   100,
152			memoryMB:              100,
153			memoryMaxMB:           0,
154			expectedLinuxMemoryMB: 100,
155		},
156		{
157			name:                  "plain with max=reserve",
158			cpu:                   100,
159			memoryMB:              100,
160			memoryMaxMB:           100,
161			expectedLinuxMemoryMB: 100,
162		},
163		{
164			name:                  "plain with max>reserve",
165			cpu:                   100,
166			memoryMB:              100,
167			memoryMaxMB:           200,
168			expectedLinuxMemoryMB: 200,
169		},
170	}
171
172	for _, c := range cases {
173		t.Run(c.name, func(t *testing.T) {
174			alloc := mock.BatchAlloc()
175			alloc.Job.TaskGroups[0].Count = 1
176			task := alloc.Job.TaskGroups[0].Tasks[0]
177			task.Driver = "mock_driver"
178			task.Config = map[string]interface{}{
179				"run_for": "2s",
180			}
181			res := alloc.AllocatedResources.Tasks[task.Name]
182			res.Cpu.CpuShares = c.cpu
183			res.Memory.MemoryMB = c.memoryMB
184			res.Memory.MemoryMaxMB = c.memoryMaxMB
185
186			conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
187			conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
188			defer cleanup()
189
190			// Run the first TaskRunner
191			tr, err := NewTaskRunner(conf)
192			require.NoError(t, err)
193
194			tc := tr.buildTaskConfig()
195			require.Equal(t, c.cpu, tc.Resources.LinuxResources.CPUShares)
196			require.Equal(t, c.expectedLinuxMemoryMB*1024*1024, tc.Resources.LinuxResources.MemoryLimitBytes)
197
198			require.Equal(t, c.cpu, tc.Resources.NomadResources.Cpu.CpuShares)
199			require.Equal(t, c.memoryMB, tc.Resources.NomadResources.Memory.MemoryMB)
200			require.Equal(t, c.memoryMaxMB, tc.Resources.NomadResources.Memory.MemoryMaxMB)
201		})
202	}
203}
204
205// TestTaskRunner_Stop_ExitCode asserts that the exit code is captured on a task, even if it's stopped
206func TestTaskRunner_Stop_ExitCode(t *testing.T) {
207	ctestutil.ExecCompatible(t)
208	t.Parallel()
209
210	alloc := mock.BatchAlloc()
211	alloc.Job.TaskGroups[0].Count = 1
212	task := alloc.Job.TaskGroups[0].Tasks[0]
213	task.KillSignal = "SIGTERM"
214	task.Driver = "raw_exec"
215	task.Config = map[string]interface{}{
216		"command": "/bin/sleep",
217		"args":    []string{"1000"},
218	}
219
220	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
221	defer cleanup()
222
223	// Run the first TaskRunner
224	tr, err := NewTaskRunner(conf)
225	require.NoError(t, err)
226	go tr.Run()
227
228	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
229
230	// Wait for it to be running
231	testWaitForTaskToStart(t, tr)
232
233	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
234	defer cancel()
235
236	err = tr.Kill(ctx, structs.NewTaskEvent("shutdown"))
237	require.NoError(t, err)
238
239	var exitEvent *structs.TaskEvent
240	state := tr.TaskState()
241	for _, e := range state.Events {
242		if e.Type == structs.TaskTerminated {
243			exitEvent = e
244			break
245		}
246	}
247	require.NotNilf(t, exitEvent, "exit event not found: %v", state.Events)
248
249	require.Equal(t, 143, exitEvent.ExitCode)
250	require.Equal(t, 15, exitEvent.Signal)
251
252}
253
254// TestTaskRunner_Restore_Running asserts restoring a running task does not
255// rerun the task.
256func TestTaskRunner_Restore_Running(t *testing.T) {
257	t.Parallel()
258	require := require.New(t)
259
260	alloc := mock.BatchAlloc()
261	alloc.Job.TaskGroups[0].Count = 1
262	task := alloc.Job.TaskGroups[0].Tasks[0]
263	task.Driver = "mock_driver"
264	task.Config = map[string]interface{}{
265		"run_for": "2s",
266	}
267	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
268	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
269	defer cleanup()
270
271	// Run the first TaskRunner
272	origTR, err := NewTaskRunner(conf)
273	require.NoError(err)
274	go origTR.Run()
275	defer origTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
276
277	// Wait for it to be running
278	testWaitForTaskToStart(t, origTR)
279
280	// Cause TR to exit without shutting down task
281	origTR.Shutdown()
282
283	// Start a new TaskRunner and make sure it does not rerun the task
284	newTR, err := NewTaskRunner(conf)
285	require.NoError(err)
286
287	// Do the Restore
288	require.NoError(newTR.Restore())
289
290	go newTR.Run()
291	defer newTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
292
293	// Wait for new task runner to exit when the process does
294	<-newTR.WaitCh()
295
296	// Assert that the process was only started once
297	started := 0
298	state := newTR.TaskState()
299	require.Equal(structs.TaskStateDead, state.State)
300	for _, ev := range state.Events {
301		if ev.Type == structs.TaskStarted {
302			started++
303		}
304	}
305	assert.Equal(t, 1, started)
306}
307
308// setupRestoreFailureTest starts a service, shuts down the task runner, and
309// kills the task before restarting a new TaskRunner. The new TaskRunner is
310// returned once it is running and waiting in pending along with a cleanup
311// func.
312func setupRestoreFailureTest(t *testing.T, alloc *structs.Allocation) (*TaskRunner, *Config, func()) {
313	t.Parallel()
314
315	task := alloc.Job.TaskGroups[0].Tasks[0]
316	task.Driver = "raw_exec"
317	task.Config = map[string]interface{}{
318		"command": "sleep",
319		"args":    []string{"30"},
320	}
321	conf, cleanup1 := testTaskRunnerConfig(t, alloc, task.Name)
322	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between runs
323
324	// Run the first TaskRunner
325	origTR, err := NewTaskRunner(conf)
326	require.NoError(t, err)
327	go origTR.Run()
328	cleanup2 := func() {
329		origTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
330		cleanup1()
331	}
332
333	// Wait for it to be running
334	testWaitForTaskToStart(t, origTR)
335
336	handle := origTR.getDriverHandle()
337	require.NotNil(t, handle)
338	taskID := handle.taskID
339
340	// Cause TR to exit without shutting down task
341	origTR.Shutdown()
342
343	// Get the driver
344	driverPlugin, err := conf.DriverManager.Dispense(rawexec.PluginID.Name)
345	require.NoError(t, err)
346	rawexecDriver := driverPlugin.(*rawexec.Driver)
347
348	// Assert the task is still running despite TR having exited
349	taskStatus, err := rawexecDriver.InspectTask(taskID)
350	require.NoError(t, err)
351	require.Equal(t, drivers.TaskStateRunning, taskStatus.State)
352
353	// Kill the task so it fails to recover when restore is called
354	require.NoError(t, rawexecDriver.DestroyTask(taskID, true))
355	_, err = rawexecDriver.InspectTask(taskID)
356	require.EqualError(t, err, drivers.ErrTaskNotFound.Error())
357
358	// Create a new TaskRunner and Restore the task
359	conf.ServersContactedCh = make(chan struct{})
360	newTR, err := NewTaskRunner(conf)
361	require.NoError(t, err)
362
363	// Assert the TR will wait on servers because reattachment failed
364	require.NoError(t, newTR.Restore())
365	require.True(t, newTR.waitOnServers)
366
367	// Start new TR
368	go newTR.Run()
369	cleanup3 := func() {
370		newTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
371		cleanup2()
372		cleanup1()
373	}
374
375	// Assert task has not been restarted
376	_, err = rawexecDriver.InspectTask(taskID)
377	require.EqualError(t, err, drivers.ErrTaskNotFound.Error())
378	ts := newTR.TaskState()
379	require.Equal(t, structs.TaskStatePending, ts.State)
380
381	return newTR, conf, cleanup3
382}
383
384// TestTaskRunner_Restore_Restart asserts restoring a dead task blocks until
385// MarkAlive is called. #1795
386func TestTaskRunner_Restore_Restart(t *testing.T) {
387	newTR, conf, cleanup := setupRestoreFailureTest(t, mock.Alloc())
388	defer cleanup()
389
390	// Fake contacting the server by closing the chan
391	close(conf.ServersContactedCh)
392
393	testutil.WaitForResult(func() (bool, error) {
394		ts := newTR.TaskState().State
395		return ts == structs.TaskStateRunning, fmt.Errorf("expected task to be running but found %q", ts)
396	}, func(err error) {
397		require.NoError(t, err)
398	})
399}
400
401// TestTaskRunner_Restore_Kill asserts restoring a dead task blocks until
402// the task is killed. #1795
403func TestTaskRunner_Restore_Kill(t *testing.T) {
404	newTR, _, cleanup := setupRestoreFailureTest(t, mock.Alloc())
405	defer cleanup()
406
407	// Sending the task a terminal update shouldn't kill it or unblock it
408	alloc := newTR.Alloc().Copy()
409	alloc.DesiredStatus = structs.AllocDesiredStatusStop
410	newTR.Update(alloc)
411
412	require.Equal(t, structs.TaskStatePending, newTR.TaskState().State)
413
414	// AllocRunner will immediately kill tasks after sending a terminal
415	// update.
416	newTR.Kill(context.Background(), structs.NewTaskEvent(structs.TaskKilling))
417
418	select {
419	case <-newTR.WaitCh():
420		// It died as expected!
421	case <-time.After(10 * time.Second):
422		require.Fail(t, "timeout waiting for task to die")
423	}
424}
425
426// TestTaskRunner_Restore_Update asserts restoring a dead task blocks until
427// Update is called. #1795
428func TestTaskRunner_Restore_Update(t *testing.T) {
429	newTR, conf, cleanup := setupRestoreFailureTest(t, mock.Alloc())
430	defer cleanup()
431
432	// Fake Client.runAllocs behavior by calling Update then closing chan
433	alloc := newTR.Alloc().Copy()
434	newTR.Update(alloc)
435
436	// Update alone should not unblock the test
437	require.Equal(t, structs.TaskStatePending, newTR.TaskState().State)
438
439	// Fake Client.runAllocs behavior of closing chan after Update
440	close(conf.ServersContactedCh)
441
442	testutil.WaitForResult(func() (bool, error) {
443		ts := newTR.TaskState().State
444		return ts == structs.TaskStateRunning, fmt.Errorf("expected task to be running but found %q", ts)
445	}, func(err error) {
446		require.NoError(t, err)
447	})
448}
449
450// TestTaskRunner_Restore_System asserts restoring a dead system task does not
451// block.
452func TestTaskRunner_Restore_System(t *testing.T) {
453	t.Parallel()
454
455	alloc := mock.Alloc()
456	alloc.Job.Type = structs.JobTypeSystem
457	task := alloc.Job.TaskGroups[0].Tasks[0]
458	task.Driver = "raw_exec"
459	task.Config = map[string]interface{}{
460		"command": "sleep",
461		"args":    []string{"30"},
462	}
463	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
464	defer cleanup()
465	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between runs
466
467	// Run the first TaskRunner
468	origTR, err := NewTaskRunner(conf)
469	require.NoError(t, err)
470	go origTR.Run()
471	defer origTR.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
472
473	// Wait for it to be running
474	testWaitForTaskToStart(t, origTR)
475
476	handle := origTR.getDriverHandle()
477	require.NotNil(t, handle)
478	taskID := handle.taskID
479
480	// Cause TR to exit without shutting down task
481	origTR.Shutdown()
482
483	// Get the driver
484	driverPlugin, err := conf.DriverManager.Dispense(rawexec.PluginID.Name)
485	require.NoError(t, err)
486	rawexecDriver := driverPlugin.(*rawexec.Driver)
487
488	// Assert the task is still running despite TR having exited
489	taskStatus, err := rawexecDriver.InspectTask(taskID)
490	require.NoError(t, err)
491	require.Equal(t, drivers.TaskStateRunning, taskStatus.State)
492
493	// Kill the task so it fails to recover when restore is called
494	require.NoError(t, rawexecDriver.DestroyTask(taskID, true))
495	_, err = rawexecDriver.InspectTask(taskID)
496	require.EqualError(t, err, drivers.ErrTaskNotFound.Error())
497
498	// Create a new TaskRunner and Restore the task
499	conf.ServersContactedCh = make(chan struct{})
500	newTR, err := NewTaskRunner(conf)
501	require.NoError(t, err)
502
503	// Assert the TR will not wait on servers even though reattachment
504	// failed because it is a system task.
505	require.NoError(t, newTR.Restore())
506	require.False(t, newTR.waitOnServers)
507
508	// Nothing should have closed the chan
509	select {
510	case <-conf.ServersContactedCh:
511		require.Fail(t, "serversContactedCh was closed but should not have been")
512	default:
513	}
514
515	testutil.WaitForResult(func() (bool, error) {
516		ts := newTR.TaskState().State
517		return ts == structs.TaskStateRunning, fmt.Errorf("expected task to be running but found %q", ts)
518	}, func(err error) {
519		require.NoError(t, err)
520	})
521}
522
523// TestTaskRunner_TaskEnv_Interpolated asserts driver configurations are
524// interpolated.
525func TestTaskRunner_TaskEnv_Interpolated(t *testing.T) {
526	t.Parallel()
527	require := require.New(t)
528
529	alloc := mock.BatchAlloc()
530	alloc.Job.TaskGroups[0].Meta = map[string]string{
531		"common_user": "somebody",
532	}
533	task := alloc.Job.TaskGroups[0].Tasks[0]
534	task.Meta = map[string]string{
535		"foo": "bar",
536	}
537
538	// Use interpolation from both node attributes and meta vars
539	task.Config = map[string]interface{}{
540		"run_for":       "1ms",
541		"stdout_string": `${node.region} ${NOMAD_META_foo} ${NOMAD_META_common_user}`,
542	}
543
544	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
545	defer cleanup()
546
547	// Wait for task to complete
548	select {
549	case <-tr.WaitCh():
550	case <-time.After(3 * time.Second):
551		require.Fail("timeout waiting for task to exit")
552	}
553
554	// Get the mock driver plugin
555	driverPlugin, err := conf.DriverManager.Dispense(mockdriver.PluginID.Name)
556	require.NoError(err)
557	mockDriver := driverPlugin.(*mockdriver.Driver)
558
559	// Assert its config has been properly interpolated
560	driverCfg, mockCfg := mockDriver.GetTaskConfig()
561	require.NotNil(driverCfg)
562	require.NotNil(mockCfg)
563	assert.Equal(t, "global bar somebody", mockCfg.StdoutString)
564}
565
566// TestTaskRunner_TaskEnv_Chroot asserts chroot drivers use chroot paths and
567// not host paths.
568func TestTaskRunner_TaskEnv_Chroot(t *testing.T) {
569	ctestutil.ExecCompatible(t)
570	t.Parallel()
571	require := require.New(t)
572
573	alloc := mock.BatchAlloc()
574	task := alloc.Job.TaskGroups[0].Tasks[0]
575	task.Driver = "exec"
576	task.Config = map[string]interface{}{
577		"command": "bash",
578		"args": []string{"-c", "echo $NOMAD_ALLOC_DIR; " +
579			"echo $NOMAD_TASK_DIR; " +
580			"echo $NOMAD_SECRETS_DIR; " +
581			"echo $PATH; ",
582		},
583	}
584
585	// Expect chroot paths and host $PATH
586	exp := fmt.Sprintf(`/alloc
587/local
588/secrets
589%s
590`, os.Getenv("PATH"))
591
592	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
593	defer cleanup()
594
595	// Remove /sbin and /usr from chroot
596	conf.ClientConfig.ChrootEnv = map[string]string{
597		"/bin":            "/bin",
598		"/etc":            "/etc",
599		"/lib":            "/lib",
600		"/lib32":          "/lib32",
601		"/lib64":          "/lib64",
602		"/run/resolvconf": "/run/resolvconf",
603	}
604
605	tr, err := NewTaskRunner(conf)
606	require.NoError(err)
607	go tr.Run()
608	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
609
610	// Wait for task to exit
611	select {
612	case <-tr.WaitCh():
613	case <-time.After(15 * time.Second):
614		require.Fail("timeout waiting for task to exit")
615	}
616
617	// Read stdout
618	p := filepath.Join(conf.TaskDir.LogDir, task.Name+".stdout.0")
619	stdout, err := ioutil.ReadFile(p)
620	require.NoError(err)
621	require.Equalf(exp, string(stdout), "expected: %s\n\nactual: %s\n", exp, stdout)
622}
623
624// TestTaskRunner_TaskEnv_Image asserts image drivers use chroot paths and
625// not host paths. Host env vars should also be excluded.
626func TestTaskRunner_TaskEnv_Image(t *testing.T) {
627	ctestutil.DockerCompatible(t)
628	t.Parallel()
629	require := require.New(t)
630
631	alloc := mock.BatchAlloc()
632	task := alloc.Job.TaskGroups[0].Tasks[0]
633	task.Driver = "docker"
634	task.Config = map[string]interface{}{
635		"image":        "redis:3.2-alpine",
636		"network_mode": "none",
637		"command":      "sh",
638		"args": []string{"-c", "echo $NOMAD_ALLOC_DIR; " +
639			"echo $NOMAD_TASK_DIR; " +
640			"echo $NOMAD_SECRETS_DIR; " +
641			"echo $PATH",
642		},
643	}
644
645	// Expect chroot paths and image specific PATH
646	exp := `/alloc
647/local
648/secrets
649/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
650`
651
652	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
653	defer cleanup()
654
655	// Wait for task to exit
656	select {
657	case <-tr.WaitCh():
658	case <-time.After(15 * time.Second):
659		require.Fail("timeout waiting for task to exit")
660	}
661
662	// Read stdout
663	p := filepath.Join(conf.TaskDir.LogDir, task.Name+".stdout.0")
664	stdout, err := ioutil.ReadFile(p)
665	require.NoError(err)
666	require.Equalf(exp, string(stdout), "expected: %s\n\nactual: %s\n", exp, stdout)
667}
668
669// TestTaskRunner_TaskEnv_None asserts raw_exec uses host paths and env vars.
670func TestTaskRunner_TaskEnv_None(t *testing.T) {
671	t.Parallel()
672	require := require.New(t)
673
674	alloc := mock.BatchAlloc()
675	task := alloc.Job.TaskGroups[0].Tasks[0]
676	task.Driver = "raw_exec"
677	task.Config = map[string]interface{}{
678		"command": "sh",
679		"args": []string{"-c", "echo $NOMAD_ALLOC_DIR; " +
680			"echo $NOMAD_TASK_DIR; " +
681			"echo $NOMAD_SECRETS_DIR; " +
682			"echo $PATH",
683		},
684	}
685
686	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
687	defer cleanup()
688
689	// Expect host paths
690	root := filepath.Join(conf.ClientConfig.AllocDir, alloc.ID)
691	taskDir := filepath.Join(root, task.Name)
692	exp := fmt.Sprintf(`%s/alloc
693%s/local
694%s/secrets
695%s
696`, root, taskDir, taskDir, os.Getenv("PATH"))
697
698	// Wait for task to exit
699	select {
700	case <-tr.WaitCh():
701	case <-time.After(15 * time.Second):
702		require.Fail("timeout waiting for task to exit")
703	}
704
705	// Read stdout
706	p := filepath.Join(conf.TaskDir.LogDir, task.Name+".stdout.0")
707	stdout, err := ioutil.ReadFile(p)
708	require.NoError(err)
709	require.Equalf(exp, string(stdout), "expected: %s\n\nactual: %s\n", exp, stdout)
710}
711
712// Test that devices get sent to the driver
713func TestTaskRunner_DevicePropogation(t *testing.T) {
714	t.Parallel()
715	require := require.New(t)
716
717	// Create a mock alloc that has a gpu
718	alloc := mock.BatchAlloc()
719	alloc.Job.TaskGroups[0].Count = 1
720	task := alloc.Job.TaskGroups[0].Tasks[0]
721	task.Driver = "mock_driver"
722	task.Config = map[string]interface{}{
723		"run_for": "100ms",
724	}
725	tRes := alloc.AllocatedResources.Tasks[task.Name]
726	tRes.Devices = append(tRes.Devices, &structs.AllocatedDeviceResource{Type: "mock"})
727
728	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
729	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between task runners
730	defer cleanup()
731
732	// Setup the devicemanager
733	dm, ok := conf.DeviceManager.(*devicemanager.MockManager)
734	require.True(ok)
735
736	dm.ReserveF = func(d *structs.AllocatedDeviceResource) (*device.ContainerReservation, error) {
737		res := &device.ContainerReservation{
738			Envs: map[string]string{
739				"ABC": "123",
740			},
741			Mounts: []*device.Mount{
742				{
743					ReadOnly: true,
744					TaskPath: "foo",
745					HostPath: "bar",
746				},
747			},
748			Devices: []*device.DeviceSpec{
749				{
750					TaskPath:    "foo",
751					HostPath:    "bar",
752					CgroupPerms: "123",
753				},
754			},
755		}
756		return res, nil
757	}
758
759	// Run the TaskRunner
760	tr, err := NewTaskRunner(conf)
761	require.NoError(err)
762	go tr.Run()
763	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
764
765	// Wait for task to complete
766	select {
767	case <-tr.WaitCh():
768	case <-time.After(3 * time.Second):
769	}
770
771	// Get the mock driver plugin
772	driverPlugin, err := conf.DriverManager.Dispense(mockdriver.PluginID.Name)
773	require.NoError(err)
774	mockDriver := driverPlugin.(*mockdriver.Driver)
775
776	// Assert its config has been properly interpolated
777	driverCfg, _ := mockDriver.GetTaskConfig()
778	require.NotNil(driverCfg)
779	require.Len(driverCfg.Devices, 1)
780	require.Equal(driverCfg.Devices[0].Permissions, "123")
781	require.Len(driverCfg.Mounts, 1)
782	require.Equal(driverCfg.Mounts[0].TaskPath, "foo")
783	require.Contains(driverCfg.Env, "ABC")
784}
785
786// mockEnvHook is a test hook that sets an env var and done=true. It fails if
787// it's called more than once.
788type mockEnvHook struct {
789	called int
790}
791
792func (*mockEnvHook) Name() string {
793	return "mock_env_hook"
794}
795
796func (h *mockEnvHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error {
797	h.called++
798
799	resp.Done = true
800	resp.Env = map[string]string{
801		"mock_hook": "1",
802	}
803
804	return nil
805}
806
807// TestTaskRunner_Restore_HookEnv asserts that re-running prestart hooks with
808// hook environments set restores the environment without re-running done
809// hooks.
810func TestTaskRunner_Restore_HookEnv(t *testing.T) {
811	t.Parallel()
812	require := require.New(t)
813
814	alloc := mock.BatchAlloc()
815	task := alloc.Job.TaskGroups[0].Tasks[0]
816	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
817	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between prestart calls
818	defer cleanup()
819
820	tr, err := NewTaskRunner(conf)
821	require.NoError(err)
822
823	// Override the default hooks to only run the mock hook
824	mockHook := &mockEnvHook{}
825	tr.runnerHooks = []interfaces.TaskHook{mockHook}
826
827	// Manually run prestart hooks
828	require.NoError(tr.prestart())
829
830	// Assert env was called
831	require.Equal(1, mockHook.called)
832
833	// Re-running prestart hooks should *not* call done mock hook
834	require.NoError(tr.prestart())
835
836	// Assert env was called
837	require.Equal(1, mockHook.called)
838
839	// Assert the env is still set
840	env := tr.envBuilder.Build().All()
841	require.Contains(env, "mock_hook")
842	require.Equal("1", env["mock_hook"])
843}
844
845// This test asserts that we can recover from an "external" plugin exiting by
846// retrieving a new instance of the driver and recovering the task.
847func TestTaskRunner_RecoverFromDriverExiting(t *testing.T) {
848	t.Parallel()
849	require := require.New(t)
850
851	// Create an allocation using the mock driver that exits simulating the
852	// driver crashing. We can then test that the task runner recovers from this
853	alloc := mock.BatchAlloc()
854	task := alloc.Job.TaskGroups[0].Tasks[0]
855	task.Driver = "mock_driver"
856	task.Config = map[string]interface{}{
857		"plugin_exit_after": "1s",
858		"run_for":           "5s",
859	}
860
861	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
862	conf.StateDB = cstate.NewMemDB(conf.Logger) // "persist" state between prestart calls
863	defer cleanup()
864
865	tr, err := NewTaskRunner(conf)
866	require.NoError(err)
867
868	start := time.Now()
869	go tr.Run()
870	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
871
872	// Wait for the task to be running
873	testWaitForTaskToStart(t, tr)
874
875	// Get the task ID
876	tr.stateLock.RLock()
877	l := tr.localState.TaskHandle
878	require.NotNil(l)
879	require.NotNil(l.Config)
880	require.NotEmpty(l.Config.ID)
881	id := l.Config.ID
882	tr.stateLock.RUnlock()
883
884	// Get the mock driver plugin
885	driverPlugin, err := conf.DriverManager.Dispense(mockdriver.PluginID.Name)
886	require.NoError(err)
887	mockDriver := driverPlugin.(*mockdriver.Driver)
888
889	// Wait for the task to start
890	testutil.WaitForResult(func() (bool, error) {
891		// Get the handle and check that it was recovered
892		handle := mockDriver.GetHandle(id)
893		if handle == nil {
894			return false, fmt.Errorf("nil handle")
895		}
896		if !handle.Recovered {
897			return false, fmt.Errorf("handle not recovered")
898		}
899		return true, nil
900	}, func(err error) {
901		t.Fatal(err.Error())
902	})
903
904	// Wait for task to complete
905	select {
906	case <-tr.WaitCh():
907	case <-time.After(10 * time.Second):
908	}
909
910	// Ensure that we actually let the task complete
911	require.True(time.Now().Sub(start) > 5*time.Second)
912
913	// Check it finished successfully
914	state := tr.TaskState()
915	require.True(state.Successful())
916}
917
918// TestTaskRunner_ShutdownDelay asserts services are removed from Consul
919// ${shutdown_delay} seconds before killing the process.
920func TestTaskRunner_ShutdownDelay(t *testing.T) {
921	t.Parallel()
922
923	alloc := mock.Alloc()
924	task := alloc.Job.TaskGroups[0].Tasks[0]
925	task.Services[0].Tags = []string{"tag1"}
926	task.Services = task.Services[:1] // only need 1 for this test
927	task.Driver = "mock_driver"
928	task.Config = map[string]interface{}{
929		"run_for": "1000s",
930	}
931
932	// No shutdown escape hatch for this delay, so don't set it too high
933	task.ShutdownDelay = 1000 * time.Duration(testutil.TestMultiplier()) * time.Millisecond
934
935	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
936	defer cleanup()
937
938	mockConsul := conf.Consul.(*consul.MockConsulServiceClient)
939
940	// Wait for the task to start
941	testWaitForTaskToStart(t, tr)
942
943	testutil.WaitForResult(func() (bool, error) {
944		ops := mockConsul.GetOps()
945		if n := len(ops); n != 1 {
946			return false, fmt.Errorf("expected 1 consul operation. Found %d", n)
947		}
948		return ops[0].Op == "add", fmt.Errorf("consul operation was not a registration: %#v", ops[0])
949	}, func(err error) {
950		t.Fatalf("err: %v", err)
951	})
952
953	// Asynchronously kill task
954	killSent := time.Now()
955	killed := make(chan struct{})
956	go func() {
957		defer close(killed)
958		assert.NoError(t, tr.Kill(context.Background(), structs.NewTaskEvent("test")))
959	}()
960
961	// Wait for *2* deregistration calls (due to needing to remove both
962	// canary tag variants)
963WAIT:
964	for {
965		ops := mockConsul.GetOps()
966		switch n := len(ops); n {
967		case 1, 2:
968			// Waiting for both deregistration calls
969		case 3:
970			require.Equalf(t, "remove", ops[1].Op, "expected deregistration but found: %#v", ops[1])
971			require.Equalf(t, "remove", ops[2].Op, "expected deregistration but found: %#v", ops[2])
972			break WAIT
973		default:
974			// ?!
975			t.Fatalf("unexpected number of consul operations: %d\n%s", n, pretty.Sprint(ops))
976
977		}
978
979		select {
980		case <-killed:
981			t.Fatal("killed while service still registered")
982		case <-time.After(10 * time.Millisecond):
983		}
984	}
985
986	// Wait for actual exit
987	select {
988	case <-tr.WaitCh():
989	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
990		t.Fatalf("timeout")
991	}
992
993	<-killed
994	killDur := time.Now().Sub(killSent)
995	if killDur < task.ShutdownDelay {
996		t.Fatalf("task killed before shutdown_delay (killed_after: %s; shutdown_delay: %s",
997			killDur, task.ShutdownDelay,
998		)
999	}
1000}
1001
1002// TestTaskRunner_Dispatch_Payload asserts that a dispatch job runs and the
1003// payload was written to disk.
1004func TestTaskRunner_Dispatch_Payload(t *testing.T) {
1005	t.Parallel()
1006
1007	alloc := mock.BatchAlloc()
1008	task := alloc.Job.TaskGroups[0].Tasks[0]
1009	task.Driver = "mock_driver"
1010	task.Config = map[string]interface{}{
1011		"run_for": "1s",
1012	}
1013
1014	fileName := "test"
1015	task.DispatchPayload = &structs.DispatchPayloadConfig{
1016		File: fileName,
1017	}
1018	alloc.Job.ParameterizedJob = &structs.ParameterizedJobConfig{}
1019
1020	// Add a payload (they're snappy encoded bytes)
1021	expected := []byte("hello world")
1022	compressed := snappy.Encode(nil, expected)
1023	alloc.Job.Payload = compressed
1024
1025	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
1026	defer cleanup()
1027
1028	// Wait for it to finish
1029	testutil.WaitForResult(func() (bool, error) {
1030		ts := tr.TaskState()
1031		return ts.State == structs.TaskStateDead, fmt.Errorf("%v", ts.State)
1032	}, func(err error) {
1033		require.NoError(t, err)
1034	})
1035
1036	// Should have exited successfully
1037	ts := tr.TaskState()
1038	require.False(t, ts.Failed)
1039	require.Zero(t, ts.Restarts)
1040
1041	// Check that the file was written to disk properly
1042	payloadPath := filepath.Join(tr.taskDir.LocalDir, fileName)
1043	data, err := ioutil.ReadFile(payloadPath)
1044	require.NoError(t, err)
1045	require.Equal(t, expected, data)
1046}
1047
1048// TestTaskRunner_SignalFailure asserts that signal errors are properly
1049// propagated from the driver to TaskRunner.
1050func TestTaskRunner_SignalFailure(t *testing.T) {
1051	t.Parallel()
1052
1053	alloc := mock.Alloc()
1054	task := alloc.Job.TaskGroups[0].Tasks[0]
1055	task.Driver = "mock_driver"
1056	errMsg := "test forcing failure"
1057	task.Config = map[string]interface{}{
1058		"run_for":      "10m",
1059		"signal_error": errMsg,
1060	}
1061
1062	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
1063	defer cleanup()
1064
1065	testWaitForTaskToStart(t, tr)
1066
1067	require.EqualError(t, tr.Signal(&structs.TaskEvent{}, "SIGINT"), errMsg)
1068}
1069
1070// TestTaskRunner_RestartTask asserts that restarting a task works and emits a
1071// Restarting event.
1072func TestTaskRunner_RestartTask(t *testing.T) {
1073	t.Parallel()
1074
1075	alloc := mock.Alloc()
1076	task := alloc.Job.TaskGroups[0].Tasks[0]
1077	task.Driver = "mock_driver"
1078	task.Config = map[string]interface{}{
1079		"run_for": "10m",
1080	}
1081
1082	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
1083	defer cleanup()
1084
1085	testWaitForTaskToStart(t, tr)
1086
1087	// Restart task. Send a RestartSignal event like check watcher. Restart
1088	// handler emits the Restarting event.
1089	event := structs.NewTaskEvent(structs.TaskRestartSignal).SetRestartReason("test")
1090	const fail = false
1091	tr.Restart(context.Background(), event.Copy(), fail)
1092
1093	// Wait for it to restart and be running again
1094	testutil.WaitForResult(func() (bool, error) {
1095		ts := tr.TaskState()
1096		if ts.Restarts != 1 {
1097			return false, fmt.Errorf("expected 1 restart but found %d\nevents: %s",
1098				ts.Restarts, pretty.Sprint(ts.Events))
1099		}
1100		if ts.State != structs.TaskStateRunning {
1101			return false, fmt.Errorf("expected running but received %s", ts.State)
1102		}
1103		return true, nil
1104	}, func(err error) {
1105		require.NoError(t, err)
1106	})
1107
1108	// Assert the expected Restarting event was emitted
1109	found := false
1110	events := tr.TaskState().Events
1111	for _, e := range events {
1112		if e.Type == structs.TaskRestartSignal {
1113			found = true
1114			require.Equal(t, event.Time, e.Time)
1115			require.Equal(t, event.RestartReason, e.RestartReason)
1116			require.Contains(t, e.DisplayMessage, event.RestartReason)
1117		}
1118	}
1119	require.True(t, found, "restarting task event not found", pretty.Sprint(events))
1120}
1121
1122// TestTaskRunner_CheckWatcher_Restart asserts that when enabled an unhealthy
1123// Consul check will cause a task to restart following restart policy rules.
1124func TestTaskRunner_CheckWatcher_Restart(t *testing.T) {
1125	t.Parallel()
1126
1127	alloc := mock.Alloc()
1128
1129	// Make the restart policy fail within this test
1130	tg := alloc.Job.TaskGroups[0]
1131	tg.RestartPolicy.Attempts = 2
1132	tg.RestartPolicy.Interval = 1 * time.Minute
1133	tg.RestartPolicy.Delay = 10 * time.Millisecond
1134	tg.RestartPolicy.Mode = structs.RestartPolicyModeFail
1135
1136	task := tg.Tasks[0]
1137	task.Driver = "mock_driver"
1138	task.Config = map[string]interface{}{
1139		"run_for": "10m",
1140	}
1141
1142	// Make the task register a check that fails
1143	task.Services[0].Checks[0] = &structs.ServiceCheck{
1144		Name:     "test-restarts",
1145		Type:     structs.ServiceCheckTCP,
1146		Interval: 50 * time.Millisecond,
1147		CheckRestart: &structs.CheckRestart{
1148			Limit: 2,
1149			Grace: 100 * time.Millisecond,
1150		},
1151	}
1152
1153	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
1154	defer cleanup()
1155
1156	// Replace mock Consul ServiceClient, with the real ServiceClient
1157	// backed by a mock consul whose checks are always unhealthy.
1158	consulAgent := agentconsul.NewMockAgent()
1159	consulAgent.SetStatus("critical")
1160	namespacesClient := agentconsul.NewNamespacesClient(agentconsul.NewMockNamespaces(nil))
1161	consulClient := agentconsul.NewServiceClient(consulAgent, namespacesClient, conf.Logger, true)
1162	go consulClient.Run()
1163	defer consulClient.Shutdown()
1164
1165	conf.Consul = consulClient
1166
1167	tr, err := NewTaskRunner(conf)
1168	require.NoError(t, err)
1169
1170	expectedEvents := []string{
1171		"Received",
1172		"Task Setup",
1173		"Started",
1174		"Restart Signaled",
1175		"Terminated",
1176		"Restarting",
1177		"Started",
1178		"Restart Signaled",
1179		"Terminated",
1180		"Restarting",
1181		"Started",
1182		"Restart Signaled",
1183		"Terminated",
1184		"Not Restarting",
1185	}
1186
1187	// Bump maxEvents so task events aren't dropped
1188	tr.maxEvents = 100
1189
1190	go tr.Run()
1191	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
1192
1193	// Wait until the task exits. Don't simply wait for it to run as it may
1194	// get restarted and terminated before the test is able to observe it
1195	// running.
1196	select {
1197	case <-tr.WaitCh():
1198	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
1199		require.Fail(t, "timeout")
1200	}
1201
1202	state := tr.TaskState()
1203	actualEvents := make([]string, len(state.Events))
1204	for i, e := range state.Events {
1205		actualEvents[i] = string(e.Type)
1206	}
1207	require.Equal(t, actualEvents, expectedEvents)
1208	require.Equal(t, structs.TaskStateDead, state.State)
1209	require.True(t, state.Failed, pretty.Sprint(state))
1210}
1211
1212type mockEnvoyBootstrapHook struct {
1213	// nothing
1214}
1215
1216func (_ *mockEnvoyBootstrapHook) Name() string {
1217	return "mock_envoy_bootstrap"
1218}
1219
1220func (_ *mockEnvoyBootstrapHook) Prestart(_ context.Context, _ *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error {
1221	resp.Done = true
1222	return nil
1223}
1224
1225// The envoy bootstrap hook tries to connect to consul and run the envoy
1226// bootstrap command, so turn it off when testing connect jobs that are not
1227// using envoy.
1228func useMockEnvoyBootstrapHook(tr *TaskRunner) {
1229	mock := new(mockEnvoyBootstrapHook)
1230	for i, hook := range tr.runnerHooks {
1231		if _, ok := hook.(*envoyBootstrapHook); ok {
1232			tr.runnerHooks[i] = mock
1233		}
1234	}
1235}
1236
1237// TestTaskRunner_BlockForSIDSToken asserts tasks do not start until a Consul
1238// Service Identity token is derived.
1239func TestTaskRunner_BlockForSIDSToken(t *testing.T) {
1240	t.Parallel()
1241	r := require.New(t)
1242
1243	alloc := mock.BatchConnectAlloc()
1244	task := alloc.Job.TaskGroups[0].Tasks[0]
1245	task.Config = map[string]interface{}{
1246		"run_for": "0s",
1247	}
1248
1249	trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
1250	defer cleanup()
1251
1252	// set a consul token on the Nomad client's consul config, because that is
1253	// what gates the action of requesting SI token(s)
1254	trConfig.ClientConfig.ConsulConfig.Token = uuid.Generate()
1255
1256	// control when we get a Consul SI token
1257	token := uuid.Generate()
1258	waitCh := make(chan struct{})
1259	deriveFn := func(*structs.Allocation, []string) (map[string]string, error) {
1260		<-waitCh
1261		return map[string]string{task.Name: token}, nil
1262	}
1263	siClient := trConfig.ConsulSI.(*consulapi.MockServiceIdentitiesClient)
1264	siClient.DeriveTokenFn = deriveFn
1265
1266	// start the task runner
1267	tr, err := NewTaskRunner(trConfig)
1268	r.NoError(err)
1269	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
1270	useMockEnvoyBootstrapHook(tr) // mock the envoy bootstrap hook
1271
1272	go tr.Run()
1273
1274	// assert task runner blocks on SI token
1275	select {
1276	case <-tr.WaitCh():
1277		r.Fail("task_runner exited before si unblocked")
1278	case <-time.After(100 * time.Millisecond):
1279	}
1280
1281	// assert task state is still pending
1282	r.Equal(structs.TaskStatePending, tr.TaskState().State)
1283
1284	// unblock service identity token
1285	close(waitCh)
1286
1287	// task runner should exit now that it has been unblocked and it is a batch
1288	// job with a zero sleep time
1289	select {
1290	case <-tr.WaitCh():
1291	case <-time.After(15 * time.Second * time.Duration(testutil.TestMultiplier())):
1292		r.Fail("timed out waiting for batch task to exist")
1293	}
1294
1295	// assert task exited successfully
1296	finalState := tr.TaskState()
1297	r.Equal(structs.TaskStateDead, finalState.State)
1298	r.False(finalState.Failed)
1299
1300	// assert the token is on disk
1301	tokenPath := filepath.Join(trConfig.TaskDir.SecretsDir, sidsTokenFile)
1302	data, err := ioutil.ReadFile(tokenPath)
1303	r.NoError(err)
1304	r.Equal(token, string(data))
1305}
1306
1307func TestTaskRunner_DeriveSIToken_Retry(t *testing.T) {
1308	t.Parallel()
1309	r := require.New(t)
1310
1311	alloc := mock.BatchConnectAlloc()
1312	task := alloc.Job.TaskGroups[0].Tasks[0]
1313	task.Config = map[string]interface{}{
1314		"run_for": "0s",
1315	}
1316
1317	trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
1318	defer cleanup()
1319
1320	// set a consul token on the Nomad client's consul config, because that is
1321	// what gates the action of requesting SI token(s)
1322	trConfig.ClientConfig.ConsulConfig.Token = uuid.Generate()
1323
1324	// control when we get a Consul SI token (recoverable failure on first call)
1325	token := uuid.Generate()
1326	deriveCount := 0
1327	deriveFn := func(*structs.Allocation, []string) (map[string]string, error) {
1328		if deriveCount > 0 {
1329
1330			return map[string]string{task.Name: token}, nil
1331		}
1332		deriveCount++
1333		return nil, structs.NewRecoverableError(errors.New("try again later"), true)
1334	}
1335	siClient := trConfig.ConsulSI.(*consulapi.MockServiceIdentitiesClient)
1336	siClient.DeriveTokenFn = deriveFn
1337
1338	// start the task runner
1339	tr, err := NewTaskRunner(trConfig)
1340	r.NoError(err)
1341	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
1342	useMockEnvoyBootstrapHook(tr) // mock the envoy bootstrap
1343	go tr.Run()
1344
1345	// assert task runner blocks on SI token
1346	select {
1347	case <-tr.WaitCh():
1348	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
1349		r.Fail("timed out waiting for task runner")
1350	}
1351
1352	// assert task exited successfully
1353	finalState := tr.TaskState()
1354	r.Equal(structs.TaskStateDead, finalState.State)
1355	r.False(finalState.Failed)
1356
1357	// assert the token is on disk
1358	tokenPath := filepath.Join(trConfig.TaskDir.SecretsDir, sidsTokenFile)
1359	data, err := ioutil.ReadFile(tokenPath)
1360	r.NoError(err)
1361	r.Equal(token, string(data))
1362}
1363
1364// TestTaskRunner_DeriveSIToken_Unrecoverable asserts that an unrecoverable error
1365// from deriving a service identity token will fail a task.
1366func TestTaskRunner_DeriveSIToken_Unrecoverable(t *testing.T) {
1367	t.Parallel()
1368	r := require.New(t)
1369
1370	alloc := mock.BatchConnectAlloc()
1371	tg := alloc.Job.TaskGroups[0]
1372	tg.RestartPolicy.Attempts = 0
1373	tg.RestartPolicy.Interval = 0
1374	tg.RestartPolicy.Delay = 0
1375	tg.RestartPolicy.Mode = structs.RestartPolicyModeFail
1376	task := tg.Tasks[0]
1377	task.Config = map[string]interface{}{
1378		"run_for": "0s",
1379	}
1380
1381	trConfig, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
1382	defer cleanup()
1383
1384	// set a consul token on the Nomad client's consul config, because that is
1385	// what gates the action of requesting SI token(s)
1386	trConfig.ClientConfig.ConsulConfig.Token = uuid.Generate()
1387
1388	// SI token derivation suffers a non-retryable error
1389	siClient := trConfig.ConsulSI.(*consulapi.MockServiceIdentitiesClient)
1390	siClient.SetDeriveTokenError(alloc.ID, []string{task.Name}, errors.New("non-recoverable"))
1391
1392	tr, err := NewTaskRunner(trConfig)
1393	r.NoError(err)
1394
1395	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
1396	useMockEnvoyBootstrapHook(tr) // mock the envoy bootstrap hook
1397	go tr.Run()
1398
1399	// Wait for the task to die
1400	select {
1401	case <-tr.WaitCh():
1402	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
1403		require.Fail(t, "timed out waiting for task runner to fail")
1404	}
1405
1406	// assert we have died and failed
1407	finalState := tr.TaskState()
1408	r.Equal(structs.TaskStateDead, finalState.State)
1409	r.True(finalState.Failed)
1410	r.Equal(5, len(finalState.Events))
1411	/*
1412	 + event: Task received by client
1413	 + event: Building Task Directory
1414	 + event: consul: failed to derive SI token: non-recoverable
1415	 + event: consul_sids: context canceled
1416	 + event: Policy allows no restarts
1417	*/
1418	r.Equal("true", finalState.Events[2].Details["fails_task"])
1419}
1420
1421// TestTaskRunner_BlockForVaultToken asserts tasks do not start until a vault token
1422// is derived.
1423func TestTaskRunner_BlockForVaultToken(t *testing.T) {
1424	t.Parallel()
1425
1426	alloc := mock.BatchAlloc()
1427	task := alloc.Job.TaskGroups[0].Tasks[0]
1428	task.Config = map[string]interface{}{
1429		"run_for": "0s",
1430	}
1431	task.Vault = &structs.Vault{Policies: []string{"default"}}
1432
1433	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
1434	defer cleanup()
1435
1436	// Control when we get a Vault token
1437	token := "1234"
1438	waitCh := make(chan struct{})
1439	handler := func(*structs.Allocation, []string) (map[string]string, error) {
1440		<-waitCh
1441		return map[string]string{task.Name: token}, nil
1442	}
1443	vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
1444	vaultClient.DeriveTokenFn = handler
1445
1446	tr, err := NewTaskRunner(conf)
1447	require.NoError(t, err)
1448	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
1449	go tr.Run()
1450
1451	// Assert TR blocks on vault token (does *not* exit)
1452	select {
1453	case <-tr.WaitCh():
1454		require.Fail(t, "tr exited before vault unblocked")
1455	case <-time.After(1 * time.Second):
1456	}
1457
1458	// Assert task state is still Pending
1459	require.Equal(t, structs.TaskStatePending, tr.TaskState().State)
1460
1461	// Unblock vault token
1462	close(waitCh)
1463
1464	// TR should exit now that it's unblocked by vault as its a batch job
1465	// with 0 sleeping.
1466	select {
1467	case <-tr.WaitCh():
1468	case <-time.After(15 * time.Second * time.Duration(testutil.TestMultiplier())):
1469		require.Fail(t, "timed out waiting for batch task to exit")
1470	}
1471
1472	// Assert task exited successfully
1473	finalState := tr.TaskState()
1474	require.Equal(t, structs.TaskStateDead, finalState.State)
1475	require.False(t, finalState.Failed)
1476
1477	// Check that the token is on disk
1478	tokenPath := filepath.Join(conf.TaskDir.SecretsDir, vaultTokenFile)
1479	data, err := ioutil.ReadFile(tokenPath)
1480	require.NoError(t, err)
1481	require.Equal(t, token, string(data))
1482
1483	// Check the token was revoked
1484	testutil.WaitForResult(func() (bool, error) {
1485		if len(vaultClient.StoppedTokens()) != 1 {
1486			return false, fmt.Errorf("Expected a stopped token %q but found: %v", token, vaultClient.StoppedTokens())
1487		}
1488
1489		if a := vaultClient.StoppedTokens()[0]; a != token {
1490			return false, fmt.Errorf("got stopped token %q; want %q", a, token)
1491		}
1492		return true, nil
1493	}, func(err error) {
1494		require.Fail(t, err.Error())
1495	})
1496}
1497
1498// TestTaskRunner_DeriveToken_Retry asserts that if a recoverable error is
1499// returned when deriving a vault token a task will continue to block while
1500// it's retried.
1501func TestTaskRunner_DeriveToken_Retry(t *testing.T) {
1502	t.Parallel()
1503	alloc := mock.BatchAlloc()
1504	task := alloc.Job.TaskGroups[0].Tasks[0]
1505	task.Vault = &structs.Vault{Policies: []string{"default"}}
1506
1507	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
1508	defer cleanup()
1509
1510	// Fail on the first attempt to derive a vault token
1511	token := "1234"
1512	count := 0
1513	handler := func(*structs.Allocation, []string) (map[string]string, error) {
1514		if count > 0 {
1515			return map[string]string{task.Name: token}, nil
1516		}
1517
1518		count++
1519		return nil, structs.NewRecoverableError(fmt.Errorf("Want a retry"), true)
1520	}
1521	vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
1522	vaultClient.DeriveTokenFn = handler
1523
1524	tr, err := NewTaskRunner(conf)
1525	require.NoError(t, err)
1526	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
1527	go tr.Run()
1528
1529	// Wait for TR to exit and check its state
1530	select {
1531	case <-tr.WaitCh():
1532	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
1533		require.Fail(t, "timed out waiting for task runner to exit")
1534	}
1535
1536	state := tr.TaskState()
1537	require.Equal(t, structs.TaskStateDead, state.State)
1538	require.False(t, state.Failed)
1539
1540	require.Equal(t, 1, count)
1541
1542	// Check that the token is on disk
1543	tokenPath := filepath.Join(conf.TaskDir.SecretsDir, vaultTokenFile)
1544	data, err := ioutil.ReadFile(tokenPath)
1545	require.NoError(t, err)
1546	require.Equal(t, token, string(data))
1547
1548	// Check the token was revoked
1549	testutil.WaitForResult(func() (bool, error) {
1550		if len(vaultClient.StoppedTokens()) != 1 {
1551			return false, fmt.Errorf("Expected a stopped token: %v", vaultClient.StoppedTokens())
1552		}
1553
1554		if a := vaultClient.StoppedTokens()[0]; a != token {
1555			return false, fmt.Errorf("got stopped token %q; want %q", a, token)
1556		}
1557		return true, nil
1558	}, func(err error) {
1559		require.Fail(t, err.Error())
1560	})
1561}
1562
1563// TestTaskRunner_DeriveToken_Unrecoverable asserts that an unrecoverable error
1564// from deriving a vault token will fail a task.
1565func TestTaskRunner_DeriveToken_Unrecoverable(t *testing.T) {
1566	t.Parallel()
1567
1568	// Use a batch job with no restarts
1569	alloc := mock.BatchAlloc()
1570	tg := alloc.Job.TaskGroups[0]
1571	tg.RestartPolicy.Attempts = 0
1572	tg.RestartPolicy.Interval = 0
1573	tg.RestartPolicy.Delay = 0
1574	tg.RestartPolicy.Mode = structs.RestartPolicyModeFail
1575	task := tg.Tasks[0]
1576	task.Config = map[string]interface{}{
1577		"run_for": "0s",
1578	}
1579	task.Vault = &structs.Vault{Policies: []string{"default"}}
1580
1581	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
1582	defer cleanup()
1583
1584	// Error the token derivation
1585	vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
1586	vaultClient.SetDeriveTokenError(alloc.ID, []string{task.Name}, fmt.Errorf("Non recoverable"))
1587
1588	tr, err := NewTaskRunner(conf)
1589	require.NoError(t, err)
1590	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
1591	go tr.Run()
1592
1593	// Wait for the task to die
1594	select {
1595	case <-tr.WaitCh():
1596	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
1597		require.Fail(t, "timed out waiting for task runner to fail")
1598	}
1599
1600	// Task should be dead and last event should have failed task
1601	state := tr.TaskState()
1602	require.Equal(t, structs.TaskStateDead, state.State)
1603	require.True(t, state.Failed)
1604	require.Len(t, state.Events, 3)
1605	require.True(t, state.Events[2].FailsTask)
1606}
1607
1608// TestTaskRunner_Download_ChrootExec asserts that downloaded artifacts may be
1609// executed in a chroot.
1610func TestTaskRunner_Download_ChrootExec(t *testing.T) {
1611	t.Parallel()
1612	ctestutil.ExecCompatible(t)
1613
1614	ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("."))))
1615	defer ts.Close()
1616
1617	// Create a task that downloads a script and executes it.
1618	alloc := mock.BatchAlloc()
1619	alloc.Job.TaskGroups[0].RestartPolicy = &structs.RestartPolicy{}
1620	task := alloc.Job.TaskGroups[0].Tasks[0]
1621	task.RestartPolicy = &structs.RestartPolicy{}
1622	task.Driver = "exec"
1623	task.Config = map[string]interface{}{
1624		"command": "noop.sh",
1625	}
1626	task.Artifacts = []*structs.TaskArtifact{
1627		{
1628			GetterSource: fmt.Sprintf("%s/testdata/noop.sh", ts.URL),
1629			GetterMode:   "file",
1630			RelativeDest: "noop.sh",
1631		},
1632	}
1633
1634	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
1635	defer cleanup()
1636
1637	// Wait for task to run and exit
1638	select {
1639	case <-tr.WaitCh():
1640	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
1641		require.Fail(t, "timed out waiting for task runner to exit")
1642	}
1643
1644	state := tr.TaskState()
1645	require.Equal(t, structs.TaskStateDead, state.State)
1646	require.False(t, state.Failed)
1647}
1648
1649// TestTaskRunner_Download_Exec asserts that downloaded artifacts may be
1650// executed in a driver without filesystem isolation.
1651func TestTaskRunner_Download_RawExec(t *testing.T) {
1652	t.Parallel()
1653
1654	ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("."))))
1655	defer ts.Close()
1656
1657	// Create a task that downloads a script and executes it.
1658	alloc := mock.BatchAlloc()
1659	alloc.Job.TaskGroups[0].RestartPolicy = &structs.RestartPolicy{}
1660	task := alloc.Job.TaskGroups[0].Tasks[0]
1661	task.RestartPolicy = &structs.RestartPolicy{}
1662	task.Driver = "raw_exec"
1663	task.Config = map[string]interface{}{
1664		"command": "noop.sh",
1665	}
1666	task.Artifacts = []*structs.TaskArtifact{
1667		{
1668			GetterSource: fmt.Sprintf("%s/testdata/noop.sh", ts.URL),
1669			GetterMode:   "file",
1670			RelativeDest: "noop.sh",
1671		},
1672	}
1673
1674	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
1675	defer cleanup()
1676
1677	// Wait for task to run and exit
1678	select {
1679	case <-tr.WaitCh():
1680	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
1681		require.Fail(t, "timed out waiting for task runner to exit")
1682	}
1683
1684	state := tr.TaskState()
1685	require.Equal(t, structs.TaskStateDead, state.State)
1686	require.False(t, state.Failed)
1687}
1688
1689// TestTaskRunner_Download_List asserts that multiple artificats are downloaded
1690// before a task is run.
1691func TestTaskRunner_Download_List(t *testing.T) {
1692	t.Parallel()
1693	ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("."))))
1694	defer ts.Close()
1695
1696	// Create an allocation that has a task with a list of artifacts.
1697	alloc := mock.BatchAlloc()
1698	task := alloc.Job.TaskGroups[0].Tasks[0]
1699	f1 := "task_runner_test.go"
1700	f2 := "task_runner.go"
1701	artifact1 := structs.TaskArtifact{
1702		GetterSource: fmt.Sprintf("%s/%s", ts.URL, f1),
1703	}
1704	artifact2 := structs.TaskArtifact{
1705		GetterSource: fmt.Sprintf("%s/%s", ts.URL, f2),
1706	}
1707	task.Artifacts = []*structs.TaskArtifact{&artifact1, &artifact2}
1708
1709	tr, conf, cleanup := runTestTaskRunner(t, alloc, task.Name)
1710	defer cleanup()
1711
1712	// Wait for task to run and exit
1713	select {
1714	case <-tr.WaitCh():
1715	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
1716		require.Fail(t, "timed out waiting for task runner to exit")
1717	}
1718
1719	state := tr.TaskState()
1720	require.Equal(t, structs.TaskStateDead, state.State)
1721	require.False(t, state.Failed)
1722
1723	require.Len(t, state.Events, 5)
1724	assert.Equal(t, structs.TaskReceived, state.Events[0].Type)
1725	assert.Equal(t, structs.TaskSetup, state.Events[1].Type)
1726	assert.Equal(t, structs.TaskDownloadingArtifacts, state.Events[2].Type)
1727	assert.Equal(t, structs.TaskStarted, state.Events[3].Type)
1728	assert.Equal(t, structs.TaskTerminated, state.Events[4].Type)
1729
1730	// Check that both files exist.
1731	_, err := os.Stat(filepath.Join(conf.TaskDir.Dir, f1))
1732	require.NoErrorf(t, err, "%v not downloaded", f1)
1733
1734	_, err = os.Stat(filepath.Join(conf.TaskDir.Dir, f2))
1735	require.NoErrorf(t, err, "%v not downloaded", f2)
1736}
1737
1738// TestTaskRunner_Download_Retries asserts that failed artifact downloads are
1739// retried according to the task's restart policy.
1740func TestTaskRunner_Download_Retries(t *testing.T) {
1741	t.Parallel()
1742
1743	// Create an allocation that has a task with bad artifacts.
1744	alloc := mock.BatchAlloc()
1745	task := alloc.Job.TaskGroups[0].Tasks[0]
1746	artifact := structs.TaskArtifact{
1747		GetterSource: "http://127.0.0.1:0/foo/bar/baz",
1748	}
1749	task.Artifacts = []*structs.TaskArtifact{&artifact}
1750
1751	// Make the restart policy retry once
1752	rp := &structs.RestartPolicy{
1753		Attempts: 1,
1754		Interval: 10 * time.Minute,
1755		Delay:    1 * time.Second,
1756		Mode:     structs.RestartPolicyModeFail,
1757	}
1758	alloc.Job.TaskGroups[0].RestartPolicy = rp
1759	alloc.Job.TaskGroups[0].Tasks[0].RestartPolicy = rp
1760
1761	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
1762	defer cleanup()
1763
1764	select {
1765	case <-tr.WaitCh():
1766	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
1767		require.Fail(t, "timed out waiting for task to exit")
1768	}
1769
1770	state := tr.TaskState()
1771	require.Equal(t, structs.TaskStateDead, state.State)
1772	require.True(t, state.Failed)
1773	require.Len(t, state.Events, 8, pretty.Sprint(state.Events))
1774	require.Equal(t, structs.TaskReceived, state.Events[0].Type)
1775	require.Equal(t, structs.TaskSetup, state.Events[1].Type)
1776	require.Equal(t, structs.TaskDownloadingArtifacts, state.Events[2].Type)
1777	require.Equal(t, structs.TaskArtifactDownloadFailed, state.Events[3].Type)
1778	require.Equal(t, structs.TaskRestarting, state.Events[4].Type)
1779	require.Equal(t, structs.TaskDownloadingArtifacts, state.Events[5].Type)
1780	require.Equal(t, structs.TaskArtifactDownloadFailed, state.Events[6].Type)
1781	require.Equal(t, structs.TaskNotRestarting, state.Events[7].Type)
1782}
1783
1784// TestTaskRunner_DriverNetwork asserts that a driver's network is properly
1785// used in services and checks.
1786func TestTaskRunner_DriverNetwork(t *testing.T) {
1787	t.Parallel()
1788
1789	alloc := mock.Alloc()
1790	task := alloc.Job.TaskGroups[0].Tasks[0]
1791	task.Driver = "mock_driver"
1792	task.Config = map[string]interface{}{
1793		"run_for":         "100s",
1794		"driver_ip":       "10.1.2.3",
1795		"driver_port_map": "http:80",
1796	}
1797
1798	// Create services and checks with custom address modes to exercise
1799	// address detection logic
1800	task.Services = []*structs.Service{
1801		{
1802			Name:        "host-service",
1803			PortLabel:   "http",
1804			AddressMode: "host",
1805			Checks: []*structs.ServiceCheck{
1806				{
1807					Name:        "driver-check",
1808					Type:        "tcp",
1809					PortLabel:   "1234",
1810					AddressMode: "driver",
1811				},
1812			},
1813		},
1814		{
1815			Name:        "driver-service",
1816			PortLabel:   "5678",
1817			AddressMode: "driver",
1818			Checks: []*structs.ServiceCheck{
1819				{
1820					Name:      "host-check",
1821					Type:      "tcp",
1822					PortLabel: "http",
1823				},
1824				{
1825					Name:        "driver-label-check",
1826					Type:        "tcp",
1827					PortLabel:   "http",
1828					AddressMode: "driver",
1829				},
1830			},
1831		},
1832	}
1833
1834	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
1835	defer cleanup()
1836
1837	// Use a mock agent to test for services
1838	consulAgent := agentconsul.NewMockAgent()
1839	namespacesClient := agentconsul.NewNamespacesClient(agentconsul.NewMockNamespaces(nil))
1840	consulClient := agentconsul.NewServiceClient(consulAgent, namespacesClient, conf.Logger, true)
1841	defer consulClient.Shutdown()
1842	go consulClient.Run()
1843
1844	conf.Consul = consulClient
1845
1846	tr, err := NewTaskRunner(conf)
1847	require.NoError(t, err)
1848	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
1849	go tr.Run()
1850
1851	// Wait for the task to start
1852	testWaitForTaskToStart(t, tr)
1853
1854	testutil.WaitForResult(func() (bool, error) {
1855		services, _ := consulAgent.ServicesWithFilterOpts("", nil)
1856		if n := len(services); n != 2 {
1857			return false, fmt.Errorf("expected 2 services, but found %d", n)
1858		}
1859		for _, s := range services {
1860			switch s.Service {
1861			case "host-service":
1862				if expected := "192.168.0.100"; s.Address != expected {
1863					return false, fmt.Errorf("expected host-service to have IP=%s but found %s",
1864						expected, s.Address)
1865				}
1866			case "driver-service":
1867				if expected := "10.1.2.3"; s.Address != expected {
1868					return false, fmt.Errorf("expected driver-service to have IP=%s but found %s",
1869						expected, s.Address)
1870				}
1871				if expected := 5678; s.Port != expected {
1872					return false, fmt.Errorf("expected driver-service to have port=%d but found %d",
1873						expected, s.Port)
1874				}
1875			default:
1876				return false, fmt.Errorf("unexpected service: %q", s.Service)
1877			}
1878
1879		}
1880
1881		checks := consulAgent.CheckRegs()
1882		if n := len(checks); n != 3 {
1883			return false, fmt.Errorf("expected 3 checks, but found %d", n)
1884		}
1885		for _, check := range checks {
1886			switch check.Name {
1887			case "driver-check":
1888				if expected := "10.1.2.3:1234"; check.TCP != expected {
1889					return false, fmt.Errorf("expected driver-check to have address %q but found %q", expected, check.TCP)
1890				}
1891			case "driver-label-check":
1892				if expected := "10.1.2.3:80"; check.TCP != expected {
1893					return false, fmt.Errorf("expected driver-label-check to have address %q but found %q", expected, check.TCP)
1894				}
1895			case "host-check":
1896				if expected := "192.168.0.100:"; !strings.HasPrefix(check.TCP, expected) {
1897					return false, fmt.Errorf("expected host-check to have address start with %q but found %q", expected, check.TCP)
1898				}
1899			default:
1900				return false, fmt.Errorf("unexpected check: %q", check.Name)
1901			}
1902		}
1903
1904		return true, nil
1905	}, func(err error) {
1906		services, _ := consulAgent.ServicesWithFilterOpts("", nil)
1907		for _, s := range services {
1908			t.Logf(pretty.Sprint("Service: ", s))
1909		}
1910		for _, c := range consulAgent.CheckRegs() {
1911			t.Logf(pretty.Sprint("Check:   ", c))
1912		}
1913		require.NoError(t, err)
1914	})
1915}
1916
1917// TestTaskRunner_RestartSignalTask_NotRunning asserts resilience to failures
1918// when a restart or signal is triggered and the task is not running.
1919func TestTaskRunner_RestartSignalTask_NotRunning(t *testing.T) {
1920	t.Parallel()
1921
1922	alloc := mock.BatchAlloc()
1923	task := alloc.Job.TaskGroups[0].Tasks[0]
1924	task.Driver = "mock_driver"
1925	task.Config = map[string]interface{}{
1926		"run_for": "0s",
1927	}
1928
1929	// Use vault to block the start
1930	task.Vault = &structs.Vault{Policies: []string{"default"}}
1931
1932	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
1933	defer cleanup()
1934
1935	// Control when we get a Vault token
1936	waitCh := make(chan struct{}, 1)
1937	defer close(waitCh)
1938	handler := func(*structs.Allocation, []string) (map[string]string, error) {
1939		<-waitCh
1940		return map[string]string{task.Name: "1234"}, nil
1941	}
1942	vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
1943	vaultClient.DeriveTokenFn = handler
1944
1945	tr, err := NewTaskRunner(conf)
1946	require.NoError(t, err)
1947	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
1948	go tr.Run()
1949
1950	select {
1951	case <-tr.WaitCh():
1952		require.Fail(t, "unexpected exit")
1953	case <-time.After(1 * time.Second):
1954	}
1955
1956	// Send a signal and restart
1957	err = tr.Signal(structs.NewTaskEvent("don't panic"), "QUIT")
1958	require.EqualError(t, err, ErrTaskNotRunning.Error())
1959
1960	// Send a restart
1961	err = tr.Restart(context.Background(), structs.NewTaskEvent("don't panic"), false)
1962	require.EqualError(t, err, ErrTaskNotRunning.Error())
1963
1964	// Unblock and let it finish
1965	waitCh <- struct{}{}
1966
1967	select {
1968	case <-tr.WaitCh():
1969	case <-time.After(10 * time.Second):
1970		require.Fail(t, "timed out waiting for task to complete")
1971	}
1972
1973	// Assert the task ran and never restarted
1974	state := tr.TaskState()
1975	require.Equal(t, structs.TaskStateDead, state.State)
1976	require.False(t, state.Failed)
1977	require.Len(t, state.Events, 4, pretty.Sprint(state.Events))
1978	require.Equal(t, structs.TaskReceived, state.Events[0].Type)
1979	require.Equal(t, structs.TaskSetup, state.Events[1].Type)
1980	require.Equal(t, structs.TaskStarted, state.Events[2].Type)
1981	require.Equal(t, structs.TaskTerminated, state.Events[3].Type)
1982}
1983
1984// TestTaskRunner_Run_RecoverableStartError asserts tasks are restarted if they
1985// return a recoverable error from StartTask.
1986func TestTaskRunner_Run_RecoverableStartError(t *testing.T) {
1987	t.Parallel()
1988
1989	alloc := mock.BatchAlloc()
1990	task := alloc.Job.TaskGroups[0].Tasks[0]
1991	task.Config = map[string]interface{}{
1992		"start_error":             "driver failure",
1993		"start_error_recoverable": true,
1994	}
1995
1996	// Make the restart policy retry once
1997	rp := &structs.RestartPolicy{
1998		Attempts: 1,
1999		Interval: 10 * time.Minute,
2000		Delay:    0,
2001		Mode:     structs.RestartPolicyModeFail,
2002	}
2003	alloc.Job.TaskGroups[0].RestartPolicy = rp
2004	alloc.Job.TaskGroups[0].Tasks[0].RestartPolicy = rp
2005
2006	tr, _, cleanup := runTestTaskRunner(t, alloc, task.Name)
2007	defer cleanup()
2008
2009	select {
2010	case <-tr.WaitCh():
2011	case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
2012		require.Fail(t, "timed out waiting for task to exit")
2013	}
2014
2015	state := tr.TaskState()
2016	require.Equal(t, structs.TaskStateDead, state.State)
2017	require.True(t, state.Failed)
2018	require.Len(t, state.Events, 6, pretty.Sprint(state.Events))
2019	require.Equal(t, structs.TaskReceived, state.Events[0].Type)
2020	require.Equal(t, structs.TaskSetup, state.Events[1].Type)
2021	require.Equal(t, structs.TaskDriverFailure, state.Events[2].Type)
2022	require.Equal(t, structs.TaskRestarting, state.Events[3].Type)
2023	require.Equal(t, structs.TaskDriverFailure, state.Events[4].Type)
2024	require.Equal(t, structs.TaskNotRestarting, state.Events[5].Type)
2025}
2026
2027// TestTaskRunner_Template_Artifact asserts that tasks can use artifacts as templates.
2028func TestTaskRunner_Template_Artifact(t *testing.T) {
2029	t.Parallel()
2030
2031	ts := httptest.NewServer(http.FileServer(http.Dir(".")))
2032	defer ts.Close()
2033
2034	alloc := mock.BatchAlloc()
2035	task := alloc.Job.TaskGroups[0].Tasks[0]
2036	f1 := "task_runner.go"
2037	f2 := "test"
2038	task.Artifacts = []*structs.TaskArtifact{
2039		{GetterSource: fmt.Sprintf("%s/%s", ts.URL, f1)},
2040	}
2041	task.Templates = []*structs.Template{
2042		{
2043			SourcePath: f1,
2044			DestPath:   "local/test",
2045			ChangeMode: structs.TemplateChangeModeNoop,
2046		},
2047	}
2048
2049	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
2050	defer cleanup()
2051
2052	tr, err := NewTaskRunner(conf)
2053	require.NoError(t, err)
2054	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
2055	go tr.Run()
2056
2057	// Wait for task to run and exit
2058	select {
2059	case <-tr.WaitCh():
2060	case <-time.After(15 * time.Second * time.Duration(testutil.TestMultiplier())):
2061		require.Fail(t, "timed out waiting for task runner to exit")
2062	}
2063
2064	state := tr.TaskState()
2065	require.Equal(t, structs.TaskStateDead, state.State)
2066	require.True(t, state.Successful())
2067	require.False(t, state.Failed)
2068
2069	artifactsDownloaded := false
2070	for _, e := range state.Events {
2071		if e.Type == structs.TaskDownloadingArtifacts {
2072			artifactsDownloaded = true
2073		}
2074	}
2075	assert.True(t, artifactsDownloaded, "expected artifacts downloaded events")
2076
2077	// Check that both files exist.
2078	_, err = os.Stat(filepath.Join(conf.TaskDir.Dir, f1))
2079	require.NoErrorf(t, err, "%v not downloaded", f1)
2080
2081	_, err = os.Stat(filepath.Join(conf.TaskDir.LocalDir, f2))
2082	require.NoErrorf(t, err, "%v not rendered", f2)
2083}
2084
2085// TestTaskRunner_Template_BlockingPreStart asserts that a template
2086// that fails to render in PreStart can gracefully be shutdown by
2087// either killCtx or shutdownCtx
2088func TestTaskRunner_Template_BlockingPreStart(t *testing.T) {
2089	t.Parallel()
2090
2091	alloc := mock.BatchAlloc()
2092	task := alloc.Job.TaskGroups[0].Tasks[0]
2093	task.Templates = []*structs.Template{
2094		{
2095			EmbeddedTmpl: `{{ with secret "foo/secret" }}{{ .Data.certificate }}{{ end }}`,
2096			DestPath:     "local/test",
2097			ChangeMode:   structs.TemplateChangeModeNoop,
2098		},
2099	}
2100
2101	task.Vault = &structs.Vault{Policies: []string{"default"}}
2102
2103	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
2104	defer cleanup()
2105
2106	tr, err := NewTaskRunner(conf)
2107	require.NoError(t, err)
2108	go tr.Run()
2109	defer tr.Shutdown()
2110
2111	testutil.WaitForResult(func() (bool, error) {
2112		ts := tr.TaskState()
2113
2114		if len(ts.Events) == 0 {
2115			return false, fmt.Errorf("no events yet")
2116		}
2117
2118		for _, e := range ts.Events {
2119			if e.Type == "Template" && strings.Contains(e.DisplayMessage, "vault.read(foo/secret)") {
2120				return true, nil
2121			}
2122		}
2123
2124		return false, fmt.Errorf("no missing vault secret template event yet: %#v", ts.Events)
2125
2126	}, func(err error) {
2127		require.NoError(t, err)
2128	})
2129
2130	shutdown := func() <-chan bool {
2131		finished := make(chan bool)
2132		go func() {
2133			tr.Shutdown()
2134			finished <- true
2135		}()
2136
2137		return finished
2138	}
2139
2140	select {
2141	case <-shutdown():
2142		// it shut down like it should have
2143	case <-time.After(10 * time.Second):
2144		require.Fail(t, "timeout shutting down task")
2145	}
2146}
2147
2148// TestTaskRunner_Template_NewVaultToken asserts that a new vault token is
2149// created when rendering template and that it is revoked on alloc completion
2150func TestTaskRunner_Template_NewVaultToken(t *testing.T) {
2151	t.Parallel()
2152
2153	alloc := mock.BatchAlloc()
2154	task := alloc.Job.TaskGroups[0].Tasks[0]
2155	task.Templates = []*structs.Template{
2156		{
2157			EmbeddedTmpl: `{{key "foo"}}`,
2158			DestPath:     "local/test",
2159			ChangeMode:   structs.TemplateChangeModeNoop,
2160		},
2161	}
2162	task.Vault = &structs.Vault{Policies: []string{"default"}}
2163
2164	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
2165	defer cleanup()
2166
2167	tr, err := NewTaskRunner(conf)
2168	require.NoError(t, err)
2169	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
2170	go tr.Run()
2171
2172	// Wait for a Vault token
2173	var token string
2174	testutil.WaitForResult(func() (bool, error) {
2175		token = tr.getVaultToken()
2176
2177		if token == "" {
2178			return false, fmt.Errorf("No Vault token")
2179		}
2180
2181		return true, nil
2182	}, func(err error) {
2183		require.NoError(t, err)
2184	})
2185
2186	vault := conf.Vault.(*vaultclient.MockVaultClient)
2187	renewalCh, ok := vault.RenewTokens()[token]
2188	require.True(t, ok, "no renewal channel for token")
2189
2190	renewalCh <- fmt.Errorf("Test killing")
2191	close(renewalCh)
2192
2193	var token2 string
2194	testutil.WaitForResult(func() (bool, error) {
2195		token2 = tr.getVaultToken()
2196
2197		if token2 == "" {
2198			return false, fmt.Errorf("No Vault token")
2199		}
2200
2201		if token2 == token {
2202			return false, fmt.Errorf("token wasn't recreated")
2203		}
2204
2205		return true, nil
2206	}, func(err error) {
2207		require.NoError(t, err)
2208	})
2209
2210	// Check the token was revoked
2211	testutil.WaitForResult(func() (bool, error) {
2212		if len(vault.StoppedTokens()) != 1 {
2213			return false, fmt.Errorf("Expected a stopped token: %v", vault.StoppedTokens())
2214		}
2215
2216		if a := vault.StoppedTokens()[0]; a != token {
2217			return false, fmt.Errorf("got stopped token %q; want %q", a, token)
2218		}
2219
2220		return true, nil
2221	}, func(err error) {
2222		require.NoError(t, err)
2223	})
2224
2225}
2226
2227// TestTaskRunner_VaultManager_Restart asserts that the alloc is restarted when the alloc
2228// derived vault token expires, when task is configured with Restart change mode
2229func TestTaskRunner_VaultManager_Restart(t *testing.T) {
2230	t.Parallel()
2231
2232	alloc := mock.BatchAlloc()
2233	task := alloc.Job.TaskGroups[0].Tasks[0]
2234	task.Config = map[string]interface{}{
2235		"run_for": "10s",
2236	}
2237	task.Vault = &structs.Vault{
2238		Policies:   []string{"default"},
2239		ChangeMode: structs.VaultChangeModeRestart,
2240	}
2241
2242	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
2243	defer cleanup()
2244
2245	tr, err := NewTaskRunner(conf)
2246	require.NoError(t, err)
2247	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
2248	go tr.Run()
2249
2250	testWaitForTaskToStart(t, tr)
2251
2252	tr.vaultTokenLock.Lock()
2253	token := tr.vaultToken
2254	tr.vaultTokenLock.Unlock()
2255
2256	require.NotEmpty(t, token)
2257
2258	vault := conf.Vault.(*vaultclient.MockVaultClient)
2259	renewalCh, ok := vault.RenewTokens()[token]
2260	require.True(t, ok, "no renewal channel for token")
2261
2262	renewalCh <- fmt.Errorf("Test killing")
2263	close(renewalCh)
2264
2265	testutil.WaitForResult(func() (bool, error) {
2266		state := tr.TaskState()
2267
2268		if len(state.Events) == 0 {
2269			return false, fmt.Errorf("no events yet")
2270		}
2271
2272		foundRestartSignal, foundRestarting := false, false
2273		for _, e := range state.Events {
2274			switch e.Type {
2275			case structs.TaskRestartSignal:
2276				foundRestartSignal = true
2277			case structs.TaskRestarting:
2278				foundRestarting = true
2279			}
2280		}
2281
2282		if !foundRestartSignal {
2283			return false, fmt.Errorf("no restart signal event yet: %#v", state.Events)
2284		}
2285
2286		if !foundRestarting {
2287			return false, fmt.Errorf("no restarting event yet: %#v", state.Events)
2288		}
2289
2290		lastEvent := state.Events[len(state.Events)-1]
2291		if lastEvent.Type != structs.TaskStarted {
2292			return false, fmt.Errorf("expected last event to be task starting but was %#v", lastEvent)
2293		}
2294		return true, nil
2295	}, func(err error) {
2296		require.NoError(t, err)
2297	})
2298}
2299
2300// TestTaskRunner_VaultManager_Signal asserts that the alloc is signalled when the alloc
2301// derived vault token expires, when task is configured with signal change mode
2302func TestTaskRunner_VaultManager_Signal(t *testing.T) {
2303	t.Parallel()
2304
2305	alloc := mock.BatchAlloc()
2306	task := alloc.Job.TaskGroups[0].Tasks[0]
2307	task.Config = map[string]interface{}{
2308		"run_for": "10s",
2309	}
2310	task.Vault = &structs.Vault{
2311		Policies:     []string{"default"},
2312		ChangeMode:   structs.VaultChangeModeSignal,
2313		ChangeSignal: "SIGUSR1",
2314	}
2315
2316	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
2317	defer cleanup()
2318
2319	tr, err := NewTaskRunner(conf)
2320	require.NoError(t, err)
2321	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
2322	go tr.Run()
2323
2324	testWaitForTaskToStart(t, tr)
2325
2326	tr.vaultTokenLock.Lock()
2327	token := tr.vaultToken
2328	tr.vaultTokenLock.Unlock()
2329
2330	require.NotEmpty(t, token)
2331
2332	vault := conf.Vault.(*vaultclient.MockVaultClient)
2333	renewalCh, ok := vault.RenewTokens()[token]
2334	require.True(t, ok, "no renewal channel for token")
2335
2336	renewalCh <- fmt.Errorf("Test killing")
2337	close(renewalCh)
2338
2339	testutil.WaitForResult(func() (bool, error) {
2340		state := tr.TaskState()
2341
2342		if len(state.Events) == 0 {
2343			return false, fmt.Errorf("no events yet")
2344		}
2345
2346		foundSignaling := false
2347		for _, e := range state.Events {
2348			if e.Type == structs.TaskSignaling {
2349				foundSignaling = true
2350			}
2351		}
2352
2353		if !foundSignaling {
2354			return false, fmt.Errorf("no signaling event yet: %#v", state.Events)
2355		}
2356
2357		return true, nil
2358	}, func(err error) {
2359		require.NoError(t, err)
2360	})
2361
2362}
2363
2364// TestTaskRunner_UnregisterConsul_Retries asserts a task is unregistered from
2365// Consul when waiting to be retried.
2366func TestTaskRunner_UnregisterConsul_Retries(t *testing.T) {
2367	t.Parallel()
2368
2369	alloc := mock.Alloc()
2370	// Make the restart policy try one ctx.update
2371	rp := &structs.RestartPolicy{
2372		Attempts: 1,
2373		Interval: 10 * time.Minute,
2374		Delay:    time.Nanosecond,
2375		Mode:     structs.RestartPolicyModeFail,
2376	}
2377	alloc.Job.TaskGroups[0].RestartPolicy = rp
2378	task := alloc.Job.TaskGroups[0].Tasks[0]
2379	task.RestartPolicy = rp
2380	task.Driver = "mock_driver"
2381	task.Config = map[string]interface{}{
2382		"exit_code": "1",
2383		"run_for":   "1ns",
2384	}
2385
2386	conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
2387	defer cleanup()
2388
2389	tr, err := NewTaskRunner(conf)
2390	require.NoError(t, err)
2391	defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
2392	tr.Run()
2393
2394	state := tr.TaskState()
2395	require.Equal(t, structs.TaskStateDead, state.State)
2396
2397	consul := conf.Consul.(*consulapi.MockConsulServiceClient)
2398	consulOps := consul.GetOps()
2399	require.Len(t, consulOps, 8)
2400
2401	// Initial add
2402	require.Equal(t, "add", consulOps[0].Op)
2403
2404	// Removing canary and non-canary entries on first exit
2405	require.Equal(t, "remove", consulOps[1].Op)
2406	require.Equal(t, "remove", consulOps[2].Op)
2407
2408	// Second add on retry
2409	require.Equal(t, "add", consulOps[3].Op)
2410
2411	// Removing canary and non-canary entries on retry
2412	require.Equal(t, "remove", consulOps[4].Op)
2413	require.Equal(t, "remove", consulOps[5].Op)
2414
2415	// Removing canary and non-canary entries on stop
2416	require.Equal(t, "remove", consulOps[6].Op)
2417	require.Equal(t, "remove", consulOps[7].Op)
2418}
2419
2420// testWaitForTaskToStart waits for the task to be running or fails the test
2421func testWaitForTaskToStart(t *testing.T, tr *TaskRunner) {
2422	testutil.WaitForResult(func() (bool, error) {
2423		ts := tr.TaskState()
2424		return ts.State == structs.TaskStateRunning, fmt.Errorf("%v", ts.State)
2425	}, func(err error) {
2426		require.NoError(t, err)
2427	})
2428}
2429
2430// TestTaskRunner_BaseLabels tests that the base labels for the task metrics
2431// are set appropriately.
2432func TestTaskRunner_BaseLabels(t *testing.T) {
2433	t.Parallel()
2434	require := require.New(t)
2435
2436	alloc := mock.BatchAlloc()
2437	alloc.Namespace = "not-default"
2438	task := alloc.Job.TaskGroups[0].Tasks[0]
2439	task.Driver = "raw_exec"
2440	task.Config = map[string]interface{}{
2441		"command": "whoami",
2442	}
2443
2444	config, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
2445	defer cleanup()
2446
2447	tr, err := NewTaskRunner(config)
2448	require.NoError(err)
2449
2450	labels := map[string]string{}
2451	for _, e := range tr.baseLabels {
2452		labels[e.Name] = e.Value
2453	}
2454	require.Equal(alloc.Job.Name, labels["job"])
2455	require.Equal(alloc.TaskGroup, labels["task_group"])
2456	require.Equal(task.Name, labels["task"])
2457	require.Equal(alloc.ID, labels["alloc_id"])
2458	require.Equal(alloc.Namespace, labels["namespace"])
2459}
2460