1package command
2
3import (
4	"context"
5	"errors"
6	"flag"
7	"fmt"
8	"io"
9	"io/ioutil"
10	"net"
11	"net/http"
12	"os"
13	"path"
14	"path/filepath"
15	"sort"
16	"strings"
17	"sync"
18	"time"
19
20	log "github.com/hashicorp/go-hclog"
21	"github.com/hashicorp/vault/api"
22	"github.com/hashicorp/vault/command/agent/auth"
23	"github.com/hashicorp/vault/command/agent/auth/alicloud"
24	"github.com/hashicorp/vault/command/agent/auth/approle"
25	"github.com/hashicorp/vault/command/agent/auth/aws"
26	"github.com/hashicorp/vault/command/agent/auth/azure"
27	"github.com/hashicorp/vault/command/agent/auth/cert"
28	"github.com/hashicorp/vault/command/agent/auth/cf"
29	"github.com/hashicorp/vault/command/agent/auth/gcp"
30	"github.com/hashicorp/vault/command/agent/auth/jwt"
31	"github.com/hashicorp/vault/command/agent/auth/kerberos"
32	"github.com/hashicorp/vault/command/agent/auth/kubernetes"
33	"github.com/hashicorp/vault/command/agent/cache"
34	"github.com/hashicorp/vault/command/agent/cache/cacheboltdb"
35	"github.com/hashicorp/vault/command/agent/cache/cachememdb"
36	"github.com/hashicorp/vault/command/agent/cache/keymanager"
37	agentConfig "github.com/hashicorp/vault/command/agent/config"
38	"github.com/hashicorp/vault/command/agent/sink"
39	"github.com/hashicorp/vault/command/agent/sink/file"
40	"github.com/hashicorp/vault/command/agent/sink/inmem"
41	"github.com/hashicorp/vault/command/agent/template"
42	"github.com/hashicorp/vault/command/agent/winsvc"
43	"github.com/hashicorp/vault/internalshared/gatedwriter"
44	"github.com/hashicorp/vault/sdk/helper/consts"
45	"github.com/hashicorp/vault/sdk/helper/logging"
46	"github.com/hashicorp/vault/sdk/logical"
47	"github.com/hashicorp/vault/sdk/version"
48	"github.com/kr/pretty"
49	"github.com/mitchellh/cli"
50	"github.com/oklog/run"
51	"github.com/posener/complete"
52)
53
54var (
55	_ cli.Command             = (*AgentCommand)(nil)
56	_ cli.CommandAutocomplete = (*AgentCommand)(nil)
57)
58
59type AgentCommand struct {
60	*BaseCommand
61
62	ShutdownCh chan struct{}
63	SighupCh   chan struct{}
64
65	logWriter io.Writer
66	logGate   *gatedwriter.Writer
67	logger    log.Logger
68
69	cleanupGuard sync.Once
70
71	startedCh chan (struct{}) // for tests
72
73	flagConfigs       []string
74	flagLogLevel      string
75	flagExitAfterAuth bool
76
77	flagTestVerifyOnly bool
78	flagCombineLogs    bool
79}
80
81func (c *AgentCommand) Synopsis() string {
82	return "Start a Vault agent"
83}
84
85func (c *AgentCommand) Help() string {
86	helpText := `
87Usage: vault agent [options]
88
89  This command starts a Vault agent that can perform automatic authentication
90  in certain environments.
91
92  Start an agent with a configuration file:
93
94      $ vault agent -config=/etc/vault/config.hcl
95
96  For a full list of examples, please see the documentation.
97
98` + c.Flags().Help()
99	return strings.TrimSpace(helpText)
100}
101
102func (c *AgentCommand) Flags() *FlagSets {
103	set := c.flagSet(FlagSetHTTP)
104
105	f := set.NewFlagSet("Command Options")
106
107	f.StringSliceVar(&StringSliceVar{
108		Name:   "config",
109		Target: &c.flagConfigs,
110		Completion: complete.PredictOr(
111			complete.PredictFiles("*.hcl"),
112			complete.PredictFiles("*.json"),
113		),
114		Usage: "Path to a configuration file. This configuration file should " +
115			"contain only agent directives.",
116	})
117
118	f.StringVar(&StringVar{
119		Name:       "log-level",
120		Target:     &c.flagLogLevel,
121		Default:    "info",
122		EnvVar:     "VAULT_LOG_LEVEL",
123		Completion: complete.PredictSet("trace", "debug", "info", "warn", "err"),
124		Usage: "Log verbosity level. Supported values (in order of detail) are " +
125			"\"trace\", \"debug\", \"info\", \"warn\", and \"err\".",
126	})
127
128	f.BoolVar(&BoolVar{
129		Name:    "exit-after-auth",
130		Target:  &c.flagExitAfterAuth,
131		Default: false,
132		Usage: "If set to true, the agent will exit with code 0 after a single " +
133			"successful auth, where success means that a token was retrieved and " +
134			"all sinks successfully wrote it",
135	})
136
137	// Internal-only flags to follow.
138	//
139	// Why hello there little source code reader! Welcome to the Vault source
140	// code. The remaining options are intentionally undocumented and come with
141	// no warranty or backwards-compatibility promise. Do not use these flags
142	// in production. Do not build automation using these flags. Unless you are
143	// developing against Vault, you should not need any of these flags.
144
145	// TODO: should the below flags be public?
146	f.BoolVar(&BoolVar{
147		Name:    "combine-logs",
148		Target:  &c.flagCombineLogs,
149		Default: false,
150		Hidden:  true,
151	})
152
153	f.BoolVar(&BoolVar{
154		Name:    "test-verify-only",
155		Target:  &c.flagTestVerifyOnly,
156		Default: false,
157		Hidden:  true,
158	})
159
160	// End internal-only flags.
161
162	return set
163}
164
165func (c *AgentCommand) AutocompleteArgs() complete.Predictor {
166	return complete.PredictNothing
167}
168
169func (c *AgentCommand) AutocompleteFlags() complete.Flags {
170	return c.Flags().Completions()
171}
172
173func (c *AgentCommand) Run(args []string) int {
174	f := c.Flags()
175
176	if err := f.Parse(args); err != nil {
177		c.UI.Error(err.Error())
178		return 1
179	}
180
181	// Create a logger. We wrap it in a gated writer so that it doesn't
182	// start logging too early.
183	c.logGate = gatedwriter.NewWriter(os.Stderr)
184	c.logWriter = c.logGate
185	if c.flagCombineLogs {
186		c.logWriter = os.Stdout
187	}
188	var level log.Level
189	c.flagLogLevel = strings.ToLower(strings.TrimSpace(c.flagLogLevel))
190	switch c.flagLogLevel {
191	case "trace":
192		level = log.Trace
193	case "debug":
194		level = log.Debug
195	case "notice", "info", "":
196		level = log.Info
197	case "warn", "warning":
198		level = log.Warn
199	case "err", "error":
200		level = log.Error
201	default:
202		c.UI.Error(fmt.Sprintf("Unknown log level: %s", c.flagLogLevel))
203		return 1
204	}
205
206	if c.logger == nil {
207		c.logger = logging.NewVaultLoggerWithWriter(c.logWriter, level)
208	}
209
210	// Validation
211	if len(c.flagConfigs) != 1 {
212		c.UI.Error("Must specify exactly one config path using -config")
213		return 1
214	}
215
216	// Load the configuration
217	config, err := agentConfig.LoadConfig(c.flagConfigs[0])
218	if err != nil {
219		c.UI.Error(fmt.Sprintf("Error loading configuration from %s: %s", c.flagConfigs[0], err))
220		return 1
221	}
222
223	// Ensure at least one config was found.
224	if config == nil {
225		c.UI.Output(wrapAtLength(
226			"No configuration read. Please provide the configuration with the " +
227				"-config flag."))
228		return 1
229	}
230	if config.AutoAuth == nil && config.Cache == nil {
231		c.UI.Error("No auto_auth or cache block found in config file")
232		return 1
233	}
234	if config.AutoAuth == nil {
235		c.UI.Info("No auto_auth block found in config file, not starting automatic authentication feature")
236	}
237
238	exitAfterAuth := config.ExitAfterAuth
239	f.Visit(func(fl *flag.Flag) {
240		if fl.Name == "exit-after-auth" {
241			exitAfterAuth = c.flagExitAfterAuth
242		}
243	})
244
245	c.setStringFlag(f, config.Vault.Address, &StringVar{
246		Name:    flagNameAddress,
247		Target:  &c.flagAddress,
248		Default: "https://127.0.0.1:8200",
249		EnvVar:  api.EnvVaultAddress,
250	})
251	config.Vault.Address = c.flagAddress
252	c.setStringFlag(f, config.Vault.CACert, &StringVar{
253		Name:    flagNameCACert,
254		Target:  &c.flagCACert,
255		Default: "",
256		EnvVar:  api.EnvVaultCACert,
257	})
258	config.Vault.CACert = c.flagCACert
259	c.setStringFlag(f, config.Vault.CAPath, &StringVar{
260		Name:    flagNameCAPath,
261		Target:  &c.flagCAPath,
262		Default: "",
263		EnvVar:  api.EnvVaultCAPath,
264	})
265	config.Vault.CAPath = c.flagCAPath
266	c.setStringFlag(f, config.Vault.ClientCert, &StringVar{
267		Name:    flagNameClientCert,
268		Target:  &c.flagClientCert,
269		Default: "",
270		EnvVar:  api.EnvVaultClientCert,
271	})
272	config.Vault.ClientCert = c.flagClientCert
273	c.setStringFlag(f, config.Vault.ClientKey, &StringVar{
274		Name:    flagNameClientKey,
275		Target:  &c.flagClientKey,
276		Default: "",
277		EnvVar:  api.EnvVaultClientKey,
278	})
279	config.Vault.ClientKey = c.flagClientKey
280	c.setBoolFlag(f, config.Vault.TLSSkipVerify, &BoolVar{
281		Name:    flagNameTLSSkipVerify,
282		Target:  &c.flagTLSSkipVerify,
283		Default: false,
284		EnvVar:  api.EnvVaultSkipVerify,
285	})
286	config.Vault.TLSSkipVerify = c.flagTLSSkipVerify
287	c.setStringFlag(f, config.Vault.TLSServerName, &StringVar{
288		Name:    flagTLSServerName,
289		Target:  &c.flagTLSServerName,
290		Default: "",
291		EnvVar:  api.EnvVaultTLSServerName,
292	})
293	config.Vault.TLSServerName = c.flagTLSServerName
294
295	infoKeys := make([]string, 0, 10)
296	info := make(map[string]string)
297	info["log level"] = c.flagLogLevel
298	infoKeys = append(infoKeys, "log level")
299
300	infoKeys = append(infoKeys, "version")
301	verInfo := version.GetVersion()
302	info["version"] = verInfo.FullVersionNumber(false)
303	if verInfo.Revision != "" {
304		info["version sha"] = strings.Trim(verInfo.Revision, "'")
305		infoKeys = append(infoKeys, "version sha")
306	}
307	infoKeys = append(infoKeys, "cgo")
308	info["cgo"] = "disabled"
309	if version.CgoEnabled {
310		info["cgo"] = "enabled"
311	}
312
313	// Tests might not want to start a vault server and just want to verify
314	// the configuration.
315	if c.flagTestVerifyOnly {
316		if os.Getenv("VAULT_TEST_VERIFY_ONLY_DUMP_CONFIG") != "" {
317			c.UI.Output(fmt.Sprintf(
318				"\nConfiguration:\n%s\n",
319				pretty.Sprint(*config)))
320		}
321		return 0
322	}
323
324	// Ignore any setting of agent's address. This client is used by the agent
325	// to reach out to Vault. This should never loop back to agent.
326	c.flagAgentAddress = ""
327	client, err := c.Client()
328	if err != nil {
329		c.UI.Error(fmt.Sprintf(
330			"Error fetching client: %v",
331			err))
332		return 1
333	}
334
335	// ctx and cancelFunc are passed to the AuthHandler, SinkServer, and
336	// TemplateServer that periodically listen for ctx.Done() to fire and shut
337	// down accordingly.
338	ctx, cancelFunc := context.WithCancel(context.Background())
339
340	var method auth.AuthMethod
341	var sinks []*sink.SinkConfig
342	var namespace string
343	if config.AutoAuth != nil {
344		for _, sc := range config.AutoAuth.Sinks {
345			switch sc.Type {
346			case "file":
347				config := &sink.SinkConfig{
348					Logger:    c.logger.Named("sink.file"),
349					Config:    sc.Config,
350					Client:    client,
351					WrapTTL:   sc.WrapTTL,
352					DHType:    sc.DHType,
353					DeriveKey: sc.DeriveKey,
354					DHPath:    sc.DHPath,
355					AAD:       sc.AAD,
356				}
357				s, err := file.NewFileSink(config)
358				if err != nil {
359					c.UI.Error(fmt.Errorf("Error creating file sink: %w", err).Error())
360					return 1
361				}
362				config.Sink = s
363				sinks = append(sinks, config)
364			default:
365				c.UI.Error(fmt.Sprintf("Unknown sink type %q", sc.Type))
366				return 1
367			}
368		}
369
370		// Check if a default namespace has been set
371		mountPath := config.AutoAuth.Method.MountPath
372		if cns := config.AutoAuth.Method.Namespace; cns != "" {
373			namespace = cns
374			// Only set this value if the env var is empty, otherwise we end up with a nested namespace
375			if ens := os.Getenv(api.EnvVaultNamespace); ens == "" {
376				mountPath = path.Join(cns, mountPath)
377			}
378		}
379
380		authConfig := &auth.AuthConfig{
381			Logger:    c.logger.Named(fmt.Sprintf("auth.%s", config.AutoAuth.Method.Type)),
382			MountPath: mountPath,
383			Config:    config.AutoAuth.Method.Config,
384		}
385		switch config.AutoAuth.Method.Type {
386		case "alicloud":
387			method, err = alicloud.NewAliCloudAuthMethod(authConfig)
388		case "aws":
389			method, err = aws.NewAWSAuthMethod(authConfig)
390		case "azure":
391			method, err = azure.NewAzureAuthMethod(authConfig)
392		case "cert":
393			method, err = cert.NewCertAuthMethod(authConfig)
394		case "cf":
395			method, err = cf.NewCFAuthMethod(authConfig)
396		case "gcp":
397			method, err = gcp.NewGCPAuthMethod(authConfig)
398		case "jwt":
399			method, err = jwt.NewJWTAuthMethod(authConfig)
400		case "kerberos":
401			method, err = kerberos.NewKerberosAuthMethod(authConfig)
402		case "kubernetes":
403			method, err = kubernetes.NewKubernetesAuthMethod(authConfig)
404		case "approle":
405			method, err = approle.NewApproleAuthMethod(authConfig)
406		case "pcf": // Deprecated.
407			method, err = cf.NewCFAuthMethod(authConfig)
408		default:
409			c.UI.Error(fmt.Sprintf("Unknown auth method %q", config.AutoAuth.Method.Type))
410			return 1
411		}
412		if err != nil {
413			c.UI.Error(fmt.Errorf("Error creating %s auth method: %w", config.AutoAuth.Method.Type, err).Error())
414			return 1
415		}
416	}
417
418	// We do this after auto-auth has been configured, because we don't want to
419	// confuse the issue of retries for auth failures which have their own
420	// config and are handled a bit differently.
421	if os.Getenv(api.EnvVaultMaxRetries) == "" {
422		client.SetMaxRetries(config.Vault.Retry.NumRetries)
423	}
424
425	enforceConsistency := cache.EnforceConsistencyNever
426	whenInconsistent := cache.WhenInconsistentFail
427	if config.Cache != nil {
428		switch config.Cache.EnforceConsistency {
429		case "always":
430			enforceConsistency = cache.EnforceConsistencyAlways
431		case "never", "":
432		default:
433			c.UI.Error(fmt.Sprintf("Unknown cache setting for enforce_consistency: %q", config.Cache.EnforceConsistency))
434			return 1
435		}
436
437		switch config.Cache.WhenInconsistent {
438		case "retry":
439			whenInconsistent = cache.WhenInconsistentRetry
440		case "forward":
441			whenInconsistent = cache.WhenInconsistentForward
442		case "fail", "":
443		default:
444			c.UI.Error(fmt.Sprintf("Unknown cache setting for when_inconsistent: %q", config.Cache.WhenInconsistent))
445			return 1
446		}
447	}
448
449	// Warn if cache _and_ cert auto-auth is enabled but certificates were not
450	// provided in the auto_auth.method["cert"].config stanza.
451	if config.Cache != nil && (config.AutoAuth != nil && config.AutoAuth.Method != nil && config.AutoAuth.Method.Type == "cert") {
452		_, okCertFile := config.AutoAuth.Method.Config["client_cert"]
453		_, okCertKey := config.AutoAuth.Method.Config["client_key"]
454
455		// If neither of these exists in the cert stanza, agent will use the
456		// certs from the vault stanza.
457		if !okCertFile && !okCertKey {
458			c.UI.Warn(wrapAtLength("WARNING! Cache is enabled and using the same certificates " +
459				"from the 'cert' auto-auth method specified in the 'vault' stanza. Consider " +
460				"specifying certificate information in the 'cert' auto-auth's config stanza."))
461		}
462
463	}
464
465	// Output the header that the agent has started
466	if !c.flagCombineLogs {
467		c.UI.Output("==> Vault agent started! Log data will stream in below:\n")
468	}
469
470	var leaseCache *cache.LeaseCache
471	var previousToken string
472	// Parse agent listener configurations
473	if config.Cache != nil && len(config.Listeners) != 0 {
474		cacheLogger := c.logger.Named("cache")
475
476		// Create the API proxier
477		apiProxy, err := cache.NewAPIProxy(&cache.APIProxyConfig{
478			Client:                 client,
479			Logger:                 cacheLogger.Named("apiproxy"),
480			EnforceConsistency:     enforceConsistency,
481			WhenInconsistentAction: whenInconsistent,
482		})
483		if err != nil {
484			c.UI.Error(fmt.Sprintf("Error creating API proxy: %v", err))
485			return 1
486		}
487
488		// Create the lease cache proxier and set its underlying proxier to
489		// the API proxier.
490		leaseCache, err = cache.NewLeaseCache(&cache.LeaseCacheConfig{
491			Client:      client,
492			BaseContext: ctx,
493			Proxier:     apiProxy,
494			Logger:      cacheLogger.Named("leasecache"),
495		})
496		if err != nil {
497			c.UI.Error(fmt.Sprintf("Error creating lease cache: %v", err))
498			return 1
499		}
500
501		// Configure persistent storage and add to LeaseCache
502		if config.Cache.Persist != nil {
503			if config.Cache.Persist.Path == "" {
504				c.UI.Error("must specify persistent cache path")
505				return 1
506			}
507
508			// Set AAD based on key protection type
509			var aad string
510			switch config.Cache.Persist.Type {
511			case "kubernetes":
512				aad, err = getServiceAccountJWT(config.Cache.Persist.ServiceAccountTokenFile)
513				if err != nil {
514					c.UI.Error(fmt.Sprintf("failed to read service account token from %s: %s", config.Cache.Persist.ServiceAccountTokenFile, err))
515					return 1
516				}
517			default:
518				c.UI.Error(fmt.Sprintf("persistent key protection type %q not supported", config.Cache.Persist.Type))
519				return 1
520			}
521
522			// Check if bolt file exists already
523			dbFileExists, err := cacheboltdb.DBFileExists(config.Cache.Persist.Path)
524			if err != nil {
525				c.UI.Error(fmt.Sprintf("failed to check if bolt file exists at path %s: %s", config.Cache.Persist.Path, err))
526				return 1
527			}
528			if dbFileExists {
529				// Open the bolt file, but wait to setup Encryption
530				ps, err := cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
531					Path:   config.Cache.Persist.Path,
532					Logger: cacheLogger.Named("cacheboltdb"),
533				})
534				if err != nil {
535					c.UI.Error(fmt.Sprintf("Error opening persistent cache: %v", err))
536					return 1
537				}
538
539				// Get the token from bolt for retrieving the encryption key,
540				// then setup encryption so that restore is possible
541				token, err := ps.GetRetrievalToken()
542				if err != nil {
543					c.UI.Error(fmt.Sprintf("Error getting retrieval token from persistent cache: %v", err))
544				}
545
546				if err := ps.Close(); err != nil {
547					c.UI.Warn(fmt.Sprintf("Failed to close persistent cache file after getting retrieval token: %s", err))
548				}
549
550				km, err := keymanager.NewPassthroughKeyManager(token)
551				if err != nil {
552					c.UI.Error(fmt.Sprintf("failed to configure persistence encryption for cache: %s", err))
553					return 1
554				}
555
556				// Open the bolt file with the wrapper provided
557				ps, err = cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
558					Path:    config.Cache.Persist.Path,
559					Logger:  cacheLogger.Named("cacheboltdb"),
560					Wrapper: km.Wrapper(),
561					AAD:     aad,
562				})
563				if err != nil {
564					c.UI.Error(fmt.Sprintf("Error opening persistent cache: %v", err))
565					return 1
566				}
567
568				// Restore anything in the persistent cache to the memory cache
569				if err := leaseCache.Restore(ctx, ps); err != nil {
570					c.UI.Error(fmt.Sprintf("Error restoring in-memory cache from persisted file: %v", err))
571					if config.Cache.Persist.ExitOnErr {
572						return 1
573					}
574				}
575				cacheLogger.Info("loaded memcache from persistent storage")
576
577				// Check for previous auto-auth token
578				oldTokenBytes, err := ps.GetAutoAuthToken(ctx)
579				if err != nil {
580					c.UI.Error(fmt.Sprintf("Error in fetching previous auto-auth token: %s", err))
581					if config.Cache.Persist.ExitOnErr {
582						return 1
583					}
584				}
585				if len(oldTokenBytes) > 0 {
586					oldToken, err := cachememdb.Deserialize(oldTokenBytes)
587					if err != nil {
588						c.UI.Error(fmt.Sprintf("Error in deserializing previous auto-auth token cache entry: %s", err))
589						if config.Cache.Persist.ExitOnErr {
590							return 1
591						}
592					}
593					previousToken = oldToken.Token
594				}
595
596				// If keep_after_import true, set persistent storage layer in
597				// leaseCache, else remove db file
598				if config.Cache.Persist.KeepAfterImport {
599					defer ps.Close()
600					leaseCache.SetPersistentStorage(ps)
601				} else {
602					if err := ps.Close(); err != nil {
603						c.UI.Warn(fmt.Sprintf("failed to close persistent cache file: %s", err))
604					}
605					dbFile := filepath.Join(config.Cache.Persist.Path, cacheboltdb.DatabaseFileName)
606					if err := os.Remove(dbFile); err != nil {
607						c.UI.Error(fmt.Sprintf("failed to remove persistent storage file %s: %s", dbFile, err))
608						if config.Cache.Persist.ExitOnErr {
609							return 1
610						}
611					}
612				}
613			} else {
614				km, err := keymanager.NewPassthroughKeyManager(nil)
615				if err != nil {
616					c.UI.Error(fmt.Sprintf("failed to configure persistence encryption for cache: %s", err))
617					return 1
618				}
619				ps, err := cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
620					Path:    config.Cache.Persist.Path,
621					Logger:  cacheLogger.Named("cacheboltdb"),
622					Wrapper: km.Wrapper(),
623					AAD:     aad,
624				})
625				if err != nil {
626					c.UI.Error(fmt.Sprintf("Error creating persistent cache: %v", err))
627					return 1
628				}
629				cacheLogger.Info("configured persistent storage", "path", config.Cache.Persist.Path)
630
631				// Stash the key material in bolt
632				token, err := km.RetrievalToken()
633				if err != nil {
634					c.UI.Error(fmt.Sprintf("Error getting persistent key: %s", err))
635					return 1
636				}
637				if err := ps.StoreRetrievalToken(token); err != nil {
638					c.UI.Error(fmt.Sprintf("Error setting key in persistent cache: %v", err))
639					return 1
640				}
641
642				defer ps.Close()
643				leaseCache.SetPersistentStorage(ps)
644			}
645		}
646
647		var inmemSink sink.Sink
648		if config.Cache.UseAutoAuthToken {
649			cacheLogger.Debug("auto-auth token is allowed to be used; configuring inmem sink")
650			inmemSink, err = inmem.New(&sink.SinkConfig{
651				Logger: cacheLogger,
652			}, leaseCache)
653			if err != nil {
654				c.UI.Error(fmt.Sprintf("Error creating inmem sink for cache: %v", err))
655				return 1
656			}
657			sinks = append(sinks, &sink.SinkConfig{
658				Logger: cacheLogger,
659				Sink:   inmemSink,
660			})
661		}
662
663		proxyVaultToken := !config.Cache.ForceAutoAuthToken
664
665		// Create the request handler
666		cacheHandler := cache.Handler(ctx, cacheLogger, leaseCache, inmemSink, proxyVaultToken)
667
668		var listeners []net.Listener
669		for i, lnConfig := range config.Listeners {
670			ln, tlsConf, err := cache.StartListener(lnConfig)
671			if err != nil {
672				c.UI.Error(fmt.Sprintf("Error starting listener: %v", err))
673				return 1
674			}
675
676			listeners = append(listeners, ln)
677
678			// Parse 'require_request_header' listener config option, and wrap
679			// the request handler if necessary
680			muxHandler := cacheHandler
681			if lnConfig.RequireRequestHeader {
682				muxHandler = verifyRequestHeader(muxHandler)
683			}
684
685			// Create a muxer and add paths relevant for the lease cache layer
686			mux := http.NewServeMux()
687			mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx))
688			mux.Handle("/", muxHandler)
689
690			scheme := "https://"
691			if tlsConf == nil {
692				scheme = "http://"
693			}
694			if ln.Addr().Network() == "unix" {
695				scheme = "unix://"
696			}
697
698			infoKey := fmt.Sprintf("api address %d", i+1)
699			info[infoKey] = scheme + ln.Addr().String()
700			infoKeys = append(infoKeys, infoKey)
701
702			server := &http.Server{
703				Addr:              ln.Addr().String(),
704				TLSConfig:         tlsConf,
705				Handler:           mux,
706				ReadHeaderTimeout: 10 * time.Second,
707				ReadTimeout:       30 * time.Second,
708				IdleTimeout:       5 * time.Minute,
709				ErrorLog:          cacheLogger.StandardLogger(nil),
710			}
711
712			go server.Serve(ln)
713		}
714
715		// Ensure that listeners are closed at all the exits
716		listenerCloseFunc := func() {
717			for _, ln := range listeners {
718				ln.Close()
719			}
720		}
721		defer c.cleanupGuard.Do(listenerCloseFunc)
722	}
723
724	// Inform any tests that the server is ready
725	if c.startedCh != nil {
726		close(c.startedCh)
727	}
728
729	// Listen for signals
730	// TODO: implement support for SIGHUP reloading of configuration
731	// signal.Notify(c.signalCh)
732
733	var g run.Group
734
735	// This run group watches for signal termination
736	g.Add(func() error {
737		for {
738			select {
739			case <-c.ShutdownCh:
740				c.UI.Output("==> Vault agent shutdown triggered")
741				// Let the lease cache know this is a shutdown; no need to evict
742				// everything
743				if leaseCache != nil {
744					leaseCache.SetShuttingDown(true)
745				}
746				return nil
747			case <-ctx.Done():
748				return nil
749			case <-winsvc.ShutdownChannel():
750				return nil
751			}
752		}
753	}, func(error) {})
754
755	// Start auto-auth and sink servers
756	if method != nil {
757		enableTokenCh := len(config.Templates) > 0
758		ah := auth.NewAuthHandler(&auth.AuthHandlerConfig{
759			Logger:                       c.logger.Named("auth.handler"),
760			Client:                       c.client,
761			WrapTTL:                      config.AutoAuth.Method.WrapTTL,
762			MaxBackoff:                   config.AutoAuth.Method.MaxBackoff,
763			EnableReauthOnNewCredentials: config.AutoAuth.EnableReauthOnNewCredentials,
764			EnableTemplateTokenCh:        enableTokenCh,
765			Token:                        previousToken,
766		})
767
768		ss := sink.NewSinkServer(&sink.SinkServerConfig{
769			Logger:        c.logger.Named("sink.server"),
770			Client:        client,
771			ExitAfterAuth: exitAfterAuth,
772		})
773
774		ts := template.NewServer(&template.ServerConfig{
775			Logger:        c.logger.Named("template.server"),
776			LogLevel:      level,
777			LogWriter:     c.logWriter,
778			AgentConfig:   config,
779			Namespace:     namespace,
780			ExitAfterAuth: exitAfterAuth,
781		})
782
783		g.Add(func() error {
784			return ah.Run(ctx, method)
785		}, func(error) {
786			// Let the lease cache know this is a shutdown; no need to evict
787			// everything
788			if leaseCache != nil {
789				leaseCache.SetShuttingDown(true)
790			}
791			cancelFunc()
792		})
793
794		g.Add(func() error {
795			err := ss.Run(ctx, ah.OutputCh, sinks)
796			c.logger.Info("sinks finished, exiting")
797
798			// Start goroutine to drain from ah.OutputCh from this point onward
799			// to prevent ah.Run from being blocked.
800			go func() {
801				for {
802					select {
803					case <-ctx.Done():
804						return
805					case <-ah.OutputCh:
806					}
807				}
808			}()
809
810			// Wait until templates are rendered
811			if len(config.Templates) > 0 {
812				<-ts.DoneCh
813			}
814
815			return err
816		}, func(error) {
817			// Let the lease cache know this is a shutdown; no need to evict
818			// everything
819			if leaseCache != nil {
820				leaseCache.SetShuttingDown(true)
821			}
822			cancelFunc()
823		})
824
825		g.Add(func() error {
826			return ts.Run(ctx, ah.TemplateTokenCh, config.Templates)
827		}, func(error) {
828			// Let the lease cache know this is a shutdown; no need to evict
829			// everything
830			if leaseCache != nil {
831				leaseCache.SetShuttingDown(true)
832			}
833			cancelFunc()
834			ts.Stop()
835		})
836
837	}
838
839	// Server configuration output
840	padding := 24
841	sort.Strings(infoKeys)
842	c.UI.Output("==> Vault agent configuration:\n")
843	for _, k := range infoKeys {
844		c.UI.Output(fmt.Sprintf(
845			"%s%s: %s",
846			strings.Repeat(" ", padding-len(k)),
847			strings.Title(k),
848			info[k]))
849	}
850	c.UI.Output("")
851
852	// Release the log gate.
853	c.logGate.Flush()
854
855	// Write out the PID to the file now that server has successfully started
856	if err := c.storePidFile(config.PidFile); err != nil {
857		c.UI.Error(fmt.Sprintf("Error storing PID: %s", err))
858		return 1
859	}
860
861	defer func() {
862		if err := c.removePidFile(config.PidFile); err != nil {
863			c.UI.Error(fmt.Sprintf("Error deleting the PID file: %s", err))
864		}
865	}()
866
867	if err := g.Run(); err != nil {
868		c.logger.Error("runtime error encountered", "error", err)
869		c.UI.Error("Error encountered during run, refer to logs for more details.")
870		return 1
871	}
872
873	return 0
874}
875
876// verifyRequestHeader wraps an http.Handler inside a Handler that checks for
877// the request header that is used for SSRF protection.
878func verifyRequestHeader(handler http.Handler) http.Handler {
879	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
880		if val, ok := r.Header[consts.RequestHeaderName]; !ok || len(val) != 1 || val[0] != "true" {
881			logical.RespondError(w,
882				http.StatusPreconditionFailed,
883				errors.New(fmt.Sprintf("missing '%s' header", consts.RequestHeaderName)))
884			return
885		}
886
887		handler.ServeHTTP(w, r)
888	})
889}
890
891func (c *AgentCommand) setStringFlag(f *FlagSets, configVal string, fVar *StringVar) {
892	var isFlagSet bool
893	f.Visit(func(f *flag.Flag) {
894		if f.Name == fVar.Name {
895			isFlagSet = true
896		}
897	})
898
899	flagEnvValue, flagEnvSet := os.LookupEnv(fVar.EnvVar)
900	switch {
901	case isFlagSet:
902		// Don't do anything as the flag is already set from the command line
903	case flagEnvSet:
904		// Use value from env var
905		*fVar.Target = flagEnvValue
906	case configVal != "":
907		// Use value from config
908		*fVar.Target = configVal
909	default:
910		// Use the default value
911		*fVar.Target = fVar.Default
912	}
913}
914
915func (c *AgentCommand) setBoolFlag(f *FlagSets, configVal bool, fVar *BoolVar) {
916	var isFlagSet bool
917	f.Visit(func(f *flag.Flag) {
918		if f.Name == fVar.Name {
919			isFlagSet = true
920		}
921	})
922
923	flagEnvValue, flagEnvSet := os.LookupEnv(fVar.EnvVar)
924	switch {
925	case isFlagSet:
926		// Don't do anything as the flag is already set from the command line
927	case flagEnvSet:
928		// Use value from env var
929		*fVar.Target = flagEnvValue != ""
930	case configVal == true:
931		// Use value from config
932		*fVar.Target = configVal
933	default:
934		// Use the default value
935		*fVar.Target = fVar.Default
936	}
937}
938
939// storePidFile is used to write out our PID to a file if necessary
940func (c *AgentCommand) storePidFile(pidPath string) error {
941	// Quit fast if no pidfile
942	if pidPath == "" {
943		return nil
944	}
945
946	// Open the PID file
947	pidFile, err := os.OpenFile(pidPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
948	if err != nil {
949		return fmt.Errorf("could not open pid file: %w", err)
950	}
951	defer pidFile.Close()
952
953	// Write out the PID
954	pid := os.Getpid()
955	_, err = pidFile.WriteString(fmt.Sprintf("%d", pid))
956	if err != nil {
957		return fmt.Errorf("could not write to pid file: %w", err)
958	}
959	return nil
960}
961
962// removePidFile is used to cleanup the PID file if necessary
963func (c *AgentCommand) removePidFile(pidPath string) error {
964	if pidPath == "" {
965		return nil
966	}
967	return os.Remove(pidPath)
968}
969
970// GetServiceAccountJWT reads the service account jwt from `tokenFile`. Default is
971// the default service account file path in kubernetes.
972func getServiceAccountJWT(tokenFile string) (string, error) {
973	if len(tokenFile) == 0 {
974		tokenFile = "/var/run/secrets/kubernetes.io/serviceaccount/token"
975	}
976	token, err := ioutil.ReadFile(tokenFile)
977	if err != nil {
978		return "", err
979	}
980	return strings.TrimSpace(string(token)), nil
981}
982