1package command 2 3import ( 4 "os" 5 "sort" 6 "strings" 7 "sync" 8 9 "github.com/hashicorp/vault/api" 10 "github.com/hashicorp/vault/sdk/helper/consts" 11 "github.com/posener/complete" 12) 13 14type Predict struct { 15 client *api.Client 16 clientOnce sync.Once 17} 18 19func NewPredict() *Predict { 20 return &Predict{} 21} 22 23func (p *Predict) Client() *api.Client { 24 p.clientOnce.Do(func() { 25 if p.client == nil { // For tests 26 client, _ := api.NewClient(nil) 27 28 if client.Token() == "" { 29 helper, err := DefaultTokenHelper() 30 if err != nil { 31 return 32 } 33 token, err := helper.Get() 34 if err != nil { 35 return 36 } 37 client.SetToken(token) 38 } 39 40 // Turn off retries for prediction 41 if os.Getenv(api.EnvVaultMaxRetries) == "" { 42 client.SetMaxRetries(0) 43 } 44 45 p.client = client 46 } 47 }) 48 return p.client 49} 50 51// defaultPredictVaultMounts is the default list of mounts to return to the 52// user. This is a best-guess, given we haven't communicated with the Vault 53// server. If the user has no token or if the token does not have the default 54// policy attached, it won't be able to read cubbyhole/, but it's a better UX 55// that returning nothing. 56var defaultPredictVaultMounts = []string{"cubbyhole/"} 57 58// predictClient is the API client to use for prediction. We create this at the 59// beginning once, because completions are generated for each command (and this 60// doesn't change), and the only way to configure the predict/autocomplete 61// client is via environment variables. Even if the user specifies a flag, we 62// can't parse that flag until after the command is submitted. 63var predictClient *api.Client 64var predictClientOnce sync.Once 65 66// PredictClient returns the cached API client for the predictor. 67func PredictClient() *api.Client { 68 predictClientOnce.Do(func() { 69 if predictClient == nil { // For tests 70 predictClient, _ = api.NewClient(nil) 71 } 72 }) 73 return predictClient 74} 75 76// PredictVaultAvailableMounts returns a predictor for the available mounts in 77// Vault. For now, there is no way to programmatically get this list. If, in the 78// future, such a list exists, we can adapt it here. Until then, it's 79// hard-coded. 80func (b *BaseCommand) PredictVaultAvailableMounts() complete.Predictor { 81 // This list does not contain deprecated backends. At present, there is no 82 // API that lists all available secret backends, so this is hard-coded :(. 83 return complete.PredictSet( 84 "aws", 85 "consul", 86 "database", 87 "generic", 88 "pki", 89 "plugin", 90 "rabbitmq", 91 "ssh", 92 "totp", 93 "transit", 94 ) 95} 96 97// PredictVaultAvailableAuths returns a predictor for the available auths in 98// Vault. For now, there is no way to programmatically get this list. If, in the 99// future, such a list exists, we can adapt it here. Until then, it's 100// hard-coded. 101func (b *BaseCommand) PredictVaultAvailableAuths() complete.Predictor { 102 return complete.PredictSet( 103 "app-id", 104 "approle", 105 "aws", 106 "cert", 107 "gcp", 108 "github", 109 "ldap", 110 "okta", 111 "plugin", 112 "radius", 113 "userpass", 114 ) 115} 116 117// PredictVaultFiles returns a predictor for Vault mounts and paths based on the 118// configured client for the base command. Unfortunately this happens pre-flag 119// parsing, so users must rely on environment variables for autocomplete if they 120// are not using Vault at the default endpoints. 121func (b *BaseCommand) PredictVaultFiles() complete.Predictor { 122 return NewPredict().VaultFiles() 123} 124 125// PredictVaultFolders returns a predictor for "folders". See PredictVaultFiles 126// for more information and restrictions. 127func (b *BaseCommand) PredictVaultFolders() complete.Predictor { 128 return NewPredict().VaultFolders() 129} 130 131// PredictVaultMounts returns a predictor for "folders". See PredictVaultFiles 132// for more information and restrictions. 133func (b *BaseCommand) PredictVaultMounts() complete.Predictor { 134 return NewPredict().VaultMounts() 135} 136 137// PredictVaultAudits returns a predictor for "folders". See PredictVaultFiles 138// for more information and restrictions. 139func (b *BaseCommand) PredictVaultAudits() complete.Predictor { 140 return NewPredict().VaultAudits() 141} 142 143// PredictVaultAuths returns a predictor for "folders". See PredictVaultFiles 144// for more information and restrictions. 145func (b *BaseCommand) PredictVaultAuths() complete.Predictor { 146 return NewPredict().VaultAuths() 147} 148 149// PredictVaultPlugins returns a predictor for installed plugins. 150func (b *BaseCommand) PredictVaultPlugins(pluginTypes ...consts.PluginType) complete.Predictor { 151 return NewPredict().VaultPlugins(pluginTypes...) 152} 153 154// PredictVaultPolicies returns a predictor for "folders". See PredictVaultFiles 155// for more information and restrictions. 156func (b *BaseCommand) PredictVaultPolicies() complete.Predictor { 157 return NewPredict().VaultPolicies() 158} 159 160func (b *BaseCommand) PredictVaultDebugTargets() complete.Predictor { 161 return complete.PredictSet( 162 "config", 163 "host", 164 "metrics", 165 "pprof", 166 "replication-status", 167 "server-status", 168 ) 169} 170 171// VaultFiles returns a predictor for Vault "files". This is a public API for 172// consumers, but you probably want BaseCommand.PredictVaultFiles instead. 173func (p *Predict) VaultFiles() complete.Predictor { 174 return p.vaultPaths(true) 175} 176 177// VaultFolders returns a predictor for Vault "folders". This is a public 178// API for consumers, but you probably want BaseCommand.PredictVaultFolders 179// instead. 180func (p *Predict) VaultFolders() complete.Predictor { 181 return p.vaultPaths(false) 182} 183 184// VaultMounts returns a predictor for Vault "folders". This is a public 185// API for consumers, but you probably want BaseCommand.PredictVaultMounts 186// instead. 187func (p *Predict) VaultMounts() complete.Predictor { 188 return p.filterFunc(p.mounts) 189} 190 191// VaultAudits returns a predictor for Vault "folders". This is a public API for 192// consumers, but you probably want BaseCommand.PredictVaultAudits instead. 193func (p *Predict) VaultAudits() complete.Predictor { 194 return p.filterFunc(p.audits) 195} 196 197// VaultAuths returns a predictor for Vault "folders". This is a public API for 198// consumers, but you probably want BaseCommand.PredictVaultAuths instead. 199func (p *Predict) VaultAuths() complete.Predictor { 200 return p.filterFunc(p.auths) 201} 202 203// VaultPlugins returns a predictor for Vault's plugin catalog. This is a public 204// API for consumers, but you probably want BaseCommand.PredictVaultPlugins 205// instead. 206func (p *Predict) VaultPlugins(pluginTypes ...consts.PluginType) complete.Predictor { 207 filterFunc := func() []string { 208 return p.plugins(pluginTypes...) 209 } 210 return p.filterFunc(filterFunc) 211} 212 213// VaultPolicies returns a predictor for Vault "folders". This is a public API for 214// consumers, but you probably want BaseCommand.PredictVaultPolicies instead. 215func (p *Predict) VaultPolicies() complete.Predictor { 216 return p.filterFunc(p.policies) 217} 218 219// vaultPaths parses the CLI options and returns the "best" list of possible 220// paths. If there are any errors, this function returns an empty result. All 221// errors are suppressed since this is a prediction function. 222func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc { 223 return func(args complete.Args) []string { 224 // Do not predict more than one paths 225 if p.hasPathArg(args.All) { 226 return nil 227 } 228 229 client := p.Client() 230 if client == nil { 231 return nil 232 } 233 234 path := args.Last 235 236 var predictions []string 237 if strings.Contains(path, "/") { 238 predictions = p.paths(path, includeFiles) 239 } else { 240 predictions = p.filter(p.mounts(), path) 241 } 242 243 // Either no results or many results, so return. 244 if len(predictions) != 1 { 245 return predictions 246 } 247 248 // If this is not a "folder", do not try to recurse. 249 if !strings.HasSuffix(predictions[0], "/") { 250 return predictions 251 } 252 253 // If the prediction is the same as the last guess, return it (we have no 254 // new information and we won't get anymore). 255 if predictions[0] == args.Last { 256 return predictions 257 } 258 259 // Re-predict with the remaining path 260 args.Last = predictions[0] 261 return p.vaultPaths(includeFiles).Predict(args) 262 } 263} 264 265// paths predicts all paths which start with the given path. 266func (p *Predict) paths(path string, includeFiles bool) []string { 267 client := p.Client() 268 if client == nil { 269 return nil 270 } 271 272 // Vault does not support listing based on a sub-key, so we have to back-pedal 273 // to the last "/" and return all paths on that "folder". Then we perform 274 // client-side filtering. 275 root := path 276 idx := strings.LastIndex(root, "/") 277 if idx > 0 && idx < len(root) { 278 root = root[:idx+1] 279 } 280 281 paths := p.listPaths(root) 282 283 var predictions []string 284 for _, p := range paths { 285 // Calculate the absolute "path" for matching. 286 p = root + p 287 288 if strings.HasPrefix(p, path) { 289 // Ensure this is a directory or we've asked to include files. 290 if includeFiles || strings.HasSuffix(p, "/") { 291 predictions = append(predictions, p) 292 } 293 } 294 } 295 296 // Add root to the path 297 if len(predictions) == 0 { 298 predictions = append(predictions, path) 299 } 300 301 return predictions 302} 303 304// audits returns a sorted list of the audit backends for Vault server for 305// which the client is configured to communicate with. 306func (p *Predict) audits() []string { 307 client := p.Client() 308 if client == nil { 309 return nil 310 } 311 312 audits, err := client.Sys().ListAudit() 313 if err != nil { 314 return nil 315 } 316 317 list := make([]string, 0, len(audits)) 318 for m := range audits { 319 list = append(list, m) 320 } 321 sort.Strings(list) 322 return list 323} 324 325// auths returns a sorted list of the enabled auth provides for Vault server for 326// which the client is configured to communicate with. 327func (p *Predict) auths() []string { 328 client := p.Client() 329 if client == nil { 330 return nil 331 } 332 333 auths, err := client.Sys().ListAuth() 334 if err != nil { 335 return nil 336 } 337 338 list := make([]string, 0, len(auths)) 339 for m := range auths { 340 list = append(list, m) 341 } 342 sort.Strings(list) 343 return list 344} 345 346// plugins returns a sorted list of the plugins in the catalog. 347func (p *Predict) plugins(pluginTypes ...consts.PluginType) []string { 348 // This method's signature doesn't enforce that a pluginType must be passed in. 349 // If it's not, it's likely the caller's intent is go get a list of all of them, 350 // so let's help them out. 351 if len(pluginTypes) == 0 { 352 pluginTypes = append(pluginTypes, consts.PluginTypeUnknown) 353 } 354 355 client := p.Client() 356 if client == nil { 357 return nil 358 } 359 360 var plugins []string 361 pluginsAdded := make(map[string]bool) 362 for _, pluginType := range pluginTypes { 363 result, err := client.Sys().ListPlugins(&api.ListPluginsInput{Type: pluginType}) 364 if err != nil { 365 return nil 366 } 367 if result == nil { 368 return nil 369 } 370 for _, names := range result.PluginsByType { 371 for _, name := range names { 372 if _, ok := pluginsAdded[name]; !ok { 373 plugins = append(plugins, name) 374 pluginsAdded[name] = true 375 } 376 } 377 } 378 } 379 sort.Strings(plugins) 380 return plugins 381} 382 383// policies returns a sorted list of the policies stored in this Vault 384// server. 385func (p *Predict) policies() []string { 386 client := p.Client() 387 if client == nil { 388 return nil 389 } 390 391 policies, err := client.Sys().ListPolicies() 392 if err != nil { 393 return nil 394 } 395 sort.Strings(policies) 396 return policies 397} 398 399// mounts returns a sorted list of the mount paths for Vault server for 400// which the client is configured to communicate with. This function returns the 401// default list of mounts if an error occurs. 402func (p *Predict) mounts() []string { 403 client := p.Client() 404 if client == nil { 405 return nil 406 } 407 408 mounts, err := client.Sys().ListMounts() 409 if err != nil { 410 return defaultPredictVaultMounts 411 } 412 413 list := make([]string, 0, len(mounts)) 414 for m := range mounts { 415 list = append(list, m) 416 } 417 sort.Strings(list) 418 return list 419} 420 421// listPaths returns a list of paths (HTTP LIST) for the given path. This 422// function returns an empty list of any errors occur. 423func (p *Predict) listPaths(path string) []string { 424 client := p.Client() 425 if client == nil { 426 return nil 427 } 428 429 secret, err := client.Logical().List(path) 430 if err != nil || secret == nil || secret.Data == nil { 431 return nil 432 } 433 434 paths, ok := secret.Data["keys"].([]interface{}) 435 if !ok { 436 return nil 437 } 438 439 list := make([]string, 0, len(paths)) 440 for _, p := range paths { 441 if str, ok := p.(string); ok { 442 list = append(list, str) 443 } 444 } 445 sort.Strings(list) 446 return list 447} 448 449// hasPathArg determines if the args have already accepted a path. 450func (p *Predict) hasPathArg(args []string) bool { 451 var nonFlags []string 452 for _, a := range args { 453 if !strings.HasPrefix(a, "-") { 454 nonFlags = append(nonFlags, a) 455 } 456 } 457 458 return len(nonFlags) > 2 459} 460 461// filterFunc is used to compose a complete predictor that filters an array 462// of strings as per the filter function. 463func (p *Predict) filterFunc(f func() []string) complete.Predictor { 464 return complete.PredictFunc(func(args complete.Args) []string { 465 if p.hasPathArg(args.All) { 466 return nil 467 } 468 469 client := p.Client() 470 if client == nil { 471 return nil 472 } 473 474 return p.filter(f(), args.Last) 475 }) 476} 477 478// filter filters the given list for items that start with the prefix. 479func (p *Predict) filter(list []string, prefix string) []string { 480 var predictions []string 481 for _, item := range list { 482 if strings.HasPrefix(item, prefix) { 483 predictions = append(predictions, item) 484 } 485 } 486 return predictions 487} 488