1package command 2 3import ( 4 "reflect" 5 "testing" 6 7 "github.com/hashicorp/vault/api" 8 "github.com/hashicorp/vault/sdk/helper/strutil" 9 "github.com/posener/complete" 10) 11 12func TestPredictVaultPaths(t *testing.T) { 13 t.Parallel() 14 15 client, closer := testVaultServer(t) 16 defer closer() 17 18 data := map[string]interface{}{"a": "b"} 19 if _, err := client.Logical().Write("secret/bar", data); err != nil { 20 t.Fatal(err) 21 } 22 if _, err := client.Logical().Write("secret/foo", data); err != nil { 23 t.Fatal(err) 24 } 25 if _, err := client.Logical().Write("secret/zip/zap", data); err != nil { 26 t.Fatal(err) 27 } 28 if _, err := client.Logical().Write("secret/zip/zonk", data); err != nil { 29 t.Fatal(err) 30 } 31 if _, err := client.Logical().Write("secret/zip/twoot", data); err != nil { 32 t.Fatal(err) 33 } 34 35 cases := []struct { 36 name string 37 args complete.Args 38 includeFiles bool 39 exp []string 40 }{ 41 { 42 "has_args", 43 complete.Args{ 44 All: []string{"read", "secret/foo", "a=b"}, 45 Last: "a=b", 46 }, 47 true, 48 nil, 49 }, 50 { 51 "has_args_no_files", 52 complete.Args{ 53 All: []string{"read", "secret/foo", "a=b"}, 54 Last: "a=b", 55 }, 56 false, 57 nil, 58 }, 59 { 60 "part_mount", 61 complete.Args{ 62 All: []string{"read", "s"}, 63 Last: "s", 64 }, 65 true, 66 []string{"secret/", "sys/"}, 67 }, 68 { 69 "part_mount_no_files", 70 complete.Args{ 71 All: []string{"read", "s"}, 72 Last: "s", 73 }, 74 false, 75 []string{"secret/", "sys/"}, 76 }, 77 { 78 "only_mount", 79 complete.Args{ 80 All: []string{"read", "sec"}, 81 Last: "sec", 82 }, 83 true, 84 []string{"secret/bar", "secret/foo", "secret/zip/"}, 85 }, 86 { 87 "only_mount_no_files", 88 complete.Args{ 89 All: []string{"read", "sec"}, 90 Last: "sec", 91 }, 92 false, 93 []string{"secret/zip/"}, 94 }, 95 { 96 "full_mount", 97 complete.Args{ 98 All: []string{"read", "secret"}, 99 Last: "secret", 100 }, 101 true, 102 []string{"secret/bar", "secret/foo", "secret/zip/"}, 103 }, 104 { 105 "full_mount_no_files", 106 complete.Args{ 107 All: []string{"read", "secret"}, 108 Last: "secret", 109 }, 110 false, 111 []string{"secret/zip/"}, 112 }, 113 { 114 "full_mount_slash", 115 complete.Args{ 116 All: []string{"read", "secret/"}, 117 Last: "secret/", 118 }, 119 true, 120 []string{"secret/bar", "secret/foo", "secret/zip/"}, 121 }, 122 { 123 "full_mount_slash_no_files", 124 complete.Args{ 125 All: []string{"read", "secret/"}, 126 Last: "secret/", 127 }, 128 false, 129 []string{"secret/zip/"}, 130 }, 131 { 132 "path_partial", 133 complete.Args{ 134 All: []string{"read", "secret/z"}, 135 Last: "secret/z", 136 }, 137 true, 138 []string{"secret/zip/twoot", "secret/zip/zap", "secret/zip/zonk"}, 139 }, 140 { 141 "path_partial_no_files", 142 complete.Args{ 143 All: []string{"read", "secret/z"}, 144 Last: "secret/z", 145 }, 146 false, 147 []string{"secret/zip/"}, 148 }, 149 { 150 "subpath_partial_z", 151 complete.Args{ 152 All: []string{"read", "secret/zip/z"}, 153 Last: "secret/zip/z", 154 }, 155 true, 156 []string{"secret/zip/zap", "secret/zip/zonk"}, 157 }, 158 { 159 "subpath_partial_z_no_files", 160 complete.Args{ 161 All: []string{"read", "secret/zip/z"}, 162 Last: "secret/zip/z", 163 }, 164 false, 165 []string{"secret/zip/z"}, 166 }, 167 { 168 "subpath_partial_t", 169 complete.Args{ 170 All: []string{"read", "secret/zip/t"}, 171 Last: "secret/zip/t", 172 }, 173 true, 174 []string{"secret/zip/twoot"}, 175 }, 176 { 177 "subpath_partial_t_no_files", 178 complete.Args{ 179 All: []string{"read", "secret/zip/t"}, 180 Last: "secret/zip/t", 181 }, 182 false, 183 []string{"secret/zip/t"}, 184 }, 185 } 186 187 t.Run("group", func(t *testing.T) { 188 for _, tc := range cases { 189 tc := tc 190 t.Run(tc.name, func(t *testing.T) { 191 t.Parallel() 192 193 p := NewPredict() 194 p.client = client 195 196 f := p.vaultPaths(tc.includeFiles) 197 act := f(tc.args) 198 if !reflect.DeepEqual(act, tc.exp) { 199 t.Errorf("expected %q to be %q", act, tc.exp) 200 } 201 }) 202 } 203 }) 204} 205 206func TestPredict_Audits(t *testing.T) { 207 t.Parallel() 208 209 client, closer := testVaultServer(t) 210 defer closer() 211 212 badClient, badCloser := testVaultServerBad(t) 213 defer badCloser() 214 215 if err := client.Sys().EnableAuditWithOptions("file", &api.EnableAuditOptions{ 216 Type: "file", 217 Options: map[string]string{ 218 "file_path": "discard", 219 }, 220 }); err != nil { 221 t.Fatal(err) 222 } 223 224 cases := []struct { 225 name string 226 client *api.Client 227 exp []string 228 }{ 229 { 230 "not_connected_client", 231 badClient, 232 nil, 233 }, 234 { 235 "good_path", 236 client, 237 []string{"file/"}, 238 }, 239 } 240 241 t.Run("group", func(t *testing.T) { 242 for _, tc := range cases { 243 tc := tc 244 t.Run(tc.name, func(t *testing.T) { 245 t.Parallel() 246 247 p := NewPredict() 248 p.client = tc.client 249 250 act := p.audits() 251 if !reflect.DeepEqual(act, tc.exp) { 252 t.Errorf("expected %q to be %q", act, tc.exp) 253 } 254 }) 255 } 256 }) 257} 258 259func TestPredict_Mounts(t *testing.T) { 260 t.Parallel() 261 262 client, closer := testVaultServer(t) 263 defer closer() 264 265 badClient, badCloser := testVaultServerBad(t) 266 defer badCloser() 267 268 cases := []struct { 269 name string 270 client *api.Client 271 exp []string 272 }{ 273 { 274 "not_connected_client", 275 badClient, 276 defaultPredictVaultMounts, 277 }, 278 { 279 "good_path", 280 client, 281 []string{"cubbyhole/", "identity/", "secret/", "sys/"}, 282 }, 283 } 284 285 t.Run("group", func(t *testing.T) { 286 for _, tc := range cases { 287 tc := tc 288 t.Run(tc.name, func(t *testing.T) { 289 t.Parallel() 290 291 p := NewPredict() 292 p.client = tc.client 293 294 act := p.mounts() 295 if !reflect.DeepEqual(act, tc.exp) { 296 t.Errorf("expected %q to be %q", act, tc.exp) 297 } 298 }) 299 } 300 }) 301} 302 303func TestPredict_Plugins(t *testing.T) { 304 t.Parallel() 305 306 client, closer := testVaultServer(t) 307 defer closer() 308 309 badClient, badCloser := testVaultServerBad(t) 310 defer badCloser() 311 312 cases := []struct { 313 name string 314 client *api.Client 315 exp []string 316 }{ 317 { 318 "not_connected_client", 319 badClient, 320 nil, 321 }, 322 { 323 "good_path", 324 client, 325 []string{ 326 "ad", 327 "alicloud", 328 "app-id", 329 "approle", 330 "aws", 331 "azure", 332 "cassandra", 333 "cassandra-database-plugin", 334 "centrify", 335 "cert", 336 "consul", 337 "elasticsearch-database-plugin", 338 "gcp", 339 "gcpkms", 340 "github", 341 "hana-database-plugin", 342 "influxdb-database-plugin", 343 "jwt", 344 "kmip", 345 "kubernetes", 346 "kv", 347 "ldap", 348 "mongodb", 349 "mongodb-database-plugin", 350 "mssql", 351 "mssql-database-plugin", 352 "mysql", 353 "mysql-aurora-database-plugin", 354 "mysql-database-plugin", 355 "mysql-legacy-database-plugin", 356 "mysql-rds-database-plugin", 357 "nomad", 358 "oidc", 359 "okta", 360 "pcf", 361 "pki", 362 "postgresql", 363 "postgresql-database-plugin", 364 "rabbitmq", 365 "radius", 366 "ssh", 367 "totp", 368 "transit", 369 "userpass", 370 }, 371 }, 372 } 373 374 t.Run("group", func(t *testing.T) { 375 for _, tc := range cases { 376 tc := tc 377 t.Run(tc.name, func(t *testing.T) { 378 t.Parallel() 379 380 p := NewPredict() 381 p.client = tc.client 382 383 act := p.plugins() 384 385 if !strutil.StrListContains(act, "kmip") { 386 for i, v := range tc.exp { 387 if v == "kmip" { 388 tc.exp = append(tc.exp[:i], tc.exp[i+1:]...) 389 break 390 } 391 } 392 } 393 if !reflect.DeepEqual(act, tc.exp) { 394 t.Errorf("expected %q to be %q", act, tc.exp) 395 } 396 }) 397 } 398 }) 399} 400 401func TestPredict_Policies(t *testing.T) { 402 t.Parallel() 403 404 client, closer := testVaultServer(t) 405 defer closer() 406 407 badClient, badCloser := testVaultServerBad(t) 408 defer badCloser() 409 410 cases := []struct { 411 name string 412 client *api.Client 413 exp []string 414 }{ 415 { 416 "not_connected_client", 417 badClient, 418 nil, 419 }, 420 { 421 "good_path", 422 client, 423 []string{"default", "root"}, 424 }, 425 } 426 427 t.Run("group", func(t *testing.T) { 428 for _, tc := range cases { 429 tc := tc 430 t.Run(tc.name, func(t *testing.T) { 431 t.Parallel() 432 433 p := NewPredict() 434 p.client = tc.client 435 436 act := p.policies() 437 if !reflect.DeepEqual(act, tc.exp) { 438 t.Errorf("expected %q to be %q", act, tc.exp) 439 } 440 }) 441 } 442 }) 443} 444 445func TestPredict_Paths(t *testing.T) { 446 t.Parallel() 447 448 client, closer := testVaultServer(t) 449 defer closer() 450 451 data := map[string]interface{}{"a": "b"} 452 if _, err := client.Logical().Write("secret/bar", data); err != nil { 453 t.Fatal(err) 454 } 455 if _, err := client.Logical().Write("secret/foo", data); err != nil { 456 t.Fatal(err) 457 } 458 if _, err := client.Logical().Write("secret/zip/zap", data); err != nil { 459 t.Fatal(err) 460 } 461 462 cases := []struct { 463 name string 464 path string 465 includeFiles bool 466 exp []string 467 }{ 468 { 469 "bad_path", 470 "nope/not/a/real/path/ever", 471 true, 472 []string{"nope/not/a/real/path/ever"}, 473 }, 474 { 475 "good_path", 476 "secret/", 477 true, 478 []string{"secret/bar", "secret/foo", "secret/zip/"}, 479 }, 480 { 481 "good_path_no_files", 482 "secret/", 483 false, 484 []string{"secret/zip/"}, 485 }, 486 { 487 "partial_match", 488 "secret/z", 489 true, 490 []string{"secret/zip/"}, 491 }, 492 { 493 "partial_match_no_files", 494 "secret/z", 495 false, 496 []string{"secret/zip/"}, 497 }, 498 } 499 500 t.Run("group", func(t *testing.T) { 501 for _, tc := range cases { 502 tc := tc 503 t.Run(tc.name, func(t *testing.T) { 504 t.Parallel() 505 506 p := NewPredict() 507 p.client = client 508 509 act := p.paths(tc.path, tc.includeFiles) 510 if !reflect.DeepEqual(act, tc.exp) { 511 t.Errorf("expected %q to be %q", act, tc.exp) 512 } 513 }) 514 } 515 }) 516} 517 518func TestPredict_ListPaths(t *testing.T) { 519 t.Parallel() 520 521 client, closer := testVaultServer(t) 522 defer closer() 523 524 badClient, badCloser := testVaultServerBad(t) 525 defer badCloser() 526 527 data := map[string]interface{}{"a": "b"} 528 if _, err := client.Logical().Write("secret/bar", data); err != nil { 529 t.Fatal(err) 530 } 531 if _, err := client.Logical().Write("secret/foo", data); err != nil { 532 t.Fatal(err) 533 } 534 535 cases := []struct { 536 name string 537 client *api.Client 538 path string 539 exp []string 540 }{ 541 { 542 "bad_path", 543 client, 544 "nope/not/a/real/path/ever", 545 nil, 546 }, 547 { 548 "good_path", 549 client, 550 "secret/", 551 []string{"bar", "foo"}, 552 }, 553 { 554 "not_connected_client", 555 badClient, 556 "secret/", 557 nil, 558 }, 559 } 560 561 t.Run("group", func(t *testing.T) { 562 for _, tc := range cases { 563 tc := tc 564 t.Run(tc.name, func(t *testing.T) { 565 t.Parallel() 566 567 p := NewPredict() 568 p.client = tc.client 569 570 act := p.listPaths(tc.path) 571 if !reflect.DeepEqual(act, tc.exp) { 572 t.Errorf("expected %q to be %q", act, tc.exp) 573 } 574 }) 575 } 576 }) 577} 578 579func TestPredict_HasPathArg(t *testing.T) { 580 t.Parallel() 581 582 cases := []struct { 583 name string 584 args []string 585 exp bool 586 }{ 587 { 588 "nil", 589 nil, 590 false, 591 }, 592 { 593 "empty", 594 []string{}, 595 false, 596 }, 597 { 598 "empty_string", 599 []string{""}, 600 false, 601 }, 602 { 603 "single", 604 []string{"foo"}, 605 false, 606 }, 607 { 608 "multiple", 609 []string{"foo", "bar", "baz"}, 610 true, 611 }, 612 } 613 614 for _, tc := range cases { 615 tc := tc 616 t.Run(tc.name, func(t *testing.T) { 617 t.Parallel() 618 619 p := NewPredict() 620 if act := p.hasPathArg(tc.args); act != tc.exp { 621 t.Errorf("expected %t to be %t", act, tc.exp) 622 } 623 }) 624 } 625} 626