1package command
2
3import (
4	"reflect"
5	"testing"
6
7	"github.com/hashicorp/vault/api"
8	"github.com/hashicorp/vault/sdk/helper/strutil"
9	"github.com/posener/complete"
10)
11
12func TestPredictVaultPaths(t *testing.T) {
13	t.Parallel()
14
15	client, closer := testVaultServer(t)
16	defer closer()
17
18	data := map[string]interface{}{"a": "b"}
19	if _, err := client.Logical().Write("secret/bar", data); err != nil {
20		t.Fatal(err)
21	}
22	if _, err := client.Logical().Write("secret/foo", data); err != nil {
23		t.Fatal(err)
24	}
25	if _, err := client.Logical().Write("secret/zip/zap", data); err != nil {
26		t.Fatal(err)
27	}
28	if _, err := client.Logical().Write("secret/zip/zonk", data); err != nil {
29		t.Fatal(err)
30	}
31	if _, err := client.Logical().Write("secret/zip/twoot", data); err != nil {
32		t.Fatal(err)
33	}
34
35	cases := []struct {
36		name         string
37		args         complete.Args
38		includeFiles bool
39		exp          []string
40	}{
41		{
42			"has_args",
43			complete.Args{
44				All:  []string{"read", "secret/foo", "a=b"},
45				Last: "a=b",
46			},
47			true,
48			nil,
49		},
50		{
51			"has_args_no_files",
52			complete.Args{
53				All:  []string{"read", "secret/foo", "a=b"},
54				Last: "a=b",
55			},
56			false,
57			nil,
58		},
59		{
60			"part_mount",
61			complete.Args{
62				All:  []string{"read", "s"},
63				Last: "s",
64			},
65			true,
66			[]string{"secret/", "sys/"},
67		},
68		{
69			"part_mount_no_files",
70			complete.Args{
71				All:  []string{"read", "s"},
72				Last: "s",
73			},
74			false,
75			[]string{"secret/", "sys/"},
76		},
77		{
78			"only_mount",
79			complete.Args{
80				All:  []string{"read", "sec"},
81				Last: "sec",
82			},
83			true,
84			[]string{"secret/bar", "secret/foo", "secret/zip/"},
85		},
86		{
87			"only_mount_no_files",
88			complete.Args{
89				All:  []string{"read", "sec"},
90				Last: "sec",
91			},
92			false,
93			[]string{"secret/zip/"},
94		},
95		{
96			"full_mount",
97			complete.Args{
98				All:  []string{"read", "secret"},
99				Last: "secret",
100			},
101			true,
102			[]string{"secret/bar", "secret/foo", "secret/zip/"},
103		},
104		{
105			"full_mount_no_files",
106			complete.Args{
107				All:  []string{"read", "secret"},
108				Last: "secret",
109			},
110			false,
111			[]string{"secret/zip/"},
112		},
113		{
114			"full_mount_slash",
115			complete.Args{
116				All:  []string{"read", "secret/"},
117				Last: "secret/",
118			},
119			true,
120			[]string{"secret/bar", "secret/foo", "secret/zip/"},
121		},
122		{
123			"full_mount_slash_no_files",
124			complete.Args{
125				All:  []string{"read", "secret/"},
126				Last: "secret/",
127			},
128			false,
129			[]string{"secret/zip/"},
130		},
131		{
132			"path_partial",
133			complete.Args{
134				All:  []string{"read", "secret/z"},
135				Last: "secret/z",
136			},
137			true,
138			[]string{"secret/zip/twoot", "secret/zip/zap", "secret/zip/zonk"},
139		},
140		{
141			"path_partial_no_files",
142			complete.Args{
143				All:  []string{"read", "secret/z"},
144				Last: "secret/z",
145			},
146			false,
147			[]string{"secret/zip/"},
148		},
149		{
150			"subpath_partial_z",
151			complete.Args{
152				All:  []string{"read", "secret/zip/z"},
153				Last: "secret/zip/z",
154			},
155			true,
156			[]string{"secret/zip/zap", "secret/zip/zonk"},
157		},
158		{
159			"subpath_partial_z_no_files",
160			complete.Args{
161				All:  []string{"read", "secret/zip/z"},
162				Last: "secret/zip/z",
163			},
164			false,
165			[]string{"secret/zip/z"},
166		},
167		{
168			"subpath_partial_t",
169			complete.Args{
170				All:  []string{"read", "secret/zip/t"},
171				Last: "secret/zip/t",
172			},
173			true,
174			[]string{"secret/zip/twoot"},
175		},
176		{
177			"subpath_partial_t_no_files",
178			complete.Args{
179				All:  []string{"read", "secret/zip/t"},
180				Last: "secret/zip/t",
181			},
182			false,
183			[]string{"secret/zip/t"},
184		},
185	}
186
187	t.Run("group", func(t *testing.T) {
188		for _, tc := range cases {
189			tc := tc
190			t.Run(tc.name, func(t *testing.T) {
191				t.Parallel()
192
193				p := NewPredict()
194				p.client = client
195
196				f := p.vaultPaths(tc.includeFiles)
197				act := f(tc.args)
198				if !reflect.DeepEqual(act, tc.exp) {
199					t.Errorf("expected %q to be %q", act, tc.exp)
200				}
201			})
202		}
203	})
204}
205
206func TestPredict_Audits(t *testing.T) {
207	t.Parallel()
208
209	client, closer := testVaultServer(t)
210	defer closer()
211
212	badClient, badCloser := testVaultServerBad(t)
213	defer badCloser()
214
215	if err := client.Sys().EnableAuditWithOptions("file", &api.EnableAuditOptions{
216		Type: "file",
217		Options: map[string]string{
218			"file_path": "discard",
219		},
220	}); err != nil {
221		t.Fatal(err)
222	}
223
224	cases := []struct {
225		name   string
226		client *api.Client
227		exp    []string
228	}{
229		{
230			"not_connected_client",
231			badClient,
232			nil,
233		},
234		{
235			"good_path",
236			client,
237			[]string{"file/"},
238		},
239	}
240
241	t.Run("group", func(t *testing.T) {
242		for _, tc := range cases {
243			tc := tc
244			t.Run(tc.name, func(t *testing.T) {
245				t.Parallel()
246
247				p := NewPredict()
248				p.client = tc.client
249
250				act := p.audits()
251				if !reflect.DeepEqual(act, tc.exp) {
252					t.Errorf("expected %q to be %q", act, tc.exp)
253				}
254			})
255		}
256	})
257}
258
259func TestPredict_Mounts(t *testing.T) {
260	t.Parallel()
261
262	client, closer := testVaultServer(t)
263	defer closer()
264
265	badClient, badCloser := testVaultServerBad(t)
266	defer badCloser()
267
268	cases := []struct {
269		name   string
270		client *api.Client
271		exp    []string
272	}{
273		{
274			"not_connected_client",
275			badClient,
276			defaultPredictVaultMounts,
277		},
278		{
279			"good_path",
280			client,
281			[]string{"cubbyhole/", "identity/", "secret/", "sys/"},
282		},
283	}
284
285	t.Run("group", func(t *testing.T) {
286		for _, tc := range cases {
287			tc := tc
288			t.Run(tc.name, func(t *testing.T) {
289				t.Parallel()
290
291				p := NewPredict()
292				p.client = tc.client
293
294				act := p.mounts()
295				if !reflect.DeepEqual(act, tc.exp) {
296					t.Errorf("expected %q to be %q", act, tc.exp)
297				}
298			})
299		}
300	})
301}
302
303func TestPredict_Plugins(t *testing.T) {
304	t.Parallel()
305
306	client, closer := testVaultServer(t)
307	defer closer()
308
309	badClient, badCloser := testVaultServerBad(t)
310	defer badCloser()
311
312	cases := []struct {
313		name   string
314		client *api.Client
315		exp    []string
316	}{
317		{
318			"not_connected_client",
319			badClient,
320			nil,
321		},
322		{
323			"good_path",
324			client,
325			[]string{
326				"ad",
327				"alicloud",
328				"app-id",
329				"approle",
330				"aws",
331				"azure",
332				"cassandra",
333				"cassandra-database-plugin",
334				"centrify",
335				"cert",
336				"consul",
337				"elasticsearch-database-plugin",
338				"gcp",
339				"gcpkms",
340				"github",
341				"hana-database-plugin",
342				"influxdb-database-plugin",
343				"jwt",
344				"kmip",
345				"kubernetes",
346				"kv",
347				"ldap",
348				"mongodb",
349				"mongodb-database-plugin",
350				"mssql",
351				"mssql-database-plugin",
352				"mysql",
353				"mysql-aurora-database-plugin",
354				"mysql-database-plugin",
355				"mysql-legacy-database-plugin",
356				"mysql-rds-database-plugin",
357				"nomad",
358				"oidc",
359				"okta",
360				"pcf",
361				"pki",
362				"postgresql",
363				"postgresql-database-plugin",
364				"rabbitmq",
365				"radius",
366				"ssh",
367				"totp",
368				"transit",
369				"userpass",
370			},
371		},
372	}
373
374	t.Run("group", func(t *testing.T) {
375		for _, tc := range cases {
376			tc := tc
377			t.Run(tc.name, func(t *testing.T) {
378				t.Parallel()
379
380				p := NewPredict()
381				p.client = tc.client
382
383				act := p.plugins()
384
385				if !strutil.StrListContains(act, "kmip") {
386					for i, v := range tc.exp {
387						if v == "kmip" {
388							tc.exp = append(tc.exp[:i], tc.exp[i+1:]...)
389							break
390						}
391					}
392				}
393				if !reflect.DeepEqual(act, tc.exp) {
394					t.Errorf("expected %q to be %q", act, tc.exp)
395				}
396			})
397		}
398	})
399}
400
401func TestPredict_Policies(t *testing.T) {
402	t.Parallel()
403
404	client, closer := testVaultServer(t)
405	defer closer()
406
407	badClient, badCloser := testVaultServerBad(t)
408	defer badCloser()
409
410	cases := []struct {
411		name   string
412		client *api.Client
413		exp    []string
414	}{
415		{
416			"not_connected_client",
417			badClient,
418			nil,
419		},
420		{
421			"good_path",
422			client,
423			[]string{"default", "root"},
424		},
425	}
426
427	t.Run("group", func(t *testing.T) {
428		for _, tc := range cases {
429			tc := tc
430			t.Run(tc.name, func(t *testing.T) {
431				t.Parallel()
432
433				p := NewPredict()
434				p.client = tc.client
435
436				act := p.policies()
437				if !reflect.DeepEqual(act, tc.exp) {
438					t.Errorf("expected %q to be %q", act, tc.exp)
439				}
440			})
441		}
442	})
443}
444
445func TestPredict_Paths(t *testing.T) {
446	t.Parallel()
447
448	client, closer := testVaultServer(t)
449	defer closer()
450
451	data := map[string]interface{}{"a": "b"}
452	if _, err := client.Logical().Write("secret/bar", data); err != nil {
453		t.Fatal(err)
454	}
455	if _, err := client.Logical().Write("secret/foo", data); err != nil {
456		t.Fatal(err)
457	}
458	if _, err := client.Logical().Write("secret/zip/zap", data); err != nil {
459		t.Fatal(err)
460	}
461
462	cases := []struct {
463		name         string
464		path         string
465		includeFiles bool
466		exp          []string
467	}{
468		{
469			"bad_path",
470			"nope/not/a/real/path/ever",
471			true,
472			[]string{"nope/not/a/real/path/ever"},
473		},
474		{
475			"good_path",
476			"secret/",
477			true,
478			[]string{"secret/bar", "secret/foo", "secret/zip/"},
479		},
480		{
481			"good_path_no_files",
482			"secret/",
483			false,
484			[]string{"secret/zip/"},
485		},
486		{
487			"partial_match",
488			"secret/z",
489			true,
490			[]string{"secret/zip/"},
491		},
492		{
493			"partial_match_no_files",
494			"secret/z",
495			false,
496			[]string{"secret/zip/"},
497		},
498	}
499
500	t.Run("group", func(t *testing.T) {
501		for _, tc := range cases {
502			tc := tc
503			t.Run(tc.name, func(t *testing.T) {
504				t.Parallel()
505
506				p := NewPredict()
507				p.client = client
508
509				act := p.paths(tc.path, tc.includeFiles)
510				if !reflect.DeepEqual(act, tc.exp) {
511					t.Errorf("expected %q to be %q", act, tc.exp)
512				}
513			})
514		}
515	})
516}
517
518func TestPredict_ListPaths(t *testing.T) {
519	t.Parallel()
520
521	client, closer := testVaultServer(t)
522	defer closer()
523
524	badClient, badCloser := testVaultServerBad(t)
525	defer badCloser()
526
527	data := map[string]interface{}{"a": "b"}
528	if _, err := client.Logical().Write("secret/bar", data); err != nil {
529		t.Fatal(err)
530	}
531	if _, err := client.Logical().Write("secret/foo", data); err != nil {
532		t.Fatal(err)
533	}
534
535	cases := []struct {
536		name   string
537		client *api.Client
538		path   string
539		exp    []string
540	}{
541		{
542			"bad_path",
543			client,
544			"nope/not/a/real/path/ever",
545			nil,
546		},
547		{
548			"good_path",
549			client,
550			"secret/",
551			[]string{"bar", "foo"},
552		},
553		{
554			"not_connected_client",
555			badClient,
556			"secret/",
557			nil,
558		},
559	}
560
561	t.Run("group", func(t *testing.T) {
562		for _, tc := range cases {
563			tc := tc
564			t.Run(tc.name, func(t *testing.T) {
565				t.Parallel()
566
567				p := NewPredict()
568				p.client = tc.client
569
570				act := p.listPaths(tc.path)
571				if !reflect.DeepEqual(act, tc.exp) {
572					t.Errorf("expected %q to be %q", act, tc.exp)
573				}
574			})
575		}
576	})
577}
578
579func TestPredict_HasPathArg(t *testing.T) {
580	t.Parallel()
581
582	cases := []struct {
583		name string
584		args []string
585		exp  bool
586	}{
587		{
588			"nil",
589			nil,
590			false,
591		},
592		{
593			"empty",
594			[]string{},
595			false,
596		},
597		{
598			"empty_string",
599			[]string{""},
600			false,
601		},
602		{
603			"single",
604			[]string{"foo"},
605			false,
606		},
607		{
608			"multiple",
609			[]string{"foo", "bar", "baz"},
610			true,
611		},
612	}
613
614	for _, tc := range cases {
615		tc := tc
616		t.Run(tc.name, func(t *testing.T) {
617			t.Parallel()
618
619			p := NewPredict()
620			if act := p.hasPathArg(tc.args); act != tc.exp {
621				t.Errorf("expected %t to be %t", act, tc.exp)
622			}
623		})
624	}
625}
626