1package command
2
3import (
4	"os"
5	"sort"
6	"strings"
7	"sync"
8
9	"github.com/hashicorp/vault/api"
10	"github.com/hashicorp/vault/sdk/helper/consts"
11	"github.com/posener/complete"
12)
13
14type Predict struct {
15	client     *api.Client
16	clientOnce sync.Once
17}
18
19func NewPredict() *Predict {
20	return &Predict{}
21}
22
23func (p *Predict) Client() *api.Client {
24	p.clientOnce.Do(func() {
25		if p.client == nil { // For tests
26			client, _ := api.NewClient(nil)
27
28			if client.Token() == "" {
29				helper, err := DefaultTokenHelper()
30				if err != nil {
31					return
32				}
33				token, err := helper.Get()
34				if err != nil {
35					return
36				}
37				client.SetToken(token)
38			}
39
40			// Turn off retries for prediction
41			if os.Getenv(api.EnvVaultMaxRetries) == "" {
42				client.SetMaxRetries(0)
43			}
44
45			p.client = client
46		}
47	})
48	return p.client
49}
50
51// defaultPredictVaultMounts is the default list of mounts to return to the
52// user. This is a best-guess, given we haven't communicated with the Vault
53// server. If the user has no token or if the token does not have the default
54// policy attached, it won't be able to read cubbyhole/, but it's a better UX
55// that returning nothing.
56var defaultPredictVaultMounts = []string{"cubbyhole/"}
57
58// predictClient is the API client to use for prediction. We create this at the
59// beginning once, because completions are generated for each command (and this
60// doesn't change), and the only way to configure the predict/autocomplete
61// client is via environment variables. Even if the user specifies a flag, we
62// can't parse that flag until after the command is submitted.
63var predictClient *api.Client
64var predictClientOnce sync.Once
65
66// PredictClient returns the cached API client for the predictor.
67func PredictClient() *api.Client {
68	predictClientOnce.Do(func() {
69		if predictClient == nil { // For tests
70			predictClient, _ = api.NewClient(nil)
71		}
72	})
73	return predictClient
74}
75
76// PredictVaultAvailableMounts returns a predictor for the available mounts in
77// Vault. For now, there is no way to programmatically get this list. If, in the
78// future, such a list exists, we can adapt it here. Until then, it's
79// hard-coded.
80func (b *BaseCommand) PredictVaultAvailableMounts() complete.Predictor {
81	// This list does not contain deprecated backends. At present, there is no
82	// API that lists all available secret backends, so this is hard-coded :(.
83	return complete.PredictSet(
84		"aws",
85		"consul",
86		"database",
87		"generic",
88		"pki",
89		"plugin",
90		"rabbitmq",
91		"ssh",
92		"totp",
93		"transit",
94	)
95}
96
97// PredictVaultAvailableAuths returns a predictor for the available auths in
98// Vault. For now, there is no way to programmatically get this list. If, in the
99// future, such a list exists, we can adapt it here. Until then, it's
100// hard-coded.
101func (b *BaseCommand) PredictVaultAvailableAuths() complete.Predictor {
102	return complete.PredictSet(
103		"app-id",
104		"approle",
105		"aws",
106		"cert",
107		"gcp",
108		"github",
109		"ldap",
110		"okta",
111		"plugin",
112		"radius",
113		"userpass",
114	)
115}
116
117// PredictVaultFiles returns a predictor for Vault mounts and paths based on the
118// configured client for the base command. Unfortunately this happens pre-flag
119// parsing, so users must rely on environment variables for autocomplete if they
120// are not using Vault at the default endpoints.
121func (b *BaseCommand) PredictVaultFiles() complete.Predictor {
122	return NewPredict().VaultFiles()
123}
124
125// PredictVaultFolders returns a predictor for "folders". See PredictVaultFiles
126// for more information and restrictions.
127func (b *BaseCommand) PredictVaultFolders() complete.Predictor {
128	return NewPredict().VaultFolders()
129}
130
131// PredictVaultMounts returns a predictor for "folders". See PredictVaultFiles
132// for more information and restrictions.
133func (b *BaseCommand) PredictVaultMounts() complete.Predictor {
134	return NewPredict().VaultMounts()
135}
136
137// PredictVaultAudits returns a predictor for "folders". See PredictVaultFiles
138// for more information and restrictions.
139func (b *BaseCommand) PredictVaultAudits() complete.Predictor {
140	return NewPredict().VaultAudits()
141}
142
143// PredictVaultAuths returns a predictor for "folders". See PredictVaultFiles
144// for more information and restrictions.
145func (b *BaseCommand) PredictVaultAuths() complete.Predictor {
146	return NewPredict().VaultAuths()
147}
148
149// PredictVaultPlugins returns a predictor for installed plugins.
150func (b *BaseCommand) PredictVaultPlugins(pluginTypes ...consts.PluginType) complete.Predictor {
151	return NewPredict().VaultPlugins(pluginTypes...)
152}
153
154// PredictVaultPolicies returns a predictor for "folders". See PredictVaultFiles
155// for more information and restrictions.
156func (b *BaseCommand) PredictVaultPolicies() complete.Predictor {
157	return NewPredict().VaultPolicies()
158}
159
160// VaultFiles returns a predictor for Vault "files". This is a public API for
161// consumers, but you probably want BaseCommand.PredictVaultFiles instead.
162func (p *Predict) VaultFiles() complete.Predictor {
163	return p.vaultPaths(true)
164}
165
166// VaultFolders returns a predictor for Vault "folders". This is a public
167// API for consumers, but you probably want BaseCommand.PredictVaultFolders
168// instead.
169func (p *Predict) VaultFolders() complete.Predictor {
170	return p.vaultPaths(false)
171}
172
173// VaultMounts returns a predictor for Vault "folders". This is a public
174// API for consumers, but you probably want BaseCommand.PredictVaultMounts
175// instead.
176func (p *Predict) VaultMounts() complete.Predictor {
177	return p.filterFunc(p.mounts)
178}
179
180// VaultAudits returns a predictor for Vault "folders". This is a public API for
181// consumers, but you probably want BaseCommand.PredictVaultAudits instead.
182func (p *Predict) VaultAudits() complete.Predictor {
183	return p.filterFunc(p.audits)
184}
185
186// VaultAuths returns a predictor for Vault "folders". This is a public API for
187// consumers, but you probably want BaseCommand.PredictVaultAuths instead.
188func (p *Predict) VaultAuths() complete.Predictor {
189	return p.filterFunc(p.auths)
190}
191
192// VaultPlugins returns a predictor for Vault's plugin catalog. This is a public
193// API for consumers, but you probably want BaseCommand.PredictVaultPlugins
194// instead.
195func (p *Predict) VaultPlugins(pluginTypes ...consts.PluginType) complete.Predictor {
196	filterFunc := func() []string {
197		return p.plugins(pluginTypes...)
198	}
199	return p.filterFunc(filterFunc)
200}
201
202// VaultPolicies returns a predictor for Vault "folders". This is a public API for
203// consumers, but you probably want BaseCommand.PredictVaultPolicies instead.
204func (p *Predict) VaultPolicies() complete.Predictor {
205	return p.filterFunc(p.policies)
206}
207
208// vaultPaths parses the CLI options and returns the "best" list of possible
209// paths. If there are any errors, this function returns an empty result. All
210// errors are suppressed since this is a prediction function.
211func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc {
212	return func(args complete.Args) []string {
213		// Do not predict more than one paths
214		if p.hasPathArg(args.All) {
215			return nil
216		}
217
218		client := p.Client()
219		if client == nil {
220			return nil
221		}
222
223		path := args.Last
224
225		var predictions []string
226		if strings.Contains(path, "/") {
227			predictions = p.paths(path, includeFiles)
228		} else {
229			predictions = p.filter(p.mounts(), path)
230		}
231
232		// Either no results or many results, so return.
233		if len(predictions) != 1 {
234			return predictions
235		}
236
237		// If this is not a "folder", do not try to recurse.
238		if !strings.HasSuffix(predictions[0], "/") {
239			return predictions
240		}
241
242		// If the prediction is the same as the last guess, return it (we have no
243		// new information and we won't get anymore).
244		if predictions[0] == args.Last {
245			return predictions
246		}
247
248		// Re-predict with the remaining path
249		args.Last = predictions[0]
250		return p.vaultPaths(includeFiles).Predict(args)
251	}
252}
253
254// paths predicts all paths which start with the given path.
255func (p *Predict) paths(path string, includeFiles bool) []string {
256	client := p.Client()
257	if client == nil {
258		return nil
259	}
260
261	// Vault does not support listing based on a sub-key, so we have to back-pedal
262	// to the last "/" and return all paths on that "folder". Then we perform
263	// client-side filtering.
264	root := path
265	idx := strings.LastIndex(root, "/")
266	if idx > 0 && idx < len(root) {
267		root = root[:idx+1]
268	}
269
270	paths := p.listPaths(root)
271
272	var predictions []string
273	for _, p := range paths {
274		// Calculate the absolute "path" for matching.
275		p = root + p
276
277		if strings.HasPrefix(p, path) {
278			// Ensure this is a directory or we've asked to include files.
279			if includeFiles || strings.HasSuffix(p, "/") {
280				predictions = append(predictions, p)
281			}
282		}
283	}
284
285	// Add root to the path
286	if len(predictions) == 0 {
287		predictions = append(predictions, path)
288	}
289
290	return predictions
291}
292
293// audits returns a sorted list of the audit backends for Vault server for
294// which the client is configured to communicate with.
295func (p *Predict) audits() []string {
296	client := p.Client()
297	if client == nil {
298		return nil
299	}
300
301	audits, err := client.Sys().ListAudit()
302	if err != nil {
303		return nil
304	}
305
306	list := make([]string, 0, len(audits))
307	for m := range audits {
308		list = append(list, m)
309	}
310	sort.Strings(list)
311	return list
312}
313
314// auths returns a sorted list of the enabled auth provides for Vault server for
315// which the client is configured to communicate with.
316func (p *Predict) auths() []string {
317	client := p.Client()
318	if client == nil {
319		return nil
320	}
321
322	auths, err := client.Sys().ListAuth()
323	if err != nil {
324		return nil
325	}
326
327	list := make([]string, 0, len(auths))
328	for m := range auths {
329		list = append(list, m)
330	}
331	sort.Strings(list)
332	return list
333}
334
335// plugins returns a sorted list of the plugins in the catalog.
336func (p *Predict) plugins(pluginTypes ...consts.PluginType) []string {
337	// This method's signature doesn't enforce that a pluginType must be passed in.
338	// If it's not, it's likely the caller's intent is go get a list of all of them,
339	// so let's help them out.
340	if len(pluginTypes) == 0 {
341		pluginTypes = append(pluginTypes, consts.PluginTypeUnknown)
342	}
343
344	client := p.Client()
345	if client == nil {
346		return nil
347	}
348
349	var plugins []string
350	pluginsAdded := make(map[string]bool)
351	for _, pluginType := range pluginTypes {
352		result, err := client.Sys().ListPlugins(&api.ListPluginsInput{Type: pluginType})
353		if err != nil {
354			return nil
355		}
356		if result == nil {
357			return nil
358		}
359		for _, names := range result.PluginsByType {
360			for _, name := range names {
361				if _, ok := pluginsAdded[name]; !ok {
362					plugins = append(plugins, name)
363					pluginsAdded[name] = true
364				}
365			}
366		}
367	}
368	sort.Strings(plugins)
369	return plugins
370}
371
372// policies returns a sorted list of the policies stored in this Vault
373// server.
374func (p *Predict) policies() []string {
375	client := p.Client()
376	if client == nil {
377		return nil
378	}
379
380	policies, err := client.Sys().ListPolicies()
381	if err != nil {
382		return nil
383	}
384	sort.Strings(policies)
385	return policies
386}
387
388// mounts returns a sorted list of the mount paths for Vault server for
389// which the client is configured to communicate with. This function returns the
390// default list of mounts if an error occurs.
391func (p *Predict) mounts() []string {
392	client := p.Client()
393	if client == nil {
394		return nil
395	}
396
397	mounts, err := client.Sys().ListMounts()
398	if err != nil {
399		return defaultPredictVaultMounts
400	}
401
402	list := make([]string, 0, len(mounts))
403	for m := range mounts {
404		list = append(list, m)
405	}
406	sort.Strings(list)
407	return list
408}
409
410// listPaths returns a list of paths (HTTP LIST) for the given path. This
411// function returns an empty list of any errors occur.
412func (p *Predict) listPaths(path string) []string {
413	client := p.Client()
414	if client == nil {
415		return nil
416	}
417
418	secret, err := client.Logical().List(path)
419	if err != nil || secret == nil || secret.Data == nil {
420		return nil
421	}
422
423	paths, ok := secret.Data["keys"].([]interface{})
424	if !ok {
425		return nil
426	}
427
428	list := make([]string, 0, len(paths))
429	for _, p := range paths {
430		if str, ok := p.(string); ok {
431			list = append(list, str)
432		}
433	}
434	sort.Strings(list)
435	return list
436}
437
438// hasPathArg determines if the args have already accepted a path.
439func (p *Predict) hasPathArg(args []string) bool {
440	var nonFlags []string
441	for _, a := range args {
442		if !strings.HasPrefix(a, "-") {
443			nonFlags = append(nonFlags, a)
444		}
445	}
446
447	return len(nonFlags) > 2
448}
449
450// filterFunc is used to compose a complete predictor that filters an array
451// of strings as per the filter function.
452func (p *Predict) filterFunc(f func() []string) complete.Predictor {
453	return complete.PredictFunc(func(args complete.Args) []string {
454		if p.hasPathArg(args.All) {
455			return nil
456		}
457
458		client := p.Client()
459		if client == nil {
460			return nil
461		}
462
463		return p.filter(f(), args.Last)
464	})
465}
466
467// filter filters the given list for items that start with the prefix.
468func (p *Predict) filter(list []string, prefix string) []string {
469	var predictions []string
470	for _, item := range list {
471		if strings.HasPrefix(item, prefix) {
472			predictions = append(predictions, item)
473		}
474	}
475	return predictions
476}
477