1package command
2
3import (
4	"context"
5	"flag"
6	"fmt"
7	"io"
8	"net"
9	"net/http"
10	"os"
11	"path"
12	"sort"
13	"strings"
14	"sync"
15	"time"
16
17	"github.com/hashicorp/errwrap"
18	log "github.com/hashicorp/go-hclog"
19	"github.com/hashicorp/vault/api"
20	"github.com/hashicorp/vault/command/agent/auth"
21	"github.com/hashicorp/vault/command/agent/auth/alicloud"
22	"github.com/hashicorp/vault/command/agent/auth/approle"
23	"github.com/hashicorp/vault/command/agent/auth/aws"
24	"github.com/hashicorp/vault/command/agent/auth/azure"
25	"github.com/hashicorp/vault/command/agent/auth/cert"
26	"github.com/hashicorp/vault/command/agent/auth/gcp"
27	"github.com/hashicorp/vault/command/agent/auth/jwt"
28	"github.com/hashicorp/vault/command/agent/auth/kubernetes"
29	"github.com/hashicorp/vault/command/agent/auth/pcf"
30	"github.com/hashicorp/vault/command/agent/cache"
31	"github.com/hashicorp/vault/command/agent/config"
32	"github.com/hashicorp/vault/command/agent/sink"
33	"github.com/hashicorp/vault/command/agent/sink/file"
34	"github.com/hashicorp/vault/command/agent/sink/inmem"
35	gatedwriter "github.com/hashicorp/vault/helper/gated-writer"
36	"github.com/hashicorp/vault/sdk/helper/consts"
37	"github.com/hashicorp/vault/sdk/helper/logging"
38	"github.com/hashicorp/vault/sdk/version"
39	"github.com/kr/pretty"
40	"github.com/mitchellh/cli"
41	"github.com/posener/complete"
42)
43
44var _ cli.Command = (*AgentCommand)(nil)
45var _ cli.CommandAutocomplete = (*AgentCommand)(nil)
46
47type AgentCommand struct {
48	*BaseCommand
49
50	ShutdownCh chan struct{}
51	SighupCh   chan struct{}
52
53	logWriter io.Writer
54	logGate   *gatedwriter.Writer
55	logger    log.Logger
56
57	cleanupGuard sync.Once
58
59	startedCh chan (struct{}) // for tests
60
61	flagConfigs  []string
62	flagLogLevel string
63
64	flagTestVerifyOnly bool
65	flagCombineLogs    bool
66}
67
68func (c *AgentCommand) Synopsis() string {
69	return "Start a Vault agent"
70}
71
72func (c *AgentCommand) Help() string {
73	helpText := `
74Usage: vault agent [options]
75
76  This command starts a Vault agent that can perform automatic authentication
77  in certain environments.
78
79  Start an agent with a configuration file:
80
81      $ vault agent -config=/etc/vault/config.hcl
82
83  For a full list of examples, please see the documentation.
84
85` + c.Flags().Help()
86	return strings.TrimSpace(helpText)
87}
88
89func (c *AgentCommand) Flags() *FlagSets {
90	set := c.flagSet(FlagSetHTTP)
91
92	f := set.NewFlagSet("Command Options")
93
94	f.StringSliceVar(&StringSliceVar{
95		Name:   "config",
96		Target: &c.flagConfigs,
97		Completion: complete.PredictOr(
98			complete.PredictFiles("*.hcl"),
99			complete.PredictFiles("*.json"),
100		),
101		Usage: "Path to a configuration file. This configuration file should " +
102			"contain only agent directives.",
103	})
104
105	f.StringVar(&StringVar{
106		Name:       "log-level",
107		Target:     &c.flagLogLevel,
108		Default:    "info",
109		EnvVar:     "VAULT_LOG_LEVEL",
110		Completion: complete.PredictSet("trace", "debug", "info", "warn", "err"),
111		Usage: "Log verbosity level. Supported values (in order of detail) are " +
112			"\"trace\", \"debug\", \"info\", \"warn\", and \"err\".",
113	})
114
115	// Internal-only flags to follow.
116	//
117	// Why hello there little source code reader! Welcome to the Vault source
118	// code. The remaining options are intentionally undocumented and come with
119	// no warranty or backwards-compatibility promise. Do not use these flags
120	// in production. Do not build automation using these flags. Unless you are
121	// developing against Vault, you should not need any of these flags.
122
123	// TODO: should the below flags be public?
124	f.BoolVar(&BoolVar{
125		Name:    "combine-logs",
126		Target:  &c.flagCombineLogs,
127		Default: false,
128		Hidden:  true,
129	})
130
131	f.BoolVar(&BoolVar{
132		Name:    "test-verify-only",
133		Target:  &c.flagTestVerifyOnly,
134		Default: false,
135		Hidden:  true,
136	})
137
138	// End internal-only flags.
139
140	return set
141}
142
143func (c *AgentCommand) AutocompleteArgs() complete.Predictor {
144	return complete.PredictNothing
145}
146
147func (c *AgentCommand) AutocompleteFlags() complete.Flags {
148	return c.Flags().Completions()
149}
150
151func (c *AgentCommand) Run(args []string) int {
152	f := c.Flags()
153
154	if err := f.Parse(args); err != nil {
155		c.UI.Error(err.Error())
156		return 1
157	}
158
159	// Create a logger. We wrap it in a gated writer so that it doesn't
160	// start logging too early.
161	c.logGate = &gatedwriter.Writer{Writer: os.Stderr}
162	c.logWriter = c.logGate
163	if c.flagCombineLogs {
164		c.logWriter = os.Stdout
165	}
166	var level log.Level
167	c.flagLogLevel = strings.ToLower(strings.TrimSpace(c.flagLogLevel))
168	switch c.flagLogLevel {
169	case "trace":
170		level = log.Trace
171	case "debug":
172		level = log.Debug
173	case "notice", "info", "":
174		level = log.Info
175	case "warn", "warning":
176		level = log.Warn
177	case "err", "error":
178		level = log.Error
179	default:
180		c.UI.Error(fmt.Sprintf("Unknown log level: %s", c.flagLogLevel))
181		return 1
182	}
183
184	if c.logger == nil {
185		c.logger = logging.NewVaultLoggerWithWriter(c.logWriter, level)
186	}
187
188	// Validation
189	if len(c.flagConfigs) != 1 {
190		c.UI.Error("Must specify exactly one config path using -config")
191		return 1
192	}
193
194	// Load the configuration
195	config, err := config.LoadConfig(c.flagConfigs[0], c.logger)
196	if err != nil {
197		c.UI.Error(fmt.Sprintf("Error loading configuration from %s: %s", c.flagConfigs[0], err))
198		return 1
199	}
200
201	// Ensure at least one config was found.
202	if config == nil {
203		c.UI.Output(wrapAtLength(
204			"No configuration read. Please provide the configuration with the " +
205				"-config flag."))
206		return 1
207	}
208	if config.AutoAuth == nil && config.Cache == nil {
209		c.UI.Error("No auto_auth or cache block found in config file")
210		return 1
211	}
212	if config.AutoAuth == nil {
213		c.UI.Info("No auto_auth block found in config file, not starting automatic authentication feature")
214	}
215
216	if config.Vault != nil {
217		c.setStringFlag(f, config.Vault.Address, &StringVar{
218			Name:    flagNameAddress,
219			Target:  &c.flagAddress,
220			Default: "https://127.0.0.1:8200",
221			EnvVar:  api.EnvVaultAddress,
222		})
223		c.setStringFlag(f, config.Vault.CACert, &StringVar{
224			Name:    flagNameCACert,
225			Target:  &c.flagCACert,
226			Default: "",
227			EnvVar:  api.EnvVaultCACert,
228		})
229		c.setStringFlag(f, config.Vault.CAPath, &StringVar{
230			Name:    flagNameCAPath,
231			Target:  &c.flagCAPath,
232			Default: "",
233			EnvVar:  api.EnvVaultCAPath,
234		})
235		c.setStringFlag(f, config.Vault.ClientCert, &StringVar{
236			Name:    flagNameClientCert,
237			Target:  &c.flagClientCert,
238			Default: "",
239			EnvVar:  api.EnvVaultClientCert,
240		})
241		c.setStringFlag(f, config.Vault.ClientKey, &StringVar{
242			Name:    flagNameClientKey,
243			Target:  &c.flagClientKey,
244			Default: "",
245			EnvVar:  api.EnvVaultClientKey,
246		})
247		c.setBoolFlag(f, config.Vault.TLSSkipVerify, &BoolVar{
248			Name:    flagNameTLSSkipVerify,
249			Target:  &c.flagTLSSkipVerify,
250			Default: false,
251			EnvVar:  api.EnvVaultSkipVerify,
252		})
253	}
254
255	infoKeys := make([]string, 0, 10)
256	info := make(map[string]string)
257	info["log level"] = c.flagLogLevel
258	infoKeys = append(infoKeys, "log level")
259
260	infoKeys = append(infoKeys, "version")
261	verInfo := version.GetVersion()
262	info["version"] = verInfo.FullVersionNumber(false)
263	if verInfo.Revision != "" {
264		info["version sha"] = strings.Trim(verInfo.Revision, "'")
265		infoKeys = append(infoKeys, "version sha")
266	}
267	infoKeys = append(infoKeys, "cgo")
268	info["cgo"] = "disabled"
269	if version.CgoEnabled {
270		info["cgo"] = "enabled"
271	}
272
273	// Tests might not want to start a vault server and just want to verify
274	// the configuration.
275	if c.flagTestVerifyOnly {
276		if os.Getenv("VAULT_TEST_VERIFY_ONLY_DUMP_CONFIG") != "" {
277			c.UI.Output(fmt.Sprintf(
278				"\nConfiguration:\n%s\n",
279				pretty.Sprint(*config)))
280		}
281		return 0
282	}
283
284	// Ignore any setting of agent's address. This client is used by the agent
285	// to reach out to Vault. This should never loop back to agent.
286	c.flagAgentAddress = ""
287	client, err := c.Client()
288	if err != nil {
289		c.UI.Error(fmt.Sprintf(
290			"Error fetching client: %v",
291			err))
292		return 1
293	}
294
295	ctx, cancelFunc := context.WithCancel(context.Background())
296
297	var method auth.AuthMethod
298	var sinks []*sink.SinkConfig
299	if config.AutoAuth != nil {
300		for _, sc := range config.AutoAuth.Sinks {
301			switch sc.Type {
302			case "file":
303				config := &sink.SinkConfig{
304					Logger:  c.logger.Named("sink.file"),
305					Config:  sc.Config,
306					Client:  client,
307					WrapTTL: sc.WrapTTL,
308					DHType:  sc.DHType,
309					DHPath:  sc.DHPath,
310					AAD:     sc.AAD,
311				}
312				s, err := file.NewFileSink(config)
313				if err != nil {
314					c.UI.Error(errwrap.Wrapf("Error creating file sink: {{err}}", err).Error())
315					return 1
316				}
317				config.Sink = s
318				sinks = append(sinks, config)
319			default:
320				c.UI.Error(fmt.Sprintf("Unknown sink type %q", sc.Type))
321				return 1
322			}
323		}
324
325		// Check if a default namespace has been set
326		mountPath := config.AutoAuth.Method.MountPath
327		if config.AutoAuth.Method.Namespace != "" {
328			mountPath = path.Join(config.AutoAuth.Method.Namespace, mountPath)
329		}
330
331		authConfig := &auth.AuthConfig{
332			Logger:    c.logger.Named(fmt.Sprintf("auth.%s", config.AutoAuth.Method.Type)),
333			MountPath: mountPath,
334			Config:    config.AutoAuth.Method.Config,
335		}
336		switch config.AutoAuth.Method.Type {
337		case "alicloud":
338			method, err = alicloud.NewAliCloudAuthMethod(authConfig)
339		case "aws":
340			method, err = aws.NewAWSAuthMethod(authConfig)
341		case "azure":
342			method, err = azure.NewAzureAuthMethod(authConfig)
343		case "cert":
344			method, err = cert.NewCertAuthMethod(authConfig)
345		case "gcp":
346			method, err = gcp.NewGCPAuthMethod(authConfig)
347		case "jwt":
348			method, err = jwt.NewJWTAuthMethod(authConfig)
349		case "kubernetes":
350			method, err = kubernetes.NewKubernetesAuthMethod(authConfig)
351		case "approle":
352			method, err = approle.NewApproleAuthMethod(authConfig)
353		case "pcf":
354			method, err = pcf.NewPCFAuthMethod(authConfig)
355		default:
356			c.UI.Error(fmt.Sprintf("Unknown auth method %q", config.AutoAuth.Method.Type))
357			return 1
358		}
359		if err != nil {
360			c.UI.Error(errwrap.Wrapf(fmt.Sprintf("Error creating %s auth method: {{err}}", config.AutoAuth.Method.Type), err).Error())
361			return 1
362		}
363	}
364
365	// Output the header that the server has started
366	if !c.flagCombineLogs {
367		c.UI.Output("==> Vault server started! Log data will stream in below:\n")
368	}
369
370	// Inform any tests that the server is ready
371	select {
372	case c.startedCh <- struct{}{}:
373	default:
374	}
375
376	// Parse agent listener configurations
377	if config.Cache != nil && len(config.Listeners) != 0 {
378		cacheLogger := c.logger.Named("cache")
379
380		// Create the API proxier
381		apiProxy, err := cache.NewAPIProxy(&cache.APIProxyConfig{
382			Client: client,
383			Logger: cacheLogger.Named("apiproxy"),
384		})
385		if err != nil {
386			c.UI.Error(fmt.Sprintf("Error creating API proxy: %v", err))
387			return 1
388		}
389
390		// Create the lease cache proxier and set its underlying proxier to
391		// the API proxier.
392		leaseCache, err := cache.NewLeaseCache(&cache.LeaseCacheConfig{
393			Client:      client,
394			BaseContext: ctx,
395			Proxier:     apiProxy,
396			Logger:      cacheLogger.Named("leasecache"),
397		})
398		if err != nil {
399			c.UI.Error(fmt.Sprintf("Error creating lease cache: %v", err))
400			return 1
401		}
402
403		var inmemSink sink.Sink
404		if config.Cache.UseAutoAuthToken {
405			cacheLogger.Debug("auto-auth token is allowed to be used; configuring inmem sink")
406			inmemSink, err = inmem.New(&sink.SinkConfig{
407				Logger: cacheLogger,
408			}, leaseCache)
409			if err != nil {
410				c.UI.Error(fmt.Sprintf("Error creating inmem sink for cache: %v", err))
411				return 1
412			}
413			sinks = append(sinks, &sink.SinkConfig{
414				Logger: cacheLogger,
415				Sink:   inmemSink,
416			})
417		}
418
419		// Create a muxer and add paths relevant for the lease cache layer
420		mux := http.NewServeMux()
421		mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx))
422
423		mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, inmemSink))
424
425		var listeners []net.Listener
426		for i, lnConfig := range config.Listeners {
427			ln, tlsConf, err := cache.StartListener(lnConfig)
428			if err != nil {
429				c.UI.Error(fmt.Sprintf("Error starting listener: %v", err))
430				return 1
431			}
432
433			listeners = append(listeners, ln)
434
435			scheme := "https://"
436			if tlsConf == nil {
437				scheme = "http://"
438			}
439			if ln.Addr().Network() == "unix" {
440				scheme = "unix://"
441			}
442
443			infoKey := fmt.Sprintf("api address %d", i+1)
444			info[infoKey] = scheme + ln.Addr().String()
445			infoKeys = append(infoKeys, infoKey)
446
447			server := &http.Server{
448				Addr:              ln.Addr().String(),
449				TLSConfig:         tlsConf,
450				Handler:           mux,
451				ReadHeaderTimeout: 10 * time.Second,
452				ReadTimeout:       30 * time.Second,
453				IdleTimeout:       5 * time.Minute,
454				ErrorLog:          cacheLogger.StandardLogger(nil),
455			}
456
457			go server.Serve(ln)
458		}
459
460		// Ensure that listeners are closed at all the exits
461		listenerCloseFunc := func() {
462			for _, ln := range listeners {
463				ln.Close()
464			}
465		}
466		defer c.cleanupGuard.Do(listenerCloseFunc)
467	}
468
469	var ssDoneCh, ahDoneCh chan struct{}
470	// Start auto-auth and sink servers
471	if method != nil {
472		ah := auth.NewAuthHandler(&auth.AuthHandlerConfig{
473			Logger:                       c.logger.Named("auth.handler"),
474			Client:                       c.client,
475			WrapTTL:                      config.AutoAuth.Method.WrapTTL,
476			EnableReauthOnNewCredentials: config.AutoAuth.EnableReauthOnNewCredentials,
477		})
478		ahDoneCh = ah.DoneCh
479
480		ss := sink.NewSinkServer(&sink.SinkServerConfig{
481			Logger:        c.logger.Named("sink.server"),
482			Client:        client,
483			ExitAfterAuth: config.ExitAfterAuth,
484		})
485		ssDoneCh = ss.DoneCh
486
487		go ah.Run(ctx, method)
488		go ss.Run(ctx, ah.OutputCh, sinks)
489	}
490
491	// Server configuration output
492	padding := 24
493	sort.Strings(infoKeys)
494	c.UI.Output("==> Vault agent configuration:\n")
495	for _, k := range infoKeys {
496		c.UI.Output(fmt.Sprintf(
497			"%s%s: %s",
498			strings.Repeat(" ", padding-len(k)),
499			strings.Title(k),
500			info[k]))
501	}
502	c.UI.Output("")
503
504	// Release the log gate.
505	c.logGate.Flush()
506
507	// Write out the PID to the file now that server has successfully started
508	if err := c.storePidFile(config.PidFile); err != nil {
509		c.UI.Error(fmt.Sprintf("Error storing PID: %s", err))
510		return 1
511	}
512
513	defer func() {
514		if err := c.removePidFile(config.PidFile); err != nil {
515			c.UI.Error(fmt.Sprintf("Error deleting the PID file: %s", err))
516		}
517	}()
518
519	select {
520	case <-ssDoneCh:
521		// This will happen if we exit-on-auth
522		c.logger.Info("sinks finished, exiting")
523	case <-c.ShutdownCh:
524		c.UI.Output("==> Vault agent shutdown triggered")
525		cancelFunc()
526		if ahDoneCh != nil {
527			<-ahDoneCh
528		}
529		if ssDoneCh != nil {
530			<-ssDoneCh
531		}
532	}
533
534	return 0
535}
536
537func (c *AgentCommand) setStringFlag(f *FlagSets, configVal string, fVar *StringVar) {
538	var isFlagSet bool
539	f.Visit(func(f *flag.Flag) {
540		if f.Name == fVar.Name {
541			isFlagSet = true
542		}
543	})
544
545	flagEnvValue, flagEnvSet := os.LookupEnv(fVar.EnvVar)
546	switch {
547	case isFlagSet:
548		// Don't do anything as the flag is already set from the command line
549	case flagEnvSet:
550		// Use value from env var
551		*fVar.Target = flagEnvValue
552	case configVal != "":
553		// Use value from config
554		*fVar.Target = configVal
555	default:
556		// Use the default value
557		*fVar.Target = fVar.Default
558	}
559}
560
561func (c *AgentCommand) setBoolFlag(f *FlagSets, configVal bool, fVar *BoolVar) {
562	var isFlagSet bool
563	f.Visit(func(f *flag.Flag) {
564		if f.Name == fVar.Name {
565			isFlagSet = true
566		}
567	})
568
569	flagEnvValue, flagEnvSet := os.LookupEnv(fVar.EnvVar)
570	switch {
571	case isFlagSet:
572		// Don't do anything as the flag is already set from the command line
573	case flagEnvSet:
574		// Use value from env var
575		*fVar.Target = flagEnvValue != ""
576	case configVal == true:
577		// Use value from config
578		*fVar.Target = configVal
579	default:
580		// Use the default value
581		*fVar.Target = fVar.Default
582	}
583}
584
585// storePidFile is used to write out our PID to a file if necessary
586func (c *AgentCommand) storePidFile(pidPath string) error {
587	// Quit fast if no pidfile
588	if pidPath == "" {
589		return nil
590	}
591
592	// Open the PID file
593	pidFile, err := os.OpenFile(pidPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
594	if err != nil {
595		return errwrap.Wrapf("could not open pid file: {{err}}", err)
596	}
597	defer pidFile.Close()
598
599	// Write out the PID
600	pid := os.Getpid()
601	_, err = pidFile.WriteString(fmt.Sprintf("%d", pid))
602	if err != nil {
603		return errwrap.Wrapf("could not write to pid file: {{err}}", err)
604	}
605	return nil
606}
607
608// removePidFile is used to cleanup the PID file if necessary
609func (c *AgentCommand) removePidFile(pidPath string) error {
610	if pidPath == "" {
611		return nil
612	}
613	return os.Remove(pidPath)
614}
615