1package consul
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"io/ioutil"
8	"net"
9	"net/http"
10	"net/url"
11	"regexp"
12	"strconv"
13	"strings"
14	"sync"
15	"sync/atomic"
16	"time"
17
18	"golang.org/x/net/http2"
19
20	log "github.com/hashicorp/go-hclog"
21
22	"crypto/tls"
23	"crypto/x509"
24
25	metrics "github.com/armon/go-metrics"
26	"github.com/hashicorp/consul/api"
27	"github.com/hashicorp/consul/lib"
28	"github.com/hashicorp/errwrap"
29	multierror "github.com/hashicorp/go-multierror"
30	"github.com/hashicorp/vault/helper/consts"
31	"github.com/hashicorp/vault/helper/parseutil"
32	"github.com/hashicorp/vault/helper/strutil"
33	"github.com/hashicorp/vault/helper/tlsutil"
34	"github.com/hashicorp/vault/physical"
35)
36
37const (
38	// checkJitterFactor specifies the jitter factor used to stagger checks
39	checkJitterFactor = 16
40
41	// checkMinBuffer specifies provides a guarantee that a check will not
42	// be executed too close to the TTL check timeout
43	checkMinBuffer = 100 * time.Millisecond
44
45	// consulRetryInterval specifies the retry duration to use when an
46	// API call to the Consul agent fails.
47	consulRetryInterval = 1 * time.Second
48
49	// defaultCheckTimeout changes the timeout of TTL checks
50	defaultCheckTimeout = 5 * time.Second
51
52	// DefaultServiceName is the default Consul service name used when
53	// advertising a Vault instance.
54	DefaultServiceName = "vault"
55
56	// reconcileTimeout is how often Vault should query Consul to detect
57	// and fix any state drift.
58	reconcileTimeout = 60 * time.Second
59
60	// consistencyModeDefault is the configuration value used to tell
61	// consul to use default consistency.
62	consistencyModeDefault = "default"
63
64	// consistencyModeStrong is the configuration value used to tell
65	// consul to use strong consistency.
66	consistencyModeStrong = "strong"
67)
68
69type notifyEvent struct{}
70
71// Verify ConsulBackend satisfies the correct interfaces
72var _ physical.Backend = (*ConsulBackend)(nil)
73var _ physical.HABackend = (*ConsulBackend)(nil)
74var _ physical.Lock = (*ConsulLock)(nil)
75var _ physical.Transactional = (*ConsulBackend)(nil)
76var _ physical.ServiceDiscovery = (*ConsulBackend)(nil)
77
78var (
79	hostnameRegex = regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
80)
81
82// ConsulBackend is a physical backend that stores data at specific
83// prefix within Consul. It is used for most production situations as
84// it allows Vault to run on multiple machines in a highly-available manner.
85type ConsulBackend struct {
86	path                string
87	logger              log.Logger
88	client              *api.Client
89	kv                  *api.KV
90	permitPool          *physical.PermitPool
91	serviceLock         sync.RWMutex
92	redirectHost        string
93	redirectPort        int64
94	serviceName         string
95	serviceTags         []string
96	serviceAddress      *string
97	disableRegistration bool
98	checkTimeout        time.Duration
99	consistencyMode     string
100
101	notifyActiveCh      chan notifyEvent
102	notifySealedCh      chan notifyEvent
103	notifyPerfStandbyCh chan notifyEvent
104
105	sessionTTL   string
106	lockWaitTime time.Duration
107}
108
109// NewConsulBackend constructs a Consul backend using the given API client
110// and the prefix in the KV store.
111func NewConsulBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
112	// Get the path in Consul
113	path, ok := conf["path"]
114	if !ok {
115		path = "vault/"
116	}
117	if logger.IsDebug() {
118		logger.Debug("config path set", "path", path)
119	}
120
121	// Ensure path is suffixed but not prefixed
122	if !strings.HasSuffix(path, "/") {
123		logger.Warn("appending trailing forward slash to path")
124		path += "/"
125	}
126	if strings.HasPrefix(path, "/") {
127		logger.Warn("trimming path of its forward slash")
128		path = strings.TrimPrefix(path, "/")
129	}
130
131	// Allow admins to disable consul integration
132	disableReg, ok := conf["disable_registration"]
133	var disableRegistration bool
134	if ok && disableReg != "" {
135		b, err := parseutil.ParseBool(disableReg)
136		if err != nil {
137			return nil, errwrap.Wrapf("failed parsing disable_registration parameter: {{err}}", err)
138		}
139		disableRegistration = b
140	}
141	if logger.IsDebug() {
142		logger.Debug("config disable_registration set", "disable_registration", disableRegistration)
143	}
144
145	// Get the service name to advertise in Consul
146	service, ok := conf["service"]
147	if !ok {
148		service = DefaultServiceName
149	}
150	if !hostnameRegex.MatchString(service) {
151		return nil, errors.New("service name must be valid per RFC 1123 and can contain only alphanumeric characters or dashes")
152	}
153	if logger.IsDebug() {
154		logger.Debug("config service set", "service", service)
155	}
156
157	// Get the additional tags to attach to the registered service name
158	tags := conf["service_tags"]
159	if logger.IsDebug() {
160		logger.Debug("config service_tags set", "service_tags", tags)
161	}
162
163	// Get the service-specific address to override the use of the HA redirect address
164	var serviceAddr *string
165	serviceAddrStr, ok := conf["service_address"]
166	if ok {
167		serviceAddr = &serviceAddrStr
168	}
169	if logger.IsDebug() {
170		logger.Debug("config service_address set", "service_address", serviceAddr)
171	}
172
173	checkTimeout := defaultCheckTimeout
174	checkTimeoutStr, ok := conf["check_timeout"]
175	if ok {
176		d, err := parseutil.ParseDurationSecond(checkTimeoutStr)
177		if err != nil {
178			return nil, err
179		}
180
181		min, _ := lib.DurationMinusBufferDomain(d, checkMinBuffer, checkJitterFactor)
182		if min < checkMinBuffer {
183			return nil, fmt.Errorf("consul check_timeout must be greater than %v", min)
184		}
185
186		checkTimeout = d
187		if logger.IsDebug() {
188			logger.Debug("config check_timeout set", "check_timeout", d)
189		}
190	}
191
192	sessionTTL := api.DefaultLockSessionTTL
193	sessionTTLStr, ok := conf["session_ttl"]
194	if ok {
195		_, err := parseutil.ParseDurationSecond(sessionTTLStr)
196		if err != nil {
197			return nil, errwrap.Wrapf("invalid session_ttl: {{err}}", err)
198		}
199		sessionTTL = sessionTTLStr
200		if logger.IsDebug() {
201			logger.Debug("config session_ttl set", "session_ttl", sessionTTL)
202		}
203	}
204
205	lockWaitTime := api.DefaultLockWaitTime
206	lockWaitTimeRaw, ok := conf["lock_wait_time"]
207	if ok {
208		d, err := parseutil.ParseDurationSecond(lockWaitTimeRaw)
209		if err != nil {
210			return nil, errwrap.Wrapf("invalid lock_wait_time: {{err}}", err)
211		}
212		lockWaitTime = d
213		if logger.IsDebug() {
214			logger.Debug("config lock_wait_time set", "lock_wait_time", d)
215		}
216	}
217
218	// Configure the client
219	consulConf := api.DefaultConfig()
220	// Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore
221	consulConf.Transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
222
223	if addr, ok := conf["address"]; ok {
224		consulConf.Address = addr
225		if logger.IsDebug() {
226			logger.Debug("config address set", "address", addr)
227		}
228	}
229	if scheme, ok := conf["scheme"]; ok {
230		consulConf.Scheme = scheme
231		if logger.IsDebug() {
232			logger.Debug("config scheme set", "scheme", scheme)
233		}
234	}
235	if token, ok := conf["token"]; ok {
236		consulConf.Token = token
237		logger.Debug("config token set")
238	}
239
240	if consulConf.Scheme == "https" {
241		tlsClientConfig, err := setupTLSConfig(conf)
242		if err != nil {
243			return nil, err
244		}
245
246		consulConf.Transport.TLSClientConfig = tlsClientConfig
247		if err := http2.ConfigureTransport(consulConf.Transport); err != nil {
248			return nil, err
249		}
250		logger.Debug("configured TLS")
251	}
252
253	consulConf.HttpClient = &http.Client{Transport: consulConf.Transport}
254	client, err := api.NewClient(consulConf)
255	if err != nil {
256		return nil, errwrap.Wrapf("client setup failed: {{err}}", err)
257	}
258
259	maxParStr, ok := conf["max_parallel"]
260	var maxParInt int
261	if ok {
262		maxParInt, err = strconv.Atoi(maxParStr)
263		if err != nil {
264			return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
265		}
266		if logger.IsDebug() {
267			logger.Debug("max_parallel set", "max_parallel", maxParInt)
268		}
269	}
270
271	consistencyMode, ok := conf["consistency_mode"]
272	if ok {
273		switch consistencyMode {
274		case consistencyModeDefault, consistencyModeStrong:
275		default:
276			return nil, fmt.Errorf("invalid consistency_mode value: %q", consistencyMode)
277		}
278	} else {
279		consistencyMode = consistencyModeDefault
280	}
281
282	// Setup the backend
283	c := &ConsulBackend{
284		path:                path,
285		logger:              logger,
286		client:              client,
287		kv:                  client.KV(),
288		permitPool:          physical.NewPermitPool(maxParInt),
289		serviceName:         service,
290		serviceTags:         strutil.ParseDedupLowercaseAndSortStrings(tags, ","),
291		serviceAddress:      serviceAddr,
292		checkTimeout:        checkTimeout,
293		disableRegistration: disableRegistration,
294		consistencyMode:     consistencyMode,
295		notifyActiveCh:      make(chan notifyEvent),
296		notifySealedCh:      make(chan notifyEvent),
297		notifyPerfStandbyCh: make(chan notifyEvent),
298		sessionTTL:          sessionTTL,
299		lockWaitTime:        lockWaitTime,
300	}
301	return c, nil
302}
303
304func setupTLSConfig(conf map[string]string) (*tls.Config, error) {
305	serverName, _, err := net.SplitHostPort(conf["address"])
306	switch {
307	case err == nil:
308	case strings.Contains(err.Error(), "missing port"):
309		serverName = conf["address"]
310	default:
311		return nil, err
312	}
313
314	insecureSkipVerify := false
315	tlsSkipVerify, ok := conf["tls_skip_verify"]
316
317	if ok && tlsSkipVerify != "" {
318		b, err := parseutil.ParseBool(tlsSkipVerify)
319		if err != nil {
320			return nil, errwrap.Wrapf("failed parsing tls_skip_verify parameter: {{err}}", err)
321		}
322		insecureSkipVerify = b
323	}
324
325	tlsMinVersionStr, ok := conf["tls_min_version"]
326	if !ok {
327		// Set the default value
328		tlsMinVersionStr = "tls12"
329	}
330
331	tlsMinVersion, ok := tlsutil.TLSLookup[tlsMinVersionStr]
332	if !ok {
333		return nil, fmt.Errorf("invalid 'tls_min_version'")
334	}
335
336	tlsClientConfig := &tls.Config{
337		MinVersion:         tlsMinVersion,
338		InsecureSkipVerify: insecureSkipVerify,
339		ServerName:         serverName,
340	}
341
342	_, okCert := conf["tls_cert_file"]
343	_, okKey := conf["tls_key_file"]
344
345	if okCert && okKey {
346		tlsCert, err := tls.LoadX509KeyPair(conf["tls_cert_file"], conf["tls_key_file"])
347		if err != nil {
348			return nil, errwrap.Wrapf("client tls setup failed: {{err}}", err)
349		}
350
351		tlsClientConfig.Certificates = []tls.Certificate{tlsCert}
352	}
353
354	if tlsCaFile, ok := conf["tls_ca_file"]; ok {
355		caPool := x509.NewCertPool()
356
357		data, err := ioutil.ReadFile(tlsCaFile)
358		if err != nil {
359			return nil, errwrap.Wrapf("failed to read CA file: {{err}}", err)
360		}
361
362		if !caPool.AppendCertsFromPEM(data) {
363			return nil, fmt.Errorf("failed to parse CA certificate")
364		}
365
366		tlsClientConfig.RootCAs = caPool
367	}
368
369	return tlsClientConfig, nil
370}
371
372// Used to run multiple entries via a transaction
373func (c *ConsulBackend) Transaction(ctx context.Context, txns []*physical.TxnEntry) error {
374	if len(txns) == 0 {
375		return nil
376	}
377
378	ops := make([]*api.KVTxnOp, 0, len(txns))
379
380	for _, op := range txns {
381		cop := &api.KVTxnOp{
382			Key: c.path + op.Entry.Key,
383		}
384		switch op.Operation {
385		case physical.DeleteOperation:
386			cop.Verb = api.KVDelete
387		case physical.PutOperation:
388			cop.Verb = api.KVSet
389			cop.Value = op.Entry.Value
390		default:
391			return fmt.Errorf("%q is not a supported transaction operation", op.Operation)
392		}
393
394		ops = append(ops, cop)
395	}
396
397	c.permitPool.Acquire()
398	defer c.permitPool.Release()
399
400	queryOpts := &api.QueryOptions{}
401	queryOpts = queryOpts.WithContext(ctx)
402
403	ok, resp, _, err := c.kv.Txn(ops, queryOpts)
404	if err != nil {
405		return err
406	}
407	if ok && len(resp.Errors) == 0 {
408		return nil
409	}
410
411	var retErr *multierror.Error
412	for _, res := range resp.Errors {
413		retErr = multierror.Append(retErr, errors.New(res.What))
414	}
415
416	return retErr
417}
418
419// Put is used to insert or update an entry
420func (c *ConsulBackend) Put(ctx context.Context, entry *physical.Entry) error {
421	defer metrics.MeasureSince([]string{"consul", "put"}, time.Now())
422
423	c.permitPool.Acquire()
424	defer c.permitPool.Release()
425
426	pair := &api.KVPair{
427		Key:   c.path + entry.Key,
428		Value: entry.Value,
429	}
430
431	writeOpts := &api.WriteOptions{}
432	writeOpts = writeOpts.WithContext(ctx)
433
434	_, err := c.kv.Put(pair, writeOpts)
435	return err
436}
437
438// Get is used to fetch an entry
439func (c *ConsulBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
440	defer metrics.MeasureSince([]string{"consul", "get"}, time.Now())
441
442	c.permitPool.Acquire()
443	defer c.permitPool.Release()
444
445	queryOpts := &api.QueryOptions{}
446	queryOpts = queryOpts.WithContext(ctx)
447
448	if c.consistencyMode == consistencyModeStrong {
449		queryOpts.RequireConsistent = true
450	}
451
452	pair, _, err := c.kv.Get(c.path+key, queryOpts)
453	if err != nil {
454		return nil, err
455	}
456	if pair == nil {
457		return nil, nil
458	}
459	ent := &physical.Entry{
460		Key:   key,
461		Value: pair.Value,
462	}
463	return ent, nil
464}
465
466// Delete is used to permanently delete an entry
467func (c *ConsulBackend) Delete(ctx context.Context, key string) error {
468	defer metrics.MeasureSince([]string{"consul", "delete"}, time.Now())
469
470	c.permitPool.Acquire()
471	defer c.permitPool.Release()
472
473	writeOpts := &api.WriteOptions{}
474	writeOpts = writeOpts.WithContext(ctx)
475
476	_, err := c.kv.Delete(c.path+key, writeOpts)
477	return err
478}
479
480// List is used to list all the keys under a given
481// prefix, up to the next prefix.
482func (c *ConsulBackend) List(ctx context.Context, prefix string) ([]string, error) {
483	defer metrics.MeasureSince([]string{"consul", "list"}, time.Now())
484	scan := c.path + prefix
485
486	// The TrimPrefix call below will not work correctly if we have "//" at the
487	// end. This can happen in cases where you are e.g. listing the root of a
488	// prefix in a logical backend via "/" instead of ""
489	if strings.HasSuffix(scan, "//") {
490		scan = scan[:len(scan)-1]
491	}
492
493	c.permitPool.Acquire()
494	defer c.permitPool.Release()
495
496	queryOpts := &api.QueryOptions{}
497	queryOpts = queryOpts.WithContext(ctx)
498
499	out, _, err := c.kv.Keys(scan, "/", queryOpts)
500	for idx, val := range out {
501		out[idx] = strings.TrimPrefix(val, scan)
502	}
503
504	return out, err
505}
506
507// Lock is used for mutual exclusion based on the given key.
508func (c *ConsulBackend) LockWith(key, value string) (physical.Lock, error) {
509	// Create the lock
510	opts := &api.LockOptions{
511		Key:            c.path + key,
512		Value:          []byte(value),
513		SessionName:    "Vault Lock",
514		MonitorRetries: 5,
515		SessionTTL:     c.sessionTTL,
516		LockWaitTime:   c.lockWaitTime,
517	}
518	lock, err := c.client.LockOpts(opts)
519	if err != nil {
520		return nil, errwrap.Wrapf("failed to create lock: {{err}}", err)
521	}
522	cl := &ConsulLock{
523		client:          c.client,
524		key:             c.path + key,
525		lock:            lock,
526		consistencyMode: c.consistencyMode,
527	}
528	return cl, nil
529}
530
531// HAEnabled indicates whether the HA functionality should be exposed.
532// Currently always returns true.
533func (c *ConsulBackend) HAEnabled() bool {
534	return true
535}
536
537// DetectHostAddr is used to detect the host address by asking the Consul agent
538func (c *ConsulBackend) DetectHostAddr() (string, error) {
539	agent := c.client.Agent()
540	self, err := agent.Self()
541	if err != nil {
542		return "", err
543	}
544	addr, ok := self["Member"]["Addr"].(string)
545	if !ok {
546		return "", fmt.Errorf("unable to convert an address to string")
547	}
548	return addr, nil
549}
550
551// ConsulLock is used to provide the Lock interface backed by Consul
552type ConsulLock struct {
553	client          *api.Client
554	key             string
555	lock            *api.Lock
556	consistencyMode string
557}
558
559func (c *ConsulLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
560	return c.lock.Lock(stopCh)
561}
562
563func (c *ConsulLock) Unlock() error {
564	return c.lock.Unlock()
565}
566
567func (c *ConsulLock) Value() (bool, string, error) {
568	kv := c.client.KV()
569
570	var queryOptions *api.QueryOptions
571	if c.consistencyMode == consistencyModeStrong {
572		queryOptions = &api.QueryOptions{
573			RequireConsistent: true,
574		}
575	}
576
577	pair, _, err := kv.Get(c.key, queryOptions)
578	if err != nil {
579		return false, "", err
580	}
581	if pair == nil {
582		return false, "", nil
583	}
584	held := pair.Session != ""
585	value := string(pair.Value)
586	return held, value, nil
587}
588
589func (c *ConsulBackend) NotifyActiveStateChange() error {
590	select {
591	case c.notifyActiveCh <- notifyEvent{}:
592	default:
593		// NOTE: If this occurs Vault's active status could be out of
594		// sync with Consul until reconcileTimer expires.
595		c.logger.Warn("concurrent state change notify dropped")
596	}
597
598	return nil
599}
600
601func (c *ConsulBackend) NotifyPerformanceStandbyStateChange() error {
602	select {
603	case c.notifyPerfStandbyCh <- notifyEvent{}:
604	default:
605		// NOTE: If this occurs Vault's active status could be out of
606		// sync with Consul until reconcileTimer expires.
607		c.logger.Warn("concurrent state change notify dropped")
608	}
609
610	return nil
611}
612
613func (c *ConsulBackend) NotifySealedStateChange() error {
614	select {
615	case c.notifySealedCh <- notifyEvent{}:
616	default:
617		// NOTE: If this occurs Vault's sealed status could be out of
618		// sync with Consul until checkTimer expires.
619		c.logger.Warn("concurrent sealed state change notify dropped")
620	}
621
622	return nil
623}
624
625func (c *ConsulBackend) checkDuration() time.Duration {
626	return lib.DurationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor)
627}
628
629func (c *ConsulBackend) RunServiceDiscovery(waitGroup *sync.WaitGroup, shutdownCh physical.ShutdownChannel, redirectAddr string, activeFunc physical.ActiveFunction, sealedFunc physical.SealedFunction, perfStandbyFunc physical.PerformanceStandbyFunction) (err error) {
630	if err := c.setRedirectAddr(redirectAddr); err != nil {
631		return err
632	}
633
634	// 'server' command will wait for the below goroutine to complete
635	waitGroup.Add(1)
636
637	go c.runEventDemuxer(waitGroup, shutdownCh, redirectAddr, activeFunc, sealedFunc, perfStandbyFunc)
638
639	return nil
640}
641
642func (c *ConsulBackend) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh physical.ShutdownChannel, redirectAddr string, activeFunc physical.ActiveFunction, sealedFunc physical.SealedFunction, perfStandbyFunc physical.PerformanceStandbyFunction) {
643	// This defer statement should be executed last. So push it first.
644	defer waitGroup.Done()
645
646	// Fire the reconcileTimer immediately upon starting the event demuxer
647	reconcileTimer := time.NewTimer(0)
648	defer reconcileTimer.Stop()
649
650	// Schedule the first check.  Consul TTL checks are passing by
651	// default, checkTimer does not need to be run immediately.
652	checkTimer := time.NewTimer(c.checkDuration())
653	defer checkTimer.Stop()
654
655	// Use a reactor pattern to handle and dispatch events to singleton
656	// goroutine handlers for execution.  It is not acceptable to drop
657	// inbound events from Notify*().
658	//
659	// goroutines are dispatched if the demuxer can acquire a lock (via
660	// an atomic CAS incr) on the handler.  Handlers are responsible for
661	// deregistering themselves (atomic CAS decr).  Handlers and the
662	// demuxer share a lock to synchronize information at the beginning
663	// and end of a handler's life (or after a handler wakes up from
664	// sleeping during a back-off/retry).
665	var shutdown bool
666	var registeredServiceID string
667	checkLock := new(int32)
668	serviceRegLock := new(int32)
669
670	for !shutdown {
671		select {
672		case <-c.notifyActiveCh:
673			// Run reconcile immediately upon active state change notification
674			reconcileTimer.Reset(0)
675		case <-c.notifySealedCh:
676			// Run check timer immediately upon a seal state change notification
677			checkTimer.Reset(0)
678		case <-c.notifyPerfStandbyCh:
679			// Run check timer immediately upon a seal state change notification
680			checkTimer.Reset(0)
681		case <-reconcileTimer.C:
682			// Unconditionally rearm the reconcileTimer
683			reconcileTimer.Reset(reconcileTimeout - lib.RandomStagger(reconcileTimeout/checkJitterFactor))
684
685			// Abort if service discovery is disabled or a
686			// reconcile handler is already active
687			if !c.disableRegistration && atomic.CompareAndSwapInt32(serviceRegLock, 0, 1) {
688				// Enter handler with serviceRegLock held
689				go func() {
690					defer atomic.CompareAndSwapInt32(serviceRegLock, 1, 0)
691					for !shutdown {
692						serviceID, err := c.reconcileConsul(registeredServiceID, activeFunc, sealedFunc, perfStandbyFunc)
693						if err != nil {
694							if c.logger.IsWarn() {
695								c.logger.Warn("reconcile unable to talk with Consul backend", "error", err)
696							}
697							time.Sleep(consulRetryInterval)
698							continue
699						}
700
701						c.serviceLock.Lock()
702						defer c.serviceLock.Unlock()
703
704						registeredServiceID = serviceID
705						return
706					}
707				}()
708			}
709		case <-checkTimer.C:
710			checkTimer.Reset(c.checkDuration())
711			// Abort if service discovery is disabled or a
712			// reconcile handler is active
713			if !c.disableRegistration && atomic.CompareAndSwapInt32(checkLock, 0, 1) {
714				// Enter handler with checkLock held
715				go func() {
716					defer atomic.CompareAndSwapInt32(checkLock, 1, 0)
717					for !shutdown {
718						sealed := sealedFunc()
719						if err := c.runCheck(sealed); err != nil {
720							if c.logger.IsWarn() {
721								c.logger.Warn("check unable to talk with Consul backend", "error", err)
722							}
723							time.Sleep(consulRetryInterval)
724							continue
725						}
726						return
727					}
728				}()
729			}
730		case <-shutdownCh:
731			c.logger.Info("shutting down consul backend")
732			shutdown = true
733		}
734	}
735
736	c.serviceLock.RLock()
737	defer c.serviceLock.RUnlock()
738	if err := c.client.Agent().ServiceDeregister(registeredServiceID); err != nil {
739		if c.logger.IsWarn() {
740			c.logger.Warn("service deregistration failed", "error", err)
741		}
742	}
743}
744
745// checkID returns the ID used for a Consul Check.  Assume at least a read
746// lock is held.
747func (c *ConsulBackend) checkID() string {
748	return fmt.Sprintf("%s:vault-sealed-check", c.serviceID())
749}
750
751// serviceID returns the Vault ServiceID for use in Consul.  Assume at least
752// a read lock is held.
753func (c *ConsulBackend) serviceID() string {
754	return fmt.Sprintf("%s:%s:%d", c.serviceName, c.redirectHost, c.redirectPort)
755}
756
757// reconcileConsul queries the state of Vault Core and Consul and fixes up
758// Consul's state according to what's in Vault.  reconcileConsul is called
759// without any locks held and can be run concurrently, therefore no changes
760// to ConsulBackend can be made in this method (i.e. wtb const receiver for
761// compiler enforced safety).
762func (c *ConsulBackend) reconcileConsul(registeredServiceID string, activeFunc physical.ActiveFunction, sealedFunc physical.SealedFunction, perfStandbyFunc physical.PerformanceStandbyFunction) (serviceID string, err error) {
763	// Query vault Core for its current state
764	active := activeFunc()
765	sealed := sealedFunc()
766	perfStandby := perfStandbyFunc()
767
768	agent := c.client.Agent()
769	catalog := c.client.Catalog()
770
771	serviceID = c.serviceID()
772
773	// Get the current state of Vault from Consul
774	var currentVaultService *api.CatalogService
775	if services, _, err := catalog.Service(c.serviceName, "", &api.QueryOptions{AllowStale: true}); err == nil {
776		for _, service := range services {
777			if serviceID == service.ServiceID {
778				currentVaultService = service
779				break
780			}
781		}
782	}
783
784	tags := c.fetchServiceTags(active, perfStandby)
785
786	var reregister bool
787
788	switch {
789	case currentVaultService == nil, registeredServiceID == "":
790		reregister = true
791	default:
792		switch {
793		case !strutil.EquivalentSlices(currentVaultService.ServiceTags, tags):
794			reregister = true
795		}
796	}
797
798	if !reregister {
799		// When re-registration is not required, return a valid serviceID
800		// to avoid registration in the next cycle.
801		return serviceID, nil
802	}
803
804	// If service address was set explicitly in configuration, use that
805	// as the service-specific address instead of the HA redirect address.
806	var serviceAddress string
807	if c.serviceAddress == nil {
808		serviceAddress = c.redirectHost
809	} else {
810		serviceAddress = *c.serviceAddress
811	}
812
813	service := &api.AgentServiceRegistration{
814		ID:                serviceID,
815		Name:              c.serviceName,
816		Tags:              tags,
817		Port:              int(c.redirectPort),
818		Address:           serviceAddress,
819		EnableTagOverride: false,
820	}
821
822	checkStatus := api.HealthCritical
823	if !sealed {
824		checkStatus = api.HealthPassing
825	}
826
827	sealedCheck := &api.AgentCheckRegistration{
828		ID:        c.checkID(),
829		Name:      "Vault Sealed Status",
830		Notes:     "Vault service is healthy when Vault is in an unsealed status and can become an active Vault server",
831		ServiceID: serviceID,
832		AgentServiceCheck: api.AgentServiceCheck{
833			TTL:    c.checkTimeout.String(),
834			Status: checkStatus,
835		},
836	}
837
838	if err := agent.ServiceRegister(service); err != nil {
839		return "", errwrap.Wrapf(`service registration failed: {{err}}`, err)
840	}
841
842	if err := agent.CheckRegister(sealedCheck); err != nil {
843		return serviceID, errwrap.Wrapf(`service check registration failed: {{err}}`, err)
844	}
845
846	return serviceID, nil
847}
848
849// runCheck immediately pushes a TTL check.
850func (c *ConsulBackend) runCheck(sealed bool) error {
851	// Run a TTL check
852	agent := c.client.Agent()
853	if !sealed {
854		return agent.PassTTL(c.checkID(), "Vault Unsealed")
855	} else {
856		return agent.FailTTL(c.checkID(), "Vault Sealed")
857	}
858}
859
860// fetchServiceTags returns all of the relevant tags for Consul.
861func (c *ConsulBackend) fetchServiceTags(active bool, perfStandby bool) []string {
862	activeTag := "standby"
863	if active {
864		activeTag = "active"
865	}
866
867	result := append(c.serviceTags, activeTag)
868
869	if perfStandby {
870		result = append(c.serviceTags, "performance-standby")
871	}
872
873	return result
874}
875
876func (c *ConsulBackend) setRedirectAddr(addr string) (err error) {
877	if addr == "" {
878		return fmt.Errorf("redirect address must not be empty")
879	}
880
881	url, err := url.Parse(addr)
882	if err != nil {
883		return errwrap.Wrapf(fmt.Sprintf("failed to parse redirect URL %q: {{err}}", addr), err)
884	}
885
886	var portStr string
887	c.redirectHost, portStr, err = net.SplitHostPort(url.Host)
888	if err != nil {
889		if url.Scheme == "http" {
890			portStr = "80"
891		} else if url.Scheme == "https" {
892			portStr = "443"
893		} else if url.Scheme == "unix" {
894			portStr = "-1"
895			c.redirectHost = url.Path
896		} else {
897			return errwrap.Wrapf(fmt.Sprintf(`failed to find a host:port in redirect address "%v": {{err}}`, url.Host), err)
898		}
899	}
900	c.redirectPort, err = strconv.ParseInt(portStr, 10, 0)
901	if err != nil || c.redirectPort < -1 || c.redirectPort > 65535 {
902		return errwrap.Wrapf(fmt.Sprintf(`failed to parse valid port "%v": {{err}}`, portStr), err)
903	}
904
905	return nil
906}
907