1package command 2 3import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "net" 9 "net/http" 10 "os" 11 "path" 12 "sort" 13 "strings" 14 "sync" 15 "time" 16 17 "github.com/hashicorp/errwrap" 18 log "github.com/hashicorp/go-hclog" 19 "github.com/hashicorp/vault/api" 20 "github.com/hashicorp/vault/command/agent/auth" 21 "github.com/hashicorp/vault/command/agent/auth/alicloud" 22 "github.com/hashicorp/vault/command/agent/auth/approle" 23 "github.com/hashicorp/vault/command/agent/auth/aws" 24 "github.com/hashicorp/vault/command/agent/auth/azure" 25 "github.com/hashicorp/vault/command/agent/auth/cert" 26 "github.com/hashicorp/vault/command/agent/auth/gcp" 27 "github.com/hashicorp/vault/command/agent/auth/jwt" 28 "github.com/hashicorp/vault/command/agent/auth/kubernetes" 29 "github.com/hashicorp/vault/command/agent/auth/pcf" 30 "github.com/hashicorp/vault/command/agent/cache" 31 "github.com/hashicorp/vault/command/agent/config" 32 "github.com/hashicorp/vault/command/agent/sink" 33 "github.com/hashicorp/vault/command/agent/sink/file" 34 "github.com/hashicorp/vault/command/agent/sink/inmem" 35 gatedwriter "github.com/hashicorp/vault/helper/gated-writer" 36 "github.com/hashicorp/vault/sdk/helper/consts" 37 "github.com/hashicorp/vault/sdk/helper/logging" 38 "github.com/hashicorp/vault/sdk/version" 39 "github.com/kr/pretty" 40 "github.com/mitchellh/cli" 41 "github.com/posener/complete" 42) 43 44var _ cli.Command = (*AgentCommand)(nil) 45var _ cli.CommandAutocomplete = (*AgentCommand)(nil) 46 47type AgentCommand struct { 48 *BaseCommand 49 50 ShutdownCh chan struct{} 51 SighupCh chan struct{} 52 53 logWriter io.Writer 54 logGate *gatedwriter.Writer 55 logger log.Logger 56 57 cleanupGuard sync.Once 58 59 startedCh chan (struct{}) // for tests 60 61 flagConfigs []string 62 flagLogLevel string 63 64 flagTestVerifyOnly bool 65 flagCombineLogs bool 66} 67 68func (c *AgentCommand) Synopsis() string { 69 return "Start a Vault agent" 70} 71 72func (c *AgentCommand) Help() string { 73 helpText := ` 74Usage: vault agent [options] 75 76 This command starts a Vault agent that can perform automatic authentication 77 in certain environments. 78 79 Start an agent with a configuration file: 80 81 $ vault agent -config=/etc/vault/config.hcl 82 83 For a full list of examples, please see the documentation. 84 85` + c.Flags().Help() 86 return strings.TrimSpace(helpText) 87} 88 89func (c *AgentCommand) Flags() *FlagSets { 90 set := c.flagSet(FlagSetHTTP) 91 92 f := set.NewFlagSet("Command Options") 93 94 f.StringSliceVar(&StringSliceVar{ 95 Name: "config", 96 Target: &c.flagConfigs, 97 Completion: complete.PredictOr( 98 complete.PredictFiles("*.hcl"), 99 complete.PredictFiles("*.json"), 100 ), 101 Usage: "Path to a configuration file. This configuration file should " + 102 "contain only agent directives.", 103 }) 104 105 f.StringVar(&StringVar{ 106 Name: "log-level", 107 Target: &c.flagLogLevel, 108 Default: "info", 109 EnvVar: "VAULT_LOG_LEVEL", 110 Completion: complete.PredictSet("trace", "debug", "info", "warn", "err"), 111 Usage: "Log verbosity level. Supported values (in order of detail) are " + 112 "\"trace\", \"debug\", \"info\", \"warn\", and \"err\".", 113 }) 114 115 // Internal-only flags to follow. 116 // 117 // Why hello there little source code reader! Welcome to the Vault source 118 // code. The remaining options are intentionally undocumented and come with 119 // no warranty or backwards-compatibility promise. Do not use these flags 120 // in production. Do not build automation using these flags. Unless you are 121 // developing against Vault, you should not need any of these flags. 122 123 // TODO: should the below flags be public? 124 f.BoolVar(&BoolVar{ 125 Name: "combine-logs", 126 Target: &c.flagCombineLogs, 127 Default: false, 128 Hidden: true, 129 }) 130 131 f.BoolVar(&BoolVar{ 132 Name: "test-verify-only", 133 Target: &c.flagTestVerifyOnly, 134 Default: false, 135 Hidden: true, 136 }) 137 138 // End internal-only flags. 139 140 return set 141} 142 143func (c *AgentCommand) AutocompleteArgs() complete.Predictor { 144 return complete.PredictNothing 145} 146 147func (c *AgentCommand) AutocompleteFlags() complete.Flags { 148 return c.Flags().Completions() 149} 150 151func (c *AgentCommand) Run(args []string) int { 152 f := c.Flags() 153 154 if err := f.Parse(args); err != nil { 155 c.UI.Error(err.Error()) 156 return 1 157 } 158 159 // Create a logger. We wrap it in a gated writer so that it doesn't 160 // start logging too early. 161 c.logGate = &gatedwriter.Writer{Writer: os.Stderr} 162 c.logWriter = c.logGate 163 if c.flagCombineLogs { 164 c.logWriter = os.Stdout 165 } 166 var level log.Level 167 c.flagLogLevel = strings.ToLower(strings.TrimSpace(c.flagLogLevel)) 168 switch c.flagLogLevel { 169 case "trace": 170 level = log.Trace 171 case "debug": 172 level = log.Debug 173 case "notice", "info", "": 174 level = log.Info 175 case "warn", "warning": 176 level = log.Warn 177 case "err", "error": 178 level = log.Error 179 default: 180 c.UI.Error(fmt.Sprintf("Unknown log level: %s", c.flagLogLevel)) 181 return 1 182 } 183 184 if c.logger == nil { 185 c.logger = logging.NewVaultLoggerWithWriter(c.logWriter, level) 186 } 187 188 // Validation 189 if len(c.flagConfigs) != 1 { 190 c.UI.Error("Must specify exactly one config path using -config") 191 return 1 192 } 193 194 // Load the configuration 195 config, err := config.LoadConfig(c.flagConfigs[0], c.logger) 196 if err != nil { 197 c.UI.Error(fmt.Sprintf("Error loading configuration from %s: %s", c.flagConfigs[0], err)) 198 return 1 199 } 200 201 // Ensure at least one config was found. 202 if config == nil { 203 c.UI.Output(wrapAtLength( 204 "No configuration read. Please provide the configuration with the " + 205 "-config flag.")) 206 return 1 207 } 208 if config.AutoAuth == nil && config.Cache == nil { 209 c.UI.Error("No auto_auth or cache block found in config file") 210 return 1 211 } 212 if config.AutoAuth == nil { 213 c.UI.Info("No auto_auth block found in config file, not starting automatic authentication feature") 214 } 215 216 if config.Vault != nil { 217 c.setStringFlag(f, config.Vault.Address, &StringVar{ 218 Name: flagNameAddress, 219 Target: &c.flagAddress, 220 Default: "https://127.0.0.1:8200", 221 EnvVar: api.EnvVaultAddress, 222 }) 223 c.setStringFlag(f, config.Vault.CACert, &StringVar{ 224 Name: flagNameCACert, 225 Target: &c.flagCACert, 226 Default: "", 227 EnvVar: api.EnvVaultCACert, 228 }) 229 c.setStringFlag(f, config.Vault.CAPath, &StringVar{ 230 Name: flagNameCAPath, 231 Target: &c.flagCAPath, 232 Default: "", 233 EnvVar: api.EnvVaultCAPath, 234 }) 235 c.setStringFlag(f, config.Vault.ClientCert, &StringVar{ 236 Name: flagNameClientCert, 237 Target: &c.flagClientCert, 238 Default: "", 239 EnvVar: api.EnvVaultClientCert, 240 }) 241 c.setStringFlag(f, config.Vault.ClientKey, &StringVar{ 242 Name: flagNameClientKey, 243 Target: &c.flagClientKey, 244 Default: "", 245 EnvVar: api.EnvVaultClientKey, 246 }) 247 c.setBoolFlag(f, config.Vault.TLSSkipVerify, &BoolVar{ 248 Name: flagNameTLSSkipVerify, 249 Target: &c.flagTLSSkipVerify, 250 Default: false, 251 EnvVar: api.EnvVaultSkipVerify, 252 }) 253 } 254 255 infoKeys := make([]string, 0, 10) 256 info := make(map[string]string) 257 info["log level"] = c.flagLogLevel 258 infoKeys = append(infoKeys, "log level") 259 260 infoKeys = append(infoKeys, "version") 261 verInfo := version.GetVersion() 262 info["version"] = verInfo.FullVersionNumber(false) 263 if verInfo.Revision != "" { 264 info["version sha"] = strings.Trim(verInfo.Revision, "'") 265 infoKeys = append(infoKeys, "version sha") 266 } 267 infoKeys = append(infoKeys, "cgo") 268 info["cgo"] = "disabled" 269 if version.CgoEnabled { 270 info["cgo"] = "enabled" 271 } 272 273 // Tests might not want to start a vault server and just want to verify 274 // the configuration. 275 if c.flagTestVerifyOnly { 276 if os.Getenv("VAULT_TEST_VERIFY_ONLY_DUMP_CONFIG") != "" { 277 c.UI.Output(fmt.Sprintf( 278 "\nConfiguration:\n%s\n", 279 pretty.Sprint(*config))) 280 } 281 return 0 282 } 283 284 // Ignore any setting of agent's address. This client is used by the agent 285 // to reach out to Vault. This should never loop back to agent. 286 c.flagAgentAddress = "" 287 client, err := c.Client() 288 if err != nil { 289 c.UI.Error(fmt.Sprintf( 290 "Error fetching client: %v", 291 err)) 292 return 1 293 } 294 295 ctx, cancelFunc := context.WithCancel(context.Background()) 296 297 var method auth.AuthMethod 298 var sinks []*sink.SinkConfig 299 if config.AutoAuth != nil { 300 for _, sc := range config.AutoAuth.Sinks { 301 switch sc.Type { 302 case "file": 303 config := &sink.SinkConfig{ 304 Logger: c.logger.Named("sink.file"), 305 Config: sc.Config, 306 Client: client, 307 WrapTTL: sc.WrapTTL, 308 DHType: sc.DHType, 309 DHPath: sc.DHPath, 310 AAD: sc.AAD, 311 } 312 s, err := file.NewFileSink(config) 313 if err != nil { 314 c.UI.Error(errwrap.Wrapf("Error creating file sink: {{err}}", err).Error()) 315 return 1 316 } 317 config.Sink = s 318 sinks = append(sinks, config) 319 default: 320 c.UI.Error(fmt.Sprintf("Unknown sink type %q", sc.Type)) 321 return 1 322 } 323 } 324 325 // Check if a default namespace has been set 326 mountPath := config.AutoAuth.Method.MountPath 327 if config.AutoAuth.Method.Namespace != "" { 328 mountPath = path.Join(config.AutoAuth.Method.Namespace, mountPath) 329 } 330 331 authConfig := &auth.AuthConfig{ 332 Logger: c.logger.Named(fmt.Sprintf("auth.%s", config.AutoAuth.Method.Type)), 333 MountPath: mountPath, 334 Config: config.AutoAuth.Method.Config, 335 } 336 switch config.AutoAuth.Method.Type { 337 case "alicloud": 338 method, err = alicloud.NewAliCloudAuthMethod(authConfig) 339 case "aws": 340 method, err = aws.NewAWSAuthMethod(authConfig) 341 case "azure": 342 method, err = azure.NewAzureAuthMethod(authConfig) 343 case "cert": 344 method, err = cert.NewCertAuthMethod(authConfig) 345 case "gcp": 346 method, err = gcp.NewGCPAuthMethod(authConfig) 347 case "jwt": 348 method, err = jwt.NewJWTAuthMethod(authConfig) 349 case "kubernetes": 350 method, err = kubernetes.NewKubernetesAuthMethod(authConfig) 351 case "approle": 352 method, err = approle.NewApproleAuthMethod(authConfig) 353 case "pcf": 354 method, err = pcf.NewPCFAuthMethod(authConfig) 355 default: 356 c.UI.Error(fmt.Sprintf("Unknown auth method %q", config.AutoAuth.Method.Type)) 357 return 1 358 } 359 if err != nil { 360 c.UI.Error(errwrap.Wrapf(fmt.Sprintf("Error creating %s auth method: {{err}}", config.AutoAuth.Method.Type), err).Error()) 361 return 1 362 } 363 } 364 365 // Output the header that the server has started 366 if !c.flagCombineLogs { 367 c.UI.Output("==> Vault server started! Log data will stream in below:\n") 368 } 369 370 // Inform any tests that the server is ready 371 select { 372 case c.startedCh <- struct{}{}: 373 default: 374 } 375 376 // Parse agent listener configurations 377 if config.Cache != nil && len(config.Listeners) != 0 { 378 cacheLogger := c.logger.Named("cache") 379 380 // Create the API proxier 381 apiProxy, err := cache.NewAPIProxy(&cache.APIProxyConfig{ 382 Client: client, 383 Logger: cacheLogger.Named("apiproxy"), 384 }) 385 if err != nil { 386 c.UI.Error(fmt.Sprintf("Error creating API proxy: %v", err)) 387 return 1 388 } 389 390 // Create the lease cache proxier and set its underlying proxier to 391 // the API proxier. 392 leaseCache, err := cache.NewLeaseCache(&cache.LeaseCacheConfig{ 393 Client: client, 394 BaseContext: ctx, 395 Proxier: apiProxy, 396 Logger: cacheLogger.Named("leasecache"), 397 }) 398 if err != nil { 399 c.UI.Error(fmt.Sprintf("Error creating lease cache: %v", err)) 400 return 1 401 } 402 403 var inmemSink sink.Sink 404 if config.Cache.UseAutoAuthToken { 405 cacheLogger.Debug("auto-auth token is allowed to be used; configuring inmem sink") 406 inmemSink, err = inmem.New(&sink.SinkConfig{ 407 Logger: cacheLogger, 408 }, leaseCache) 409 if err != nil { 410 c.UI.Error(fmt.Sprintf("Error creating inmem sink for cache: %v", err)) 411 return 1 412 } 413 sinks = append(sinks, &sink.SinkConfig{ 414 Logger: cacheLogger, 415 Sink: inmemSink, 416 }) 417 } 418 419 // Create a muxer and add paths relevant for the lease cache layer 420 mux := http.NewServeMux() 421 mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx)) 422 423 mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, inmemSink)) 424 425 var listeners []net.Listener 426 for i, lnConfig := range config.Listeners { 427 ln, tlsConf, err := cache.StartListener(lnConfig) 428 if err != nil { 429 c.UI.Error(fmt.Sprintf("Error starting listener: %v", err)) 430 return 1 431 } 432 433 listeners = append(listeners, ln) 434 435 scheme := "https://" 436 if tlsConf == nil { 437 scheme = "http://" 438 } 439 if ln.Addr().Network() == "unix" { 440 scheme = "unix://" 441 } 442 443 infoKey := fmt.Sprintf("api address %d", i+1) 444 info[infoKey] = scheme + ln.Addr().String() 445 infoKeys = append(infoKeys, infoKey) 446 447 server := &http.Server{ 448 Addr: ln.Addr().String(), 449 TLSConfig: tlsConf, 450 Handler: mux, 451 ReadHeaderTimeout: 10 * time.Second, 452 ReadTimeout: 30 * time.Second, 453 IdleTimeout: 5 * time.Minute, 454 ErrorLog: cacheLogger.StandardLogger(nil), 455 } 456 457 go server.Serve(ln) 458 } 459 460 // Ensure that listeners are closed at all the exits 461 listenerCloseFunc := func() { 462 for _, ln := range listeners { 463 ln.Close() 464 } 465 } 466 defer c.cleanupGuard.Do(listenerCloseFunc) 467 } 468 469 var ssDoneCh, ahDoneCh chan struct{} 470 // Start auto-auth and sink servers 471 if method != nil { 472 ah := auth.NewAuthHandler(&auth.AuthHandlerConfig{ 473 Logger: c.logger.Named("auth.handler"), 474 Client: c.client, 475 WrapTTL: config.AutoAuth.Method.WrapTTL, 476 EnableReauthOnNewCredentials: config.AutoAuth.EnableReauthOnNewCredentials, 477 }) 478 ahDoneCh = ah.DoneCh 479 480 ss := sink.NewSinkServer(&sink.SinkServerConfig{ 481 Logger: c.logger.Named("sink.server"), 482 Client: client, 483 ExitAfterAuth: config.ExitAfterAuth, 484 }) 485 ssDoneCh = ss.DoneCh 486 487 go ah.Run(ctx, method) 488 go ss.Run(ctx, ah.OutputCh, sinks) 489 } 490 491 // Server configuration output 492 padding := 24 493 sort.Strings(infoKeys) 494 c.UI.Output("==> Vault agent configuration:\n") 495 for _, k := range infoKeys { 496 c.UI.Output(fmt.Sprintf( 497 "%s%s: %s", 498 strings.Repeat(" ", padding-len(k)), 499 strings.Title(k), 500 info[k])) 501 } 502 c.UI.Output("") 503 504 // Release the log gate. 505 c.logGate.Flush() 506 507 // Write out the PID to the file now that server has successfully started 508 if err := c.storePidFile(config.PidFile); err != nil { 509 c.UI.Error(fmt.Sprintf("Error storing PID: %s", err)) 510 return 1 511 } 512 513 defer func() { 514 if err := c.removePidFile(config.PidFile); err != nil { 515 c.UI.Error(fmt.Sprintf("Error deleting the PID file: %s", err)) 516 } 517 }() 518 519 select { 520 case <-ssDoneCh: 521 // This will happen if we exit-on-auth 522 c.logger.Info("sinks finished, exiting") 523 case <-c.ShutdownCh: 524 c.UI.Output("==> Vault agent shutdown triggered") 525 cancelFunc() 526 if ahDoneCh != nil { 527 <-ahDoneCh 528 } 529 if ssDoneCh != nil { 530 <-ssDoneCh 531 } 532 } 533 534 return 0 535} 536 537func (c *AgentCommand) setStringFlag(f *FlagSets, configVal string, fVar *StringVar) { 538 var isFlagSet bool 539 f.Visit(func(f *flag.Flag) { 540 if f.Name == fVar.Name { 541 isFlagSet = true 542 } 543 }) 544 545 flagEnvValue, flagEnvSet := os.LookupEnv(fVar.EnvVar) 546 switch { 547 case isFlagSet: 548 // Don't do anything as the flag is already set from the command line 549 case flagEnvSet: 550 // Use value from env var 551 *fVar.Target = flagEnvValue 552 case configVal != "": 553 // Use value from config 554 *fVar.Target = configVal 555 default: 556 // Use the default value 557 *fVar.Target = fVar.Default 558 } 559} 560 561func (c *AgentCommand) setBoolFlag(f *FlagSets, configVal bool, fVar *BoolVar) { 562 var isFlagSet bool 563 f.Visit(func(f *flag.Flag) { 564 if f.Name == fVar.Name { 565 isFlagSet = true 566 } 567 }) 568 569 flagEnvValue, flagEnvSet := os.LookupEnv(fVar.EnvVar) 570 switch { 571 case isFlagSet: 572 // Don't do anything as the flag is already set from the command line 573 case flagEnvSet: 574 // Use value from env var 575 *fVar.Target = flagEnvValue != "" 576 case configVal == true: 577 // Use value from config 578 *fVar.Target = configVal 579 default: 580 // Use the default value 581 *fVar.Target = fVar.Default 582 } 583} 584 585// storePidFile is used to write out our PID to a file if necessary 586func (c *AgentCommand) storePidFile(pidPath string) error { 587 // Quit fast if no pidfile 588 if pidPath == "" { 589 return nil 590 } 591 592 // Open the PID file 593 pidFile, err := os.OpenFile(pidPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) 594 if err != nil { 595 return errwrap.Wrapf("could not open pid file: {{err}}", err) 596 } 597 defer pidFile.Close() 598 599 // Write out the PID 600 pid := os.Getpid() 601 _, err = pidFile.WriteString(fmt.Sprintf("%d", pid)) 602 if err != nil { 603 return errwrap.Wrapf("could not write to pid file: {{err}}", err) 604 } 605 return nil 606} 607 608// removePidFile is used to cleanup the PID file if necessary 609func (c *AgentCommand) removePidFile(pidPath string) error { 610 if pidPath == "" { 611 return nil 612 } 613 return os.Remove(pidPath) 614} 615