1package command 2 3import ( 4 "archive/tar" 5 "encoding/json" 6 "fmt" 7 "io/ioutil" 8 "os" 9 "path/filepath" 10 "strings" 11 "testing" 12 "time" 13 14 "github.com/hashicorp/vault/api" 15 "github.com/mholt/archiver" 16 "github.com/mitchellh/cli" 17) 18 19func testDebugCommand(tb testing.TB) (*cli.MockUi, *DebugCommand) { 20 tb.Helper() 21 22 ui := cli.NewMockUi() 23 return ui, &DebugCommand{ 24 BaseCommand: &BaseCommand{ 25 UI: ui, 26 }, 27 } 28} 29 30func TestDebugCommand_Run(t *testing.T) { 31 t.Parallel() 32 33 testDir, err := ioutil.TempDir("", "vault-debug") 34 if err != nil { 35 t.Fatal(err) 36 } 37 defer os.RemoveAll(testDir) 38 39 cases := []struct { 40 name string 41 args []string 42 out string 43 code int 44 }{ 45 { 46 "valid", 47 []string{ 48 "-duration=1s", 49 fmt.Sprintf("-output=%s/valid", testDir), 50 }, 51 "", 52 0, 53 }, 54 { 55 "too_many_args", 56 []string{ 57 "-duration=1s", 58 fmt.Sprintf("-output=%s/too_many_args", testDir), 59 "foo", 60 }, 61 "Too many arguments", 62 1, 63 }, 64 { 65 "invalid_target", 66 []string{ 67 "-duration=1s", 68 fmt.Sprintf("-output=%s/invalid_target", testDir), 69 "-target=foo", 70 }, 71 "Ignoring invalid targets: foo", 72 0, 73 }, 74 } 75 76 for _, tc := range cases { 77 tc := tc 78 79 t.Run(tc.name, func(t *testing.T) { 80 t.Parallel() 81 82 client, closer := testVaultServer(t) 83 defer closer() 84 85 ui, cmd := testDebugCommand(t) 86 cmd.client = client 87 cmd.skipTimingChecks = true 88 89 code := cmd.Run(tc.args) 90 if code != tc.code { 91 t.Errorf("expected %d to be %d", code, tc.code) 92 } 93 94 combined := ui.OutputWriter.String() + ui.ErrorWriter.String() 95 if !strings.Contains(combined, tc.out) { 96 t.Fatalf("expected %q to contain %q", combined, tc.out) 97 } 98 }) 99 } 100} 101 102func TestDebugCommand_Archive(t *testing.T) { 103 t.Parallel() 104 105 cases := []struct { 106 name string 107 ext string 108 expectError bool 109 }{ 110 { 111 "no-ext", 112 "", 113 false, 114 }, 115 { 116 "with-ext-tar-gz", 117 ".tar.gz", 118 false, 119 }, 120 { 121 "with-ext-tgz", 122 ".tgz", 123 false, 124 }, 125 } 126 127 for _, tc := range cases { 128 tc := tc 129 130 t.Run(tc.name, func(t *testing.T) { 131 t.Parallel() 132 133 // Create temp dirs for each test case since os.Stat and tgz.Walk 134 // (called down below) exhibits raciness otherwise. 135 testDir, err := ioutil.TempDir("", "vault-debug") 136 if err != nil { 137 t.Fatal(err) 138 } 139 defer os.RemoveAll(testDir) 140 141 client, closer := testVaultServer(t) 142 defer closer() 143 144 ui, cmd := testDebugCommand(t) 145 cmd.client = client 146 cmd.skipTimingChecks = true 147 148 // We use tc.name as the base path and apply the extension per 149 // test case. 150 basePath := tc.name 151 outputPath := filepath.Join(testDir, basePath+tc.ext) 152 args := []string{ 153 "-duration=1s", 154 fmt.Sprintf("-output=%s", outputPath), 155 "-target=server-status", 156 } 157 158 code := cmd.Run(args) 159 if exp := 0; code != exp { 160 t.Log(ui.OutputWriter.String()) 161 t.Log(ui.ErrorWriter.String()) 162 t.Fatalf("expected %d to be %d", code, exp) 163 } 164 // If we expect an error we're done here 165 if tc.expectError { 166 return 167 } 168 169 expectedExt := tc.ext 170 if expectedExt == "" { 171 expectedExt = debugCompressionExt 172 } 173 174 bundlePath := filepath.Join(testDir, basePath+expectedExt) 175 _, err = os.Stat(bundlePath) 176 if os.IsNotExist(err) { 177 t.Log(ui.OutputWriter.String()) 178 t.Fatal(err) 179 } 180 181 tgz := archiver.NewTarGz() 182 err = tgz.Walk(bundlePath, func(f archiver.File) error { 183 fh, ok := f.Header.(*tar.Header) 184 if !ok { 185 t.Fatalf("invalid file header: %#v", f.Header) 186 } 187 188 // Ignore base directory and index file 189 if fh.Name == basePath+"/" || fh.Name == filepath.Join(basePath, "index.json") { 190 return nil 191 } 192 193 if fh.Name != filepath.Join(basePath, "server_status.json") { 194 t.Fatalf("unxexpected file: %s", fh.Name) 195 } 196 return nil 197 }) 198 }) 199 } 200} 201 202func TestDebugCommand_CaptureTargets(t *testing.T) { 203 t.Parallel() 204 205 cases := []struct { 206 name string 207 targets []string 208 expectedFiles []string 209 }{ 210 { 211 "config", 212 []string{"config"}, 213 []string{"config.json"}, 214 }, 215 { 216 "host-info", 217 []string{"host"}, 218 []string{"host_info.json"}, 219 }, 220 { 221 "metrics", 222 []string{"metrics"}, 223 []string{"metrics.json"}, 224 }, 225 { 226 "replication-status", 227 []string{"replication-status"}, 228 []string{"replication_status.json"}, 229 }, 230 { 231 "server-status", 232 []string{"server-status"}, 233 []string{"server_status.json"}, 234 }, 235 { 236 "all-minus-pprof", 237 []string{"config", "host", "metrics", "replication-status", "server-status"}, 238 []string{"config.json", "host_info.json", "metrics.json", "replication_status.json", "server_status.json"}, 239 }, 240 } 241 242 for _, tc := range cases { 243 tc := tc 244 245 t.Run(tc.name, func(t *testing.T) { 246 t.Parallel() 247 248 testDir, err := ioutil.TempDir("", "vault-debug") 249 if err != nil { 250 t.Fatal(err) 251 } 252 defer os.RemoveAll(testDir) 253 254 client, closer := testVaultServer(t) 255 defer closer() 256 257 ui, cmd := testDebugCommand(t) 258 cmd.client = client 259 cmd.skipTimingChecks = true 260 261 basePath := tc.name 262 args := []string{ 263 "-duration=1s", 264 fmt.Sprintf("-output=%s/%s", testDir, basePath), 265 } 266 for _, target := range tc.targets { 267 args = append(args, fmt.Sprintf("-target=%s", target)) 268 } 269 270 code := cmd.Run(args) 271 if exp := 0; code != exp { 272 t.Log(ui.ErrorWriter.String()) 273 t.Fatalf("expected %d to be %d", code, exp) 274 } 275 276 bundlePath := filepath.Join(testDir, basePath+debugCompressionExt) 277 _, err = os.Open(bundlePath) 278 if err != nil { 279 t.Fatalf("failed to open archive: %s", err) 280 } 281 282 tgz := archiver.NewTarGz() 283 err = tgz.Walk(bundlePath, func(f archiver.File) error { 284 fh, ok := f.Header.(*tar.Header) 285 if !ok { 286 t.Fatalf("invalid file header: %#v", f.Header) 287 } 288 289 // Ignore base directory and index file 290 if fh.Name == basePath+"/" || fh.Name == filepath.Join(basePath, "index.json") { 291 return nil 292 } 293 294 for _, fileName := range tc.expectedFiles { 295 if fh.Name == filepath.Join(basePath, fileName) { 296 return nil 297 } 298 } 299 300 // If we reach here, it means that this is an unexpected file 301 return fmt.Errorf("unexpected file: %s", fh.Name) 302 }) 303 if err != nil { 304 t.Fatal(err) 305 } 306 }) 307 } 308} 309 310func TestDebugCommand_Pprof(t *testing.T) { 311 testDir, err := ioutil.TempDir("", "vault-debug") 312 if err != nil { 313 t.Fatal(err) 314 } 315 defer os.RemoveAll(testDir) 316 317 client, closer := testVaultServer(t) 318 defer closer() 319 320 ui, cmd := testDebugCommand(t) 321 cmd.client = client 322 cmd.skipTimingChecks = true 323 324 basePath := "pprof" 325 outputPath := filepath.Join(testDir, basePath) 326 // pprof requires a minimum interval of 1s, we set it to 2 to ensure it 327 // runs through and reduce flakiness on slower systems. 328 args := []string{ 329 "-compress=false", 330 "-duration=2s", 331 "-interval=2s", 332 fmt.Sprintf("-output=%s", outputPath), 333 "-target=pprof", 334 } 335 336 code := cmd.Run(args) 337 if exp := 0; code != exp { 338 t.Log(ui.ErrorWriter.String()) 339 t.Fatalf("expected %d to be %d", code, exp) 340 } 341 342 profiles := []string{"heap.prof", "goroutine.prof"} 343 pollingProfiles := []string{"profile.prof", "trace.out"} 344 345 // These are captures on the first (0th) and last (1st) frame 346 for _, v := range profiles { 347 files, _ := filepath.Glob(fmt.Sprintf("%s/*/%s", outputPath, v)) 348 if len(files) != 2 { 349 t.Errorf("2 output files should exist for %s: got: %v", v, files) 350 } 351 } 352 353 // Since profile and trace are polling outputs, these only get captured 354 // on the first (0th) frame. 355 for _, v := range pollingProfiles { 356 files, _ := filepath.Glob(fmt.Sprintf("%s/*/%s", outputPath, v)) 357 if len(files) != 1 { 358 t.Errorf("1 output file should exist for %s: got: %v", v, files) 359 } 360 } 361 362 t.Log(ui.OutputWriter.String()) 363 t.Log(ui.ErrorWriter.String()) 364} 365 366func TestDebugCommand_IndexFile(t *testing.T) { 367 t.Parallel() 368 369 testDir, err := ioutil.TempDir("", "vault-debug") 370 if err != nil { 371 t.Fatal(err) 372 } 373 defer os.RemoveAll(testDir) 374 375 client, closer := testVaultServer(t) 376 defer closer() 377 378 ui, cmd := testDebugCommand(t) 379 cmd.client = client 380 cmd.skipTimingChecks = true 381 382 basePath := "index-test" 383 outputPath := filepath.Join(testDir, basePath) 384 // pprof requires a minimum interval of 1s 385 args := []string{ 386 "-compress=false", 387 "-duration=1s", 388 "-interval=1s", 389 "-metrics-interval=1s", 390 fmt.Sprintf("-output=%s", outputPath), 391 } 392 393 code := cmd.Run(args) 394 if exp := 0; code != exp { 395 t.Log(ui.ErrorWriter.String()) 396 t.Fatalf("expected %d to be %d", code, exp) 397 } 398 399 content, err := ioutil.ReadFile(filepath.Join(outputPath, "index.json")) 400 if err != nil { 401 t.Fatal(err) 402 } 403 404 index := &debugIndex{} 405 if err := json.Unmarshal(content, index); err != nil { 406 t.Fatal(err) 407 } 408 if len(index.Output) == 0 { 409 t.Fatalf("expected valid index file: got: %v", index) 410 } 411} 412 413func TestDebugCommand_TimingChecks(t *testing.T) { 414 t.Parallel() 415 416 testDir, err := ioutil.TempDir("", "vault-debug") 417 if err != nil { 418 t.Fatal(err) 419 } 420 defer os.RemoveAll(testDir) 421 422 cases := []struct { 423 name string 424 duration string 425 interval string 426 metricsInterval string 427 }{ 428 { 429 "short-values-all", 430 "10ms", 431 "10ms", 432 "10ms", 433 }, 434 { 435 "short-duration", 436 "10ms", 437 "", 438 "", 439 }, 440 { 441 "short-interval", 442 debugMinInterval.String(), 443 "10ms", 444 "", 445 }, 446 { 447 "short-metrics-interval", 448 debugMinInterval.String(), 449 "", 450 "10ms", 451 }, 452 } 453 454 for _, tc := range cases { 455 tc := tc 456 457 t.Run(tc.name, func(t *testing.T) { 458 t.Parallel() 459 460 client, closer := testVaultServer(t) 461 defer closer() 462 463 // If we are past the minimum duration + some grace, trigger shutdown 464 // to prevent hanging 465 grace := 10 * time.Second 466 shutdownCh := make(chan struct{}) 467 go func() { 468 time.AfterFunc(grace, func() { 469 close(shutdownCh) 470 }) 471 }() 472 473 ui, cmd := testDebugCommand(t) 474 cmd.client = client 475 cmd.ShutdownCh = shutdownCh 476 477 basePath := tc.name 478 outputPath := filepath.Join(testDir, basePath) 479 // pprof requires a minimum interval of 1s 480 args := []string{ 481 "-target=server-status", 482 fmt.Sprintf("-output=%s", outputPath), 483 } 484 if tc.duration != "" { 485 args = append(args, fmt.Sprintf("-duration=%s", tc.duration)) 486 } 487 if tc.interval != "" { 488 args = append(args, fmt.Sprintf("-interval=%s", tc.interval)) 489 } 490 if tc.metricsInterval != "" { 491 args = append(args, fmt.Sprintf("-metrics-interval=%s", tc.metricsInterval)) 492 } 493 494 code := cmd.Run(args) 495 if exp := 0; code != exp { 496 t.Log(ui.ErrorWriter.String()) 497 t.Fatalf("expected %d to be %d", code, exp) 498 } 499 500 if !strings.Contains(ui.OutputWriter.String(), "Duration: 5s") { 501 t.Fatal("expected minimum duration value") 502 } 503 504 if tc.interval != "" { 505 if !strings.Contains(ui.OutputWriter.String(), " Interval: 5s") { 506 t.Fatal("expected minimum interval value") 507 } 508 } 509 510 if tc.metricsInterval != "" { 511 if !strings.Contains(ui.OutputWriter.String(), "Metrics Interval: 5s") { 512 t.Fatal("expected minimum metrics interval value") 513 } 514 } 515 }) 516 } 517} 518 519func TestDebugCommand_NoConnection(t *testing.T) { 520 t.Parallel() 521 522 client, err := api.NewClient(nil) 523 if err != nil { 524 t.Fatal(err) 525 } 526 527 _, cmd := testDebugCommand(t) 528 cmd.client = client 529 cmd.skipTimingChecks = true 530 531 args := []string{ 532 "-duration=1s", 533 "-target=server-status", 534 } 535 536 code := cmd.Run(args) 537 if exp := 1; code != exp { 538 t.Fatalf("expected %d to be %d", code, exp) 539 } 540} 541 542func TestDebugCommand_OutputExists(t *testing.T) { 543 t.Parallel() 544 545 cases := []struct { 546 name string 547 compress bool 548 outputFile string 549 expectedError string 550 }{ 551 { 552 "no-compress", 553 false, 554 "output-exists", 555 "output directory already exists", 556 }, 557 { 558 "compress", 559 true, 560 "output-exist.tar.gz", 561 "output file already exists", 562 }, 563 } 564 565 for _, tc := range cases { 566 tc := tc 567 568 t.Run(tc.name, func(t *testing.T) { 569 t.Parallel() 570 571 testDir, err := ioutil.TempDir("", "vault-debug") 572 if err != nil { 573 t.Fatal(err) 574 } 575 defer os.RemoveAll(testDir) 576 577 client, closer := testVaultServer(t) 578 defer closer() 579 580 ui, cmd := testDebugCommand(t) 581 cmd.client = client 582 cmd.skipTimingChecks = true 583 584 outputPath := filepath.Join(testDir, tc.outputFile) 585 586 // Create a conflicting file/directory 587 if tc.compress { 588 _, err = os.Create(outputPath) 589 if err != nil { 590 t.Fatal(err) 591 } 592 } else { 593 err = os.Mkdir(outputPath, 0o755) 594 if err != nil { 595 t.Fatal(err) 596 } 597 } 598 599 args := []string{ 600 fmt.Sprintf("-compress=%t", tc.compress), 601 "-duration=1s", 602 "-interval=1s", 603 "-metrics-interval=1s", 604 fmt.Sprintf("-output=%s", outputPath), 605 } 606 607 code := cmd.Run(args) 608 if exp := 1; code != exp { 609 t.Log(ui.OutputWriter.String()) 610 t.Log(ui.ErrorWriter.String()) 611 t.Errorf("expected %d to be %d", code, exp) 612 } 613 614 output := ui.ErrorWriter.String() + ui.OutputWriter.String() 615 if !strings.Contains(output, tc.expectedError) { 616 t.Fatalf("expected %s, got: %s", tc.expectedError, output) 617 } 618 }) 619 } 620} 621 622func TestDebugCommand_PartialPermissions(t *testing.T) { 623 t.Parallel() 624 625 testDir, err := ioutil.TempDir("", "vault-debug") 626 if err != nil { 627 t.Fatal(err) 628 } 629 defer os.RemoveAll(testDir) 630 631 client, closer := testVaultServer(t) 632 defer closer() 633 634 // Create a new token with default policy 635 resp, err := client.Logical().Write("auth/token/create", map[string]interface{}{ 636 "policies": "default", 637 }) 638 if err != nil { 639 t.Fatal(err) 640 } 641 642 client.SetToken(resp.Auth.ClientToken) 643 644 ui, cmd := testDebugCommand(t) 645 cmd.client = client 646 cmd.skipTimingChecks = true 647 648 basePath := "with-default-policy-token" 649 args := []string{ 650 "-duration=1s", 651 fmt.Sprintf("-output=%s/%s", testDir, basePath), 652 } 653 654 code := cmd.Run(args) 655 if exp := 0; code != exp { 656 t.Log(ui.ErrorWriter.String()) 657 t.Fatalf("expected %d to be %d", code, exp) 658 } 659 660 bundlePath := filepath.Join(testDir, basePath+debugCompressionExt) 661 _, err = os.Open(bundlePath) 662 if err != nil { 663 t.Fatalf("failed to open archive: %s", err) 664 } 665 666 tgz := archiver.NewTarGz() 667 err = tgz.Walk(bundlePath, func(f archiver.File) error { 668 fh, ok := f.Header.(*tar.Header) 669 if !ok { 670 t.Fatalf("invalid file header: %#v", f.Header) 671 } 672 673 // Ignore base directory and index file 674 if fh.Name == basePath+"/" { 675 return nil 676 } 677 678 // Ignore directories, which still get created by pprof but should 679 // otherwise be empty. 680 if fh.FileInfo().IsDir() { 681 return nil 682 } 683 684 switch { 685 case fh.Name == filepath.Join(basePath, "index.json"): 686 case fh.Name == filepath.Join(basePath, "replication_status.json"): 687 case fh.Name == filepath.Join(basePath, "server_status.json"): 688 case fh.Name == filepath.Join(basePath, "vault.log"): 689 default: 690 return fmt.Errorf("unexpected file: %s", fh.Name) 691 } 692 693 return nil 694 }) 695 if err != nil { 696 t.Fatal(err) 697 } 698} 699