1package command
2
3import (
4	"archive/tar"
5	"encoding/json"
6	"fmt"
7	"io/ioutil"
8	"os"
9	"path/filepath"
10	"strings"
11	"testing"
12	"time"
13
14	"github.com/hashicorp/vault/api"
15	"github.com/mholt/archiver"
16	"github.com/mitchellh/cli"
17)
18
19func testDebugCommand(tb testing.TB) (*cli.MockUi, *DebugCommand) {
20	tb.Helper()
21
22	ui := cli.NewMockUi()
23	return ui, &DebugCommand{
24		BaseCommand: &BaseCommand{
25			UI: ui,
26		},
27	}
28}
29
30func TestDebugCommand_Run(t *testing.T) {
31	t.Parallel()
32
33	testDir, err := ioutil.TempDir("", "vault-debug")
34	if err != nil {
35		t.Fatal(err)
36	}
37	defer os.RemoveAll(testDir)
38
39	cases := []struct {
40		name string
41		args []string
42		out  string
43		code int
44	}{
45		{
46			"valid",
47			[]string{
48				"-duration=1s",
49				fmt.Sprintf("-output=%s/valid", testDir),
50			},
51			"",
52			0,
53		},
54		{
55			"too_many_args",
56			[]string{
57				"-duration=1s",
58				fmt.Sprintf("-output=%s/too_many_args", testDir),
59				"foo",
60			},
61			"Too many arguments",
62			1,
63		},
64		{
65			"invalid_target",
66			[]string{
67				"-duration=1s",
68				fmt.Sprintf("-output=%s/invalid_target", testDir),
69				"-target=foo",
70			},
71			"Ignoring invalid targets: foo",
72			0,
73		},
74	}
75
76	for _, tc := range cases {
77		tc := tc
78
79		t.Run(tc.name, func(t *testing.T) {
80			t.Parallel()
81
82			client, closer := testVaultServer(t)
83			defer closer()
84
85			ui, cmd := testDebugCommand(t)
86			cmd.client = client
87			cmd.skipTimingChecks = true
88
89			code := cmd.Run(tc.args)
90			if code != tc.code {
91				t.Errorf("expected %d to be %d", code, tc.code)
92			}
93
94			combined := ui.OutputWriter.String() + ui.ErrorWriter.String()
95			if !strings.Contains(combined, tc.out) {
96				t.Fatalf("expected %q to contain %q", combined, tc.out)
97			}
98		})
99	}
100}
101
102func TestDebugCommand_Archive(t *testing.T) {
103	t.Parallel()
104
105	cases := []struct {
106		name        string
107		ext         string
108		expectError bool
109	}{
110		{
111			"no-ext",
112			"",
113			false,
114		},
115		{
116			"with-ext-tar-gz",
117			".tar.gz",
118			false,
119		},
120		{
121			"with-ext-tgz",
122			".tgz",
123			false,
124		},
125	}
126
127	for _, tc := range cases {
128		tc := tc
129
130		t.Run(tc.name, func(t *testing.T) {
131			t.Parallel()
132
133			// Create temp dirs for each test case since os.Stat and tgz.Walk
134			// (called down below) exhibits raciness otherwise.
135			testDir, err := ioutil.TempDir("", "vault-debug")
136			if err != nil {
137				t.Fatal(err)
138			}
139			defer os.RemoveAll(testDir)
140
141			client, closer := testVaultServer(t)
142			defer closer()
143
144			ui, cmd := testDebugCommand(t)
145			cmd.client = client
146			cmd.skipTimingChecks = true
147
148			// We use tc.name as the base path and apply the extension per
149			// test case.
150			basePath := tc.name
151			outputPath := filepath.Join(testDir, basePath+tc.ext)
152			args := []string{
153				"-duration=1s",
154				fmt.Sprintf("-output=%s", outputPath),
155				"-target=server-status",
156			}
157
158			code := cmd.Run(args)
159			if exp := 0; code != exp {
160				t.Log(ui.OutputWriter.String())
161				t.Log(ui.ErrorWriter.String())
162				t.Fatalf("expected %d to be %d", code, exp)
163			}
164			// If we expect an error we're done here
165			if tc.expectError {
166				return
167			}
168
169			expectedExt := tc.ext
170			if expectedExt == "" {
171				expectedExt = debugCompressionExt
172			}
173
174			bundlePath := filepath.Join(testDir, basePath+expectedExt)
175			_, err = os.Stat(bundlePath)
176			if os.IsNotExist(err) {
177				t.Log(ui.OutputWriter.String())
178				t.Fatal(err)
179			}
180
181			tgz := archiver.NewTarGz()
182			err = tgz.Walk(bundlePath, func(f archiver.File) error {
183				fh, ok := f.Header.(*tar.Header)
184				if !ok {
185					t.Fatalf("invalid file header: %#v", f.Header)
186				}
187
188				// Ignore base directory and index file
189				if fh.Name == basePath+"/" || fh.Name == filepath.Join(basePath, "index.json") {
190					return nil
191				}
192
193				if fh.Name != filepath.Join(basePath, "server_status.json") {
194					t.Fatalf("unxexpected file: %s", fh.Name)
195				}
196				return nil
197			})
198		})
199	}
200}
201
202func TestDebugCommand_CaptureTargets(t *testing.T) {
203	t.Parallel()
204
205	cases := []struct {
206		name          string
207		targets       []string
208		expectedFiles []string
209	}{
210		{
211			"config",
212			[]string{"config"},
213			[]string{"config.json"},
214		},
215		{
216			"host-info",
217			[]string{"host"},
218			[]string{"host_info.json"},
219		},
220		{
221			"metrics",
222			[]string{"metrics"},
223			[]string{"metrics.json"},
224		},
225		{
226			"replication-status",
227			[]string{"replication-status"},
228			[]string{"replication_status.json"},
229		},
230		{
231			"server-status",
232			[]string{"server-status"},
233			[]string{"server_status.json"},
234		},
235		{
236			"all-minus-pprof",
237			[]string{"config", "host", "metrics", "replication-status", "server-status"},
238			[]string{"config.json", "host_info.json", "metrics.json", "replication_status.json", "server_status.json"},
239		},
240	}
241
242	for _, tc := range cases {
243		tc := tc
244
245		t.Run(tc.name, func(t *testing.T) {
246			t.Parallel()
247
248			testDir, err := ioutil.TempDir("", "vault-debug")
249			if err != nil {
250				t.Fatal(err)
251			}
252			defer os.RemoveAll(testDir)
253
254			client, closer := testVaultServer(t)
255			defer closer()
256
257			ui, cmd := testDebugCommand(t)
258			cmd.client = client
259			cmd.skipTimingChecks = true
260
261			basePath := tc.name
262			args := []string{
263				"-duration=1s",
264				fmt.Sprintf("-output=%s/%s", testDir, basePath),
265			}
266			for _, target := range tc.targets {
267				args = append(args, fmt.Sprintf("-target=%s", target))
268			}
269
270			code := cmd.Run(args)
271			if exp := 0; code != exp {
272				t.Log(ui.ErrorWriter.String())
273				t.Fatalf("expected %d to be %d", code, exp)
274			}
275
276			bundlePath := filepath.Join(testDir, basePath+debugCompressionExt)
277			_, err = os.Open(bundlePath)
278			if err != nil {
279				t.Fatalf("failed to open archive: %s", err)
280			}
281
282			tgz := archiver.NewTarGz()
283			err = tgz.Walk(bundlePath, func(f archiver.File) error {
284				fh, ok := f.Header.(*tar.Header)
285				if !ok {
286					t.Fatalf("invalid file header: %#v", f.Header)
287				}
288
289				// Ignore base directory and index file
290				if fh.Name == basePath+"/" || fh.Name == filepath.Join(basePath, "index.json") {
291					return nil
292				}
293
294				for _, fileName := range tc.expectedFiles {
295					if fh.Name == filepath.Join(basePath, fileName) {
296						return nil
297					}
298				}
299
300				// If we reach here, it means that this is an unexpected file
301				return fmt.Errorf("unexpected file: %s", fh.Name)
302			})
303			if err != nil {
304				t.Fatal(err)
305			}
306		})
307	}
308}
309
310func TestDebugCommand_Pprof(t *testing.T) {
311	testDir, err := ioutil.TempDir("", "vault-debug")
312	if err != nil {
313		t.Fatal(err)
314	}
315	defer os.RemoveAll(testDir)
316
317	client, closer := testVaultServer(t)
318	defer closer()
319
320	ui, cmd := testDebugCommand(t)
321	cmd.client = client
322	cmd.skipTimingChecks = true
323
324	basePath := "pprof"
325	outputPath := filepath.Join(testDir, basePath)
326	// pprof requires a minimum interval of 1s, we set it to 2 to ensure it
327	// runs through and reduce flakiness on slower systems.
328	args := []string{
329		"-compress=false",
330		"-duration=2s",
331		"-interval=2s",
332		fmt.Sprintf("-output=%s", outputPath),
333		"-target=pprof",
334	}
335
336	code := cmd.Run(args)
337	if exp := 0; code != exp {
338		t.Log(ui.ErrorWriter.String())
339		t.Fatalf("expected %d to be %d", code, exp)
340	}
341
342	profiles := []string{"heap.prof", "goroutine.prof"}
343	pollingProfiles := []string{"profile.prof", "trace.out"}
344
345	// These are captures on the first (0th) and last (1st) frame
346	for _, v := range profiles {
347		files, _ := filepath.Glob(fmt.Sprintf("%s/*/%s", outputPath, v))
348		if len(files) != 2 {
349			t.Errorf("2 output files should exist for %s: got: %v", v, files)
350		}
351	}
352
353	// Since profile and trace are polling outputs, these only get captured
354	// on the first (0th) frame.
355	for _, v := range pollingProfiles {
356		files, _ := filepath.Glob(fmt.Sprintf("%s/*/%s", outputPath, v))
357		if len(files) != 1 {
358			t.Errorf("1 output file should exist for %s: got: %v", v, files)
359		}
360	}
361
362	t.Log(ui.OutputWriter.String())
363	t.Log(ui.ErrorWriter.String())
364}
365
366func TestDebugCommand_IndexFile(t *testing.T) {
367	t.Parallel()
368
369	testDir, err := ioutil.TempDir("", "vault-debug")
370	if err != nil {
371		t.Fatal(err)
372	}
373	defer os.RemoveAll(testDir)
374
375	client, closer := testVaultServer(t)
376	defer closer()
377
378	ui, cmd := testDebugCommand(t)
379	cmd.client = client
380	cmd.skipTimingChecks = true
381
382	basePath := "index-test"
383	outputPath := filepath.Join(testDir, basePath)
384	// pprof requires a minimum interval of 1s
385	args := []string{
386		"-compress=false",
387		"-duration=1s",
388		"-interval=1s",
389		"-metrics-interval=1s",
390		fmt.Sprintf("-output=%s", outputPath),
391	}
392
393	code := cmd.Run(args)
394	if exp := 0; code != exp {
395		t.Log(ui.ErrorWriter.String())
396		t.Fatalf("expected %d to be %d", code, exp)
397	}
398
399	content, err := ioutil.ReadFile(filepath.Join(outputPath, "index.json"))
400	if err != nil {
401		t.Fatal(err)
402	}
403
404	index := &debugIndex{}
405	if err := json.Unmarshal(content, index); err != nil {
406		t.Fatal(err)
407	}
408	if len(index.Output) == 0 {
409		t.Fatalf("expected valid index file: got: %v", index)
410	}
411}
412
413func TestDebugCommand_TimingChecks(t *testing.T) {
414	t.Parallel()
415
416	testDir, err := ioutil.TempDir("", "vault-debug")
417	if err != nil {
418		t.Fatal(err)
419	}
420	defer os.RemoveAll(testDir)
421
422	cases := []struct {
423		name            string
424		duration        string
425		interval        string
426		metricsInterval string
427	}{
428		{
429			"short-values-all",
430			"10ms",
431			"10ms",
432			"10ms",
433		},
434		{
435			"short-duration",
436			"10ms",
437			"",
438			"",
439		},
440		{
441			"short-interval",
442			debugMinInterval.String(),
443			"10ms",
444			"",
445		},
446		{
447			"short-metrics-interval",
448			debugMinInterval.String(),
449			"",
450			"10ms",
451		},
452	}
453
454	for _, tc := range cases {
455		tc := tc
456
457		t.Run(tc.name, func(t *testing.T) {
458			t.Parallel()
459
460			client, closer := testVaultServer(t)
461			defer closer()
462
463			// If we are past the minimum duration + some grace, trigger shutdown
464			// to prevent hanging
465			grace := 10 * time.Second
466			shutdownCh := make(chan struct{})
467			go func() {
468				time.AfterFunc(grace, func() {
469					close(shutdownCh)
470				})
471			}()
472
473			ui, cmd := testDebugCommand(t)
474			cmd.client = client
475			cmd.ShutdownCh = shutdownCh
476
477			basePath := tc.name
478			outputPath := filepath.Join(testDir, basePath)
479			// pprof requires a minimum interval of 1s
480			args := []string{
481				"-target=server-status",
482				fmt.Sprintf("-output=%s", outputPath),
483			}
484			if tc.duration != "" {
485				args = append(args, fmt.Sprintf("-duration=%s", tc.duration))
486			}
487			if tc.interval != "" {
488				args = append(args, fmt.Sprintf("-interval=%s", tc.interval))
489			}
490			if tc.metricsInterval != "" {
491				args = append(args, fmt.Sprintf("-metrics-interval=%s", tc.metricsInterval))
492			}
493
494			code := cmd.Run(args)
495			if exp := 0; code != exp {
496				t.Log(ui.ErrorWriter.String())
497				t.Fatalf("expected %d to be %d", code, exp)
498			}
499
500			if !strings.Contains(ui.OutputWriter.String(), "Duration: 5s") {
501				t.Fatal("expected minimum duration value")
502			}
503
504			if tc.interval != "" {
505				if !strings.Contains(ui.OutputWriter.String(), "  Interval: 5s") {
506					t.Fatal("expected minimum interval value")
507				}
508			}
509
510			if tc.metricsInterval != "" {
511				if !strings.Contains(ui.OutputWriter.String(), "Metrics Interval: 5s") {
512					t.Fatal("expected minimum metrics interval value")
513				}
514			}
515		})
516	}
517}
518
519func TestDebugCommand_NoConnection(t *testing.T) {
520	t.Parallel()
521
522	client, err := api.NewClient(nil)
523	if err != nil {
524		t.Fatal(err)
525	}
526
527	_, cmd := testDebugCommand(t)
528	cmd.client = client
529	cmd.skipTimingChecks = true
530
531	args := []string{
532		"-duration=1s",
533		"-target=server-status",
534	}
535
536	code := cmd.Run(args)
537	if exp := 1; code != exp {
538		t.Fatalf("expected %d to be %d", code, exp)
539	}
540}
541
542func TestDebugCommand_OutputExists(t *testing.T) {
543	t.Parallel()
544
545	cases := []struct {
546		name          string
547		compress      bool
548		outputFile    string
549		expectedError string
550	}{
551		{
552			"no-compress",
553			false,
554			"output-exists",
555			"output directory already exists",
556		},
557		{
558			"compress",
559			true,
560			"output-exist.tar.gz",
561			"output file already exists",
562		},
563	}
564
565	for _, tc := range cases {
566		tc := tc
567
568		t.Run(tc.name, func(t *testing.T) {
569			t.Parallel()
570
571			testDir, err := ioutil.TempDir("", "vault-debug")
572			if err != nil {
573				t.Fatal(err)
574			}
575			defer os.RemoveAll(testDir)
576
577			client, closer := testVaultServer(t)
578			defer closer()
579
580			ui, cmd := testDebugCommand(t)
581			cmd.client = client
582			cmd.skipTimingChecks = true
583
584			outputPath := filepath.Join(testDir, tc.outputFile)
585
586			// Create a conflicting file/directory
587			if tc.compress {
588				_, err = os.Create(outputPath)
589				if err != nil {
590					t.Fatal(err)
591				}
592			} else {
593				err = os.Mkdir(outputPath, 0o755)
594				if err != nil {
595					t.Fatal(err)
596				}
597			}
598
599			args := []string{
600				fmt.Sprintf("-compress=%t", tc.compress),
601				"-duration=1s",
602				"-interval=1s",
603				"-metrics-interval=1s",
604				fmt.Sprintf("-output=%s", outputPath),
605			}
606
607			code := cmd.Run(args)
608			if exp := 1; code != exp {
609				t.Log(ui.OutputWriter.String())
610				t.Log(ui.ErrorWriter.String())
611				t.Errorf("expected %d to be %d", code, exp)
612			}
613
614			output := ui.ErrorWriter.String() + ui.OutputWriter.String()
615			if !strings.Contains(output, tc.expectedError) {
616				t.Fatalf("expected %s, got: %s", tc.expectedError, output)
617			}
618		})
619	}
620}
621
622func TestDebugCommand_PartialPermissions(t *testing.T) {
623	t.Parallel()
624
625	testDir, err := ioutil.TempDir("", "vault-debug")
626	if err != nil {
627		t.Fatal(err)
628	}
629	defer os.RemoveAll(testDir)
630
631	client, closer := testVaultServer(t)
632	defer closer()
633
634	// Create a new token with default policy
635	resp, err := client.Logical().Write("auth/token/create", map[string]interface{}{
636		"policies": "default",
637	})
638	if err != nil {
639		t.Fatal(err)
640	}
641
642	client.SetToken(resp.Auth.ClientToken)
643
644	ui, cmd := testDebugCommand(t)
645	cmd.client = client
646	cmd.skipTimingChecks = true
647
648	basePath := "with-default-policy-token"
649	args := []string{
650		"-duration=1s",
651		fmt.Sprintf("-output=%s/%s", testDir, basePath),
652	}
653
654	code := cmd.Run(args)
655	if exp := 0; code != exp {
656		t.Log(ui.ErrorWriter.String())
657		t.Fatalf("expected %d to be %d", code, exp)
658	}
659
660	bundlePath := filepath.Join(testDir, basePath+debugCompressionExt)
661	_, err = os.Open(bundlePath)
662	if err != nil {
663		t.Fatalf("failed to open archive: %s", err)
664	}
665
666	tgz := archiver.NewTarGz()
667	err = tgz.Walk(bundlePath, func(f archiver.File) error {
668		fh, ok := f.Header.(*tar.Header)
669		if !ok {
670			t.Fatalf("invalid file header: %#v", f.Header)
671		}
672
673		// Ignore base directory and index file
674		if fh.Name == basePath+"/" {
675			return nil
676		}
677
678		// Ignore directories, which still get created by pprof but should
679		// otherwise be empty.
680		if fh.FileInfo().IsDir() {
681			return nil
682		}
683
684		switch {
685		case fh.Name == filepath.Join(basePath, "index.json"):
686		case fh.Name == filepath.Join(basePath, "replication_status.json"):
687		case fh.Name == filepath.Join(basePath, "server_status.json"):
688		case fh.Name == filepath.Join(basePath, "vault.log"):
689		default:
690			return fmt.Errorf("unexpected file: %s", fh.Name)
691		}
692
693		return nil
694	})
695	if err != nil {
696		t.Fatal(err)
697	}
698}
699