1package command
2
3import (
4	"os"
5	"sort"
6	"strings"
7	"sync"
8
9	"github.com/hashicorp/vault/api"
10	"github.com/hashicorp/vault/sdk/helper/consts"
11	"github.com/posener/complete"
12)
13
14type Predict struct {
15	client     *api.Client
16	clientOnce sync.Once
17}
18
19func NewPredict() *Predict {
20	return &Predict{}
21}
22
23func (p *Predict) Client() *api.Client {
24	p.clientOnce.Do(func() {
25		if p.client == nil { // For tests
26			client, _ := api.NewClient(nil)
27
28			if client.Token() == "" {
29				helper, err := DefaultTokenHelper()
30				if err != nil {
31					return
32				}
33				token, err := helper.Get()
34				if err != nil {
35					return
36				}
37				client.SetToken(token)
38			}
39
40			// Turn off retries for prediction
41			if os.Getenv(api.EnvVaultMaxRetries) == "" {
42				client.SetMaxRetries(0)
43			}
44
45			p.client = client
46		}
47	})
48	return p.client
49}
50
51// defaultPredictVaultMounts is the default list of mounts to return to the
52// user. This is a best-guess, given we haven't communicated with the Vault
53// server. If the user has no token or if the token does not have the default
54// policy attached, it won't be able to read cubbyhole/, but it's a better UX
55// that returning nothing.
56var defaultPredictVaultMounts = []string{"cubbyhole/"}
57
58// predictClient is the API client to use for prediction. We create this at the
59// beginning once, because completions are generated for each command (and this
60// doesn't change), and the only way to configure the predict/autocomplete
61// client is via environment variables. Even if the user specifies a flag, we
62// can't parse that flag until after the command is submitted.
63var predictClient *api.Client
64var predictClientOnce sync.Once
65
66// PredictClient returns the cached API client for the predictor.
67func PredictClient() *api.Client {
68	predictClientOnce.Do(func() {
69		if predictClient == nil { // For tests
70			predictClient, _ = api.NewClient(nil)
71		}
72	})
73	return predictClient
74}
75
76// PredictVaultAvailableMounts returns a predictor for the available mounts in
77// Vault. For now, there is no way to programmatically get this list. If, in the
78// future, such a list exists, we can adapt it here. Until then, it's
79// hard-coded.
80func (b *BaseCommand) PredictVaultAvailableMounts() complete.Predictor {
81	// This list does not contain deprecated backends. At present, there is no
82	// API that lists all available secret backends, so this is hard-coded :(.
83	return complete.PredictSet(
84		"aws",
85		"consul",
86		"database",
87		"generic",
88		"pki",
89		"plugin",
90		"rabbitmq",
91		"ssh",
92		"totp",
93		"transit",
94	)
95}
96
97// PredictVaultAvailableAuths returns a predictor for the available auths in
98// Vault. For now, there is no way to programmatically get this list. If, in the
99// future, such a list exists, we can adapt it here. Until then, it's
100// hard-coded.
101func (b *BaseCommand) PredictVaultAvailableAuths() complete.Predictor {
102	return complete.PredictSet(
103		"app-id",
104		"approle",
105		"aws",
106		"cert",
107		"gcp",
108		"github",
109		"ldap",
110		"okta",
111		"plugin",
112		"radius",
113		"userpass",
114	)
115}
116
117// PredictVaultFiles returns a predictor for Vault mounts and paths based on the
118// configured client for the base command. Unfortunately this happens pre-flag
119// parsing, so users must rely on environment variables for autocomplete if they
120// are not using Vault at the default endpoints.
121func (b *BaseCommand) PredictVaultFiles() complete.Predictor {
122	return NewPredict().VaultFiles()
123}
124
125// PredictVaultFolders returns a predictor for "folders". See PredictVaultFiles
126// for more information and restrictions.
127func (b *BaseCommand) PredictVaultFolders() complete.Predictor {
128	return NewPredict().VaultFolders()
129}
130
131// PredictVaultMounts returns a predictor for "folders". See PredictVaultFiles
132// for more information and restrictions.
133func (b *BaseCommand) PredictVaultMounts() complete.Predictor {
134	return NewPredict().VaultMounts()
135}
136
137// PredictVaultAudits returns a predictor for "folders". See PredictVaultFiles
138// for more information and restrictions.
139func (b *BaseCommand) PredictVaultAudits() complete.Predictor {
140	return NewPredict().VaultAudits()
141}
142
143// PredictVaultAuths returns a predictor for "folders". See PredictVaultFiles
144// for more information and restrictions.
145func (b *BaseCommand) PredictVaultAuths() complete.Predictor {
146	return NewPredict().VaultAuths()
147}
148
149// PredictVaultPlugins returns a predictor for installed plugins.
150func (b *BaseCommand) PredictVaultPlugins(pluginTypes ...consts.PluginType) complete.Predictor {
151	return NewPredict().VaultPlugins(pluginTypes...)
152}
153
154// PredictVaultPolicies returns a predictor for "folders". See PredictVaultFiles
155// for more information and restrictions.
156func (b *BaseCommand) PredictVaultPolicies() complete.Predictor {
157	return NewPredict().VaultPolicies()
158}
159
160func (b *BaseCommand) PredictVaultDebugTargets() complete.Predictor {
161	return complete.PredictSet(
162		"config",
163		"host",
164		"metrics",
165		"pprof",
166		"replication-status",
167		"server-status",
168	)
169}
170
171// VaultFiles returns a predictor for Vault "files". This is a public API for
172// consumers, but you probably want BaseCommand.PredictVaultFiles instead.
173func (p *Predict) VaultFiles() complete.Predictor {
174	return p.vaultPaths(true)
175}
176
177// VaultFolders returns a predictor for Vault "folders". This is a public
178// API for consumers, but you probably want BaseCommand.PredictVaultFolders
179// instead.
180func (p *Predict) VaultFolders() complete.Predictor {
181	return p.vaultPaths(false)
182}
183
184// VaultMounts returns a predictor for Vault "folders". This is a public
185// API for consumers, but you probably want BaseCommand.PredictVaultMounts
186// instead.
187func (p *Predict) VaultMounts() complete.Predictor {
188	return p.filterFunc(p.mounts)
189}
190
191// VaultAudits returns a predictor for Vault "folders". This is a public API for
192// consumers, but you probably want BaseCommand.PredictVaultAudits instead.
193func (p *Predict) VaultAudits() complete.Predictor {
194	return p.filterFunc(p.audits)
195}
196
197// VaultAuths returns a predictor for Vault "folders". This is a public API for
198// consumers, but you probably want BaseCommand.PredictVaultAuths instead.
199func (p *Predict) VaultAuths() complete.Predictor {
200	return p.filterFunc(p.auths)
201}
202
203// VaultPlugins returns a predictor for Vault's plugin catalog. This is a public
204// API for consumers, but you probably want BaseCommand.PredictVaultPlugins
205// instead.
206func (p *Predict) VaultPlugins(pluginTypes ...consts.PluginType) complete.Predictor {
207	filterFunc := func() []string {
208		return p.plugins(pluginTypes...)
209	}
210	return p.filterFunc(filterFunc)
211}
212
213// VaultPolicies returns a predictor for Vault "folders". This is a public API for
214// consumers, but you probably want BaseCommand.PredictVaultPolicies instead.
215func (p *Predict) VaultPolicies() complete.Predictor {
216	return p.filterFunc(p.policies)
217}
218
219// vaultPaths parses the CLI options and returns the "best" list of possible
220// paths. If there are any errors, this function returns an empty result. All
221// errors are suppressed since this is a prediction function.
222func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc {
223	return func(args complete.Args) []string {
224		// Do not predict more than one paths
225		if p.hasPathArg(args.All) {
226			return nil
227		}
228
229		client := p.Client()
230		if client == nil {
231			return nil
232		}
233
234		path := args.Last
235
236		var predictions []string
237		if strings.Contains(path, "/") {
238			predictions = p.paths(path, includeFiles)
239		} else {
240			predictions = p.filter(p.mounts(), path)
241		}
242
243		// Either no results or many results, so return.
244		if len(predictions) != 1 {
245			return predictions
246		}
247
248		// If this is not a "folder", do not try to recurse.
249		if !strings.HasSuffix(predictions[0], "/") {
250			return predictions
251		}
252
253		// If the prediction is the same as the last guess, return it (we have no
254		// new information and we won't get anymore).
255		if predictions[0] == args.Last {
256			return predictions
257		}
258
259		// Re-predict with the remaining path
260		args.Last = predictions[0]
261		return p.vaultPaths(includeFiles).Predict(args)
262	}
263}
264
265// paths predicts all paths which start with the given path.
266func (p *Predict) paths(path string, includeFiles bool) []string {
267	client := p.Client()
268	if client == nil {
269		return nil
270	}
271
272	// Vault does not support listing based on a sub-key, so we have to back-pedal
273	// to the last "/" and return all paths on that "folder". Then we perform
274	// client-side filtering.
275	root := path
276	idx := strings.LastIndex(root, "/")
277	if idx > 0 && idx < len(root) {
278		root = root[:idx+1]
279	}
280
281	paths := p.listPaths(root)
282
283	var predictions []string
284	for _, p := range paths {
285		// Calculate the absolute "path" for matching.
286		p = root + p
287
288		if strings.HasPrefix(p, path) {
289			// Ensure this is a directory or we've asked to include files.
290			if includeFiles || strings.HasSuffix(p, "/") {
291				predictions = append(predictions, p)
292			}
293		}
294	}
295
296	// Add root to the path
297	if len(predictions) == 0 {
298		predictions = append(predictions, path)
299	}
300
301	return predictions
302}
303
304// audits returns a sorted list of the audit backends for Vault server for
305// which the client is configured to communicate with.
306func (p *Predict) audits() []string {
307	client := p.Client()
308	if client == nil {
309		return nil
310	}
311
312	audits, err := client.Sys().ListAudit()
313	if err != nil {
314		return nil
315	}
316
317	list := make([]string, 0, len(audits))
318	for m := range audits {
319		list = append(list, m)
320	}
321	sort.Strings(list)
322	return list
323}
324
325// auths returns a sorted list of the enabled auth provides for Vault server for
326// which the client is configured to communicate with.
327func (p *Predict) auths() []string {
328	client := p.Client()
329	if client == nil {
330		return nil
331	}
332
333	auths, err := client.Sys().ListAuth()
334	if err != nil {
335		return nil
336	}
337
338	list := make([]string, 0, len(auths))
339	for m := range auths {
340		list = append(list, m)
341	}
342	sort.Strings(list)
343	return list
344}
345
346// plugins returns a sorted list of the plugins in the catalog.
347func (p *Predict) plugins(pluginTypes ...consts.PluginType) []string {
348	// This method's signature doesn't enforce that a pluginType must be passed in.
349	// If it's not, it's likely the caller's intent is go get a list of all of them,
350	// so let's help them out.
351	if len(pluginTypes) == 0 {
352		pluginTypes = append(pluginTypes, consts.PluginTypeUnknown)
353	}
354
355	client := p.Client()
356	if client == nil {
357		return nil
358	}
359
360	var plugins []string
361	pluginsAdded := make(map[string]bool)
362	for _, pluginType := range pluginTypes {
363		result, err := client.Sys().ListPlugins(&api.ListPluginsInput{Type: pluginType})
364		if err != nil {
365			return nil
366		}
367		if result == nil {
368			return nil
369		}
370		for _, names := range result.PluginsByType {
371			for _, name := range names {
372				if _, ok := pluginsAdded[name]; !ok {
373					plugins = append(plugins, name)
374					pluginsAdded[name] = true
375				}
376			}
377		}
378	}
379	sort.Strings(plugins)
380	return plugins
381}
382
383// policies returns a sorted list of the policies stored in this Vault
384// server.
385func (p *Predict) policies() []string {
386	client := p.Client()
387	if client == nil {
388		return nil
389	}
390
391	policies, err := client.Sys().ListPolicies()
392	if err != nil {
393		return nil
394	}
395	sort.Strings(policies)
396	return policies
397}
398
399// mounts returns a sorted list of the mount paths for Vault server for
400// which the client is configured to communicate with. This function returns the
401// default list of mounts if an error occurs.
402func (p *Predict) mounts() []string {
403	client := p.Client()
404	if client == nil {
405		return nil
406	}
407
408	mounts, err := client.Sys().ListMounts()
409	if err != nil {
410		return defaultPredictVaultMounts
411	}
412
413	list := make([]string, 0, len(mounts))
414	for m := range mounts {
415		list = append(list, m)
416	}
417	sort.Strings(list)
418	return list
419}
420
421// listPaths returns a list of paths (HTTP LIST) for the given path. This
422// function returns an empty list of any errors occur.
423func (p *Predict) listPaths(path string) []string {
424	client := p.Client()
425	if client == nil {
426		return nil
427	}
428
429	secret, err := client.Logical().List(path)
430	if err != nil || secret == nil || secret.Data == nil {
431		return nil
432	}
433
434	paths, ok := secret.Data["keys"].([]interface{})
435	if !ok {
436		return nil
437	}
438
439	list := make([]string, 0, len(paths))
440	for _, p := range paths {
441		if str, ok := p.(string); ok {
442			list = append(list, str)
443		}
444	}
445	sort.Strings(list)
446	return list
447}
448
449// hasPathArg determines if the args have already accepted a path.
450func (p *Predict) hasPathArg(args []string) bool {
451	var nonFlags []string
452	for _, a := range args {
453		if !strings.HasPrefix(a, "-") {
454			nonFlags = append(nonFlags, a)
455		}
456	}
457
458	return len(nonFlags) > 2
459}
460
461// filterFunc is used to compose a complete predictor that filters an array
462// of strings as per the filter function.
463func (p *Predict) filterFunc(f func() []string) complete.Predictor {
464	return complete.PredictFunc(func(args complete.Args) []string {
465		if p.hasPathArg(args.All) {
466			return nil
467		}
468
469		client := p.Client()
470		if client == nil {
471			return nil
472		}
473
474		return p.filter(f(), args.Last)
475	})
476}
477
478// filter filters the given list for items that start with the prefix.
479func (p *Predict) filter(list []string, prefix string) []string {
480	var predictions []string
481	for _, item := range list {
482		if strings.HasPrefix(item, prefix) {
483			predictions = append(predictions, item)
484		}
485	}
486	return predictions
487}
488