1package consul
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"io/ioutil"
8	"net"
9	"net/http"
10	"net/url"
11	"regexp"
12	"strconv"
13	"strings"
14	"sync"
15	"sync/atomic"
16	"time"
17
18	"golang.org/x/net/http2"
19
20	log "github.com/hashicorp/go-hclog"
21
22	"crypto/tls"
23	"crypto/x509"
24
25	metrics "github.com/armon/go-metrics"
26	"github.com/hashicorp/consul/api"
27	"github.com/hashicorp/errwrap"
28	multierror "github.com/hashicorp/go-multierror"
29	"github.com/hashicorp/vault/sdk/helper/consts"
30	"github.com/hashicorp/vault/sdk/helper/parseutil"
31	"github.com/hashicorp/vault/sdk/helper/strutil"
32	"github.com/hashicorp/vault/sdk/helper/tlsutil"
33	"github.com/hashicorp/vault/sdk/physical"
34)
35
36const (
37	// checkJitterFactor specifies the jitter factor used to stagger checks
38	checkJitterFactor = 16
39
40	// checkMinBuffer specifies provides a guarantee that a check will not
41	// be executed too close to the TTL check timeout
42	checkMinBuffer = 100 * time.Millisecond
43
44	// consulRetryInterval specifies the retry duration to use when an
45	// API call to the Consul agent fails.
46	consulRetryInterval = 1 * time.Second
47
48	// defaultCheckTimeout changes the timeout of TTL checks
49	defaultCheckTimeout = 5 * time.Second
50
51	// DefaultServiceName is the default Consul service name used when
52	// advertising a Vault instance.
53	DefaultServiceName = "vault"
54
55	// reconcileTimeout is how often Vault should query Consul to detect
56	// and fix any state drift.
57	reconcileTimeout = 60 * time.Second
58
59	// consistencyModeDefault is the configuration value used to tell
60	// consul to use default consistency.
61	consistencyModeDefault = "default"
62
63	// consistencyModeStrong is the configuration value used to tell
64	// consul to use strong consistency.
65	consistencyModeStrong = "strong"
66)
67
68type notifyEvent struct{}
69
70// Verify ConsulBackend satisfies the correct interfaces
71var _ physical.Backend = (*ConsulBackend)(nil)
72var _ physical.HABackend = (*ConsulBackend)(nil)
73var _ physical.Lock = (*ConsulLock)(nil)
74var _ physical.Transactional = (*ConsulBackend)(nil)
75var _ physical.ServiceDiscovery = (*ConsulBackend)(nil)
76
77var (
78	hostnameRegex = regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
79)
80
81// ConsulBackend is a physical backend that stores data at specific
82// prefix within Consul. It is used for most production situations as
83// it allows Vault to run on multiple machines in a highly-available manner.
84type ConsulBackend struct {
85	path                string
86	logger              log.Logger
87	client              *api.Client
88	kv                  *api.KV
89	permitPool          *physical.PermitPool
90	serviceLock         sync.RWMutex
91	redirectHost        string
92	redirectPort        int64
93	serviceName         string
94	serviceTags         []string
95	serviceAddress      *string
96	disableRegistration bool
97	checkTimeout        time.Duration
98	consistencyMode     string
99
100	notifyActiveCh      chan notifyEvent
101	notifySealedCh      chan notifyEvent
102	notifyPerfStandbyCh chan notifyEvent
103
104	sessionTTL   string
105	lockWaitTime time.Duration
106}
107
108// NewConsulBackend constructs a Consul backend using the given API client
109// and the prefix in the KV store.
110func NewConsulBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
111	// Get the path in Consul
112	path, ok := conf["path"]
113	if !ok {
114		path = "vault/"
115	}
116	if logger.IsDebug() {
117		logger.Debug("config path set", "path", path)
118	}
119
120	// Ensure path is suffixed but not prefixed
121	if !strings.HasSuffix(path, "/") {
122		logger.Warn("appending trailing forward slash to path")
123		path += "/"
124	}
125	if strings.HasPrefix(path, "/") {
126		logger.Warn("trimming path of its forward slash")
127		path = strings.TrimPrefix(path, "/")
128	}
129
130	// Allow admins to disable consul integration
131	disableReg, ok := conf["disable_registration"]
132	var disableRegistration bool
133	if ok && disableReg != "" {
134		b, err := parseutil.ParseBool(disableReg)
135		if err != nil {
136			return nil, errwrap.Wrapf("failed parsing disable_registration parameter: {{err}}", err)
137		}
138		disableRegistration = b
139	}
140	if logger.IsDebug() {
141		logger.Debug("config disable_registration set", "disable_registration", disableRegistration)
142	}
143
144	// Get the service name to advertise in Consul
145	service, ok := conf["service"]
146	if !ok {
147		service = DefaultServiceName
148	}
149	if !hostnameRegex.MatchString(service) {
150		return nil, errors.New("service name must be valid per RFC 1123 and can contain only alphanumeric characters or dashes")
151	}
152	if logger.IsDebug() {
153		logger.Debug("config service set", "service", service)
154	}
155
156	// Get the additional tags to attach to the registered service name
157	tags := conf["service_tags"]
158	if logger.IsDebug() {
159		logger.Debug("config service_tags set", "service_tags", tags)
160	}
161
162	// Get the service-specific address to override the use of the HA redirect address
163	var serviceAddr *string
164	serviceAddrStr, ok := conf["service_address"]
165	if ok {
166		serviceAddr = &serviceAddrStr
167	}
168	if logger.IsDebug() {
169		logger.Debug("config service_address set", "service_address", serviceAddr)
170	}
171
172	checkTimeout := defaultCheckTimeout
173	checkTimeoutStr, ok := conf["check_timeout"]
174	if ok {
175		d, err := parseutil.ParseDurationSecond(checkTimeoutStr)
176		if err != nil {
177			return nil, err
178		}
179
180		min, _ := DurationMinusBufferDomain(d, checkMinBuffer, checkJitterFactor)
181		if min < checkMinBuffer {
182			return nil, fmt.Errorf("consul check_timeout must be greater than %v", min)
183		}
184
185		checkTimeout = d
186		if logger.IsDebug() {
187			logger.Debug("config check_timeout set", "check_timeout", d)
188		}
189	}
190
191	sessionTTL := api.DefaultLockSessionTTL
192	sessionTTLStr, ok := conf["session_ttl"]
193	if ok {
194		_, err := parseutil.ParseDurationSecond(sessionTTLStr)
195		if err != nil {
196			return nil, errwrap.Wrapf("invalid session_ttl: {{err}}", err)
197		}
198		sessionTTL = sessionTTLStr
199		if logger.IsDebug() {
200			logger.Debug("config session_ttl set", "session_ttl", sessionTTL)
201		}
202	}
203
204	lockWaitTime := api.DefaultLockWaitTime
205	lockWaitTimeRaw, ok := conf["lock_wait_time"]
206	if ok {
207		d, err := parseutil.ParseDurationSecond(lockWaitTimeRaw)
208		if err != nil {
209			return nil, errwrap.Wrapf("invalid lock_wait_time: {{err}}", err)
210		}
211		lockWaitTime = d
212		if logger.IsDebug() {
213			logger.Debug("config lock_wait_time set", "lock_wait_time", d)
214		}
215	}
216
217	// Configure the client
218	consulConf := api.DefaultConfig()
219	// Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore
220	consulConf.Transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
221
222	if addr, ok := conf["address"]; ok {
223		consulConf.Address = addr
224		if logger.IsDebug() {
225			logger.Debug("config address set", "address", addr)
226		}
227
228		// Copied from the Consul API module; set the Scheme based on
229		// the protocol field if address looks ike a URL.
230		// This can enable the TLS configuration below.
231		parts := strings.SplitN(addr, "://", 2)
232		if len(parts) == 2 {
233			if parts[0] == "http" || parts[0] == "https" {
234				consulConf.Scheme = parts[0]
235				consulConf.Address = parts[1]
236				if logger.IsDebug() {
237					logger.Debug("config address parsed", "scheme", parts[0])
238					logger.Debug("config scheme parsed", "address", parts[1])
239				}
240			} // allow "unix:" or whatever else consul supports in the future
241		}
242	}
243	if scheme, ok := conf["scheme"]; ok {
244		consulConf.Scheme = scheme
245		if logger.IsDebug() {
246			logger.Debug("config scheme set", "scheme", scheme)
247		}
248	}
249	if token, ok := conf["token"]; ok {
250		consulConf.Token = token
251		logger.Debug("config token set")
252	}
253
254	if consulConf.Scheme == "https" {
255		// Use the parsed Address instead of the raw conf['address']
256		tlsClientConfig, err := setupTLSConfig(conf, consulConf.Address)
257		if err != nil {
258			return nil, err
259		}
260
261		consulConf.Transport.TLSClientConfig = tlsClientConfig
262		if err := http2.ConfigureTransport(consulConf.Transport); err != nil {
263			return nil, err
264		}
265		logger.Debug("configured TLS")
266	}
267
268	consulConf.HttpClient = &http.Client{Transport: consulConf.Transport}
269	client, err := api.NewClient(consulConf)
270	if err != nil {
271		return nil, errwrap.Wrapf("client setup failed: {{err}}", err)
272	}
273
274	maxParStr, ok := conf["max_parallel"]
275	var maxParInt int
276	if ok {
277		maxParInt, err = strconv.Atoi(maxParStr)
278		if err != nil {
279			return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
280		}
281		if logger.IsDebug() {
282			logger.Debug("max_parallel set", "max_parallel", maxParInt)
283		}
284	}
285
286	consistencyMode, ok := conf["consistency_mode"]
287	if ok {
288		switch consistencyMode {
289		case consistencyModeDefault, consistencyModeStrong:
290		default:
291			return nil, fmt.Errorf("invalid consistency_mode value: %q", consistencyMode)
292		}
293	} else {
294		consistencyMode = consistencyModeDefault
295	}
296
297	// Setup the backend
298	c := &ConsulBackend{
299		path:                path,
300		logger:              logger,
301		client:              client,
302		kv:                  client.KV(),
303		permitPool:          physical.NewPermitPool(maxParInt),
304		serviceName:         service,
305		serviceTags:         strutil.ParseDedupLowercaseAndSortStrings(tags, ","),
306		serviceAddress:      serviceAddr,
307		checkTimeout:        checkTimeout,
308		disableRegistration: disableRegistration,
309		consistencyMode:     consistencyMode,
310		notifyActiveCh:      make(chan notifyEvent),
311		notifySealedCh:      make(chan notifyEvent),
312		notifyPerfStandbyCh: make(chan notifyEvent),
313		sessionTTL:          sessionTTL,
314		lockWaitTime:        lockWaitTime,
315	}
316	return c, nil
317}
318
319func setupTLSConfig(conf map[string]string, address string) (*tls.Config, error) {
320	serverName, _, err := net.SplitHostPort(address)
321	switch {
322	case err == nil:
323	case strings.Contains(err.Error(), "missing port"):
324		serverName = conf["address"]
325	default:
326		return nil, err
327	}
328
329	insecureSkipVerify := false
330	tlsSkipVerify, ok := conf["tls_skip_verify"]
331
332	if ok && tlsSkipVerify != "" {
333		b, err := parseutil.ParseBool(tlsSkipVerify)
334		if err != nil {
335			return nil, errwrap.Wrapf("failed parsing tls_skip_verify parameter: {{err}}", err)
336		}
337		insecureSkipVerify = b
338	}
339
340	tlsMinVersionStr, ok := conf["tls_min_version"]
341	if !ok {
342		// Set the default value
343		tlsMinVersionStr = "tls12"
344	}
345
346	tlsMinVersion, ok := tlsutil.TLSLookup[tlsMinVersionStr]
347	if !ok {
348		return nil, fmt.Errorf("invalid 'tls_min_version'")
349	}
350
351	tlsClientConfig := &tls.Config{
352		MinVersion:         tlsMinVersion,
353		InsecureSkipVerify: insecureSkipVerify,
354		ServerName:         serverName,
355	}
356
357	_, okCert := conf["tls_cert_file"]
358	_, okKey := conf["tls_key_file"]
359
360	if okCert && okKey {
361		tlsCert, err := tls.LoadX509KeyPair(conf["tls_cert_file"], conf["tls_key_file"])
362		if err != nil {
363			return nil, errwrap.Wrapf("client tls setup failed: {{err}}", err)
364		}
365
366		tlsClientConfig.Certificates = []tls.Certificate{tlsCert}
367	}
368
369	if tlsCaFile, ok := conf["tls_ca_file"]; ok {
370		caPool := x509.NewCertPool()
371
372		data, err := ioutil.ReadFile(tlsCaFile)
373		if err != nil {
374			return nil, errwrap.Wrapf("failed to read CA file: {{err}}", err)
375		}
376
377		if !caPool.AppendCertsFromPEM(data) {
378			return nil, fmt.Errorf("failed to parse CA certificate")
379		}
380
381		tlsClientConfig.RootCAs = caPool
382	}
383
384	return tlsClientConfig, nil
385}
386
387// Used to run multiple entries via a transaction
388func (c *ConsulBackend) Transaction(ctx context.Context, txns []*physical.TxnEntry) error {
389	if len(txns) == 0 {
390		return nil
391	}
392
393	ops := make([]*api.KVTxnOp, 0, len(txns))
394
395	for _, op := range txns {
396		cop := &api.KVTxnOp{
397			Key: c.path + op.Entry.Key,
398		}
399		switch op.Operation {
400		case physical.DeleteOperation:
401			cop.Verb = api.KVDelete
402		case physical.PutOperation:
403			cop.Verb = api.KVSet
404			cop.Value = op.Entry.Value
405		default:
406			return fmt.Errorf("%q is not a supported transaction operation", op.Operation)
407		}
408
409		ops = append(ops, cop)
410	}
411
412	c.permitPool.Acquire()
413	defer c.permitPool.Release()
414
415	queryOpts := &api.QueryOptions{}
416	queryOpts = queryOpts.WithContext(ctx)
417
418	ok, resp, _, err := c.kv.Txn(ops, queryOpts)
419	if err != nil {
420		if strings.Contains(err.Error(), "is too large") {
421			return errwrap.Wrapf(fmt.Sprintf("%s: {{err}}", physical.ErrValueTooLarge), err)
422		}
423		return err
424	}
425	if ok && len(resp.Errors) == 0 {
426		return nil
427	}
428
429	var retErr *multierror.Error
430	for _, res := range resp.Errors {
431		retErr = multierror.Append(retErr, errors.New(res.What))
432	}
433
434	return retErr
435}
436
437// Put is used to insert or update an entry
438func (c *ConsulBackend) Put(ctx context.Context, entry *physical.Entry) error {
439	defer metrics.MeasureSince([]string{"consul", "put"}, time.Now())
440
441	c.permitPool.Acquire()
442	defer c.permitPool.Release()
443
444	pair := &api.KVPair{
445		Key:   c.path + entry.Key,
446		Value: entry.Value,
447	}
448
449	writeOpts := &api.WriteOptions{}
450	writeOpts = writeOpts.WithContext(ctx)
451
452	_, err := c.kv.Put(pair, writeOpts)
453	if err != nil {
454		if strings.Contains(err.Error(), "Value exceeds") {
455			return errwrap.Wrapf(fmt.Sprintf("%s: {{err}}", physical.ErrValueTooLarge), err)
456		}
457		return err
458	}
459	return nil
460}
461
462// Get is used to fetch an entry
463func (c *ConsulBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
464	defer metrics.MeasureSince([]string{"consul", "get"}, time.Now())
465
466	c.permitPool.Acquire()
467	defer c.permitPool.Release()
468
469	queryOpts := &api.QueryOptions{}
470	queryOpts = queryOpts.WithContext(ctx)
471
472	if c.consistencyMode == consistencyModeStrong {
473		queryOpts.RequireConsistent = true
474	}
475
476	pair, _, err := c.kv.Get(c.path+key, queryOpts)
477	if err != nil {
478		return nil, err
479	}
480	if pair == nil {
481		return nil, nil
482	}
483	ent := &physical.Entry{
484		Key:   key,
485		Value: pair.Value,
486	}
487	return ent, nil
488}
489
490// Delete is used to permanently delete an entry
491func (c *ConsulBackend) Delete(ctx context.Context, key string) error {
492	defer metrics.MeasureSince([]string{"consul", "delete"}, time.Now())
493
494	c.permitPool.Acquire()
495	defer c.permitPool.Release()
496
497	writeOpts := &api.WriteOptions{}
498	writeOpts = writeOpts.WithContext(ctx)
499
500	_, err := c.kv.Delete(c.path+key, writeOpts)
501	return err
502}
503
504// List is used to list all the keys under a given
505// prefix, up to the next prefix.
506func (c *ConsulBackend) List(ctx context.Context, prefix string) ([]string, error) {
507	defer metrics.MeasureSince([]string{"consul", "list"}, time.Now())
508	scan := c.path + prefix
509
510	// The TrimPrefix call below will not work correctly if we have "//" at the
511	// end. This can happen in cases where you are e.g. listing the root of a
512	// prefix in a logical backend via "/" instead of ""
513	if strings.HasSuffix(scan, "//") {
514		scan = scan[:len(scan)-1]
515	}
516
517	c.permitPool.Acquire()
518	defer c.permitPool.Release()
519
520	queryOpts := &api.QueryOptions{}
521	queryOpts = queryOpts.WithContext(ctx)
522
523	out, _, err := c.kv.Keys(scan, "/", queryOpts)
524	for idx, val := range out {
525		out[idx] = strings.TrimPrefix(val, scan)
526	}
527
528	return out, err
529}
530
531// Lock is used for mutual exclusion based on the given key.
532func (c *ConsulBackend) LockWith(key, value string) (physical.Lock, error) {
533	// Create the lock
534	opts := &api.LockOptions{
535		Key:            c.path + key,
536		Value:          []byte(value),
537		SessionName:    "Vault Lock",
538		MonitorRetries: 5,
539		SessionTTL:     c.sessionTTL,
540		LockWaitTime:   c.lockWaitTime,
541	}
542	lock, err := c.client.LockOpts(opts)
543	if err != nil {
544		return nil, errwrap.Wrapf("failed to create lock: {{err}}", err)
545	}
546	cl := &ConsulLock{
547		client:          c.client,
548		key:             c.path + key,
549		lock:            lock,
550		consistencyMode: c.consistencyMode,
551	}
552	return cl, nil
553}
554
555// HAEnabled indicates whether the HA functionality should be exposed.
556// Currently always returns true.
557func (c *ConsulBackend) HAEnabled() bool {
558	return true
559}
560
561// DetectHostAddr is used to detect the host address by asking the Consul agent
562func (c *ConsulBackend) DetectHostAddr() (string, error) {
563	agent := c.client.Agent()
564	self, err := agent.Self()
565	if err != nil {
566		return "", err
567	}
568	addr, ok := self["Member"]["Addr"].(string)
569	if !ok {
570		return "", fmt.Errorf("unable to convert an address to string")
571	}
572	return addr, nil
573}
574
575// ConsulLock is used to provide the Lock interface backed by Consul
576type ConsulLock struct {
577	client          *api.Client
578	key             string
579	lock            *api.Lock
580	consistencyMode string
581}
582
583func (c *ConsulLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
584	return c.lock.Lock(stopCh)
585}
586
587func (c *ConsulLock) Unlock() error {
588	return c.lock.Unlock()
589}
590
591func (c *ConsulLock) Value() (bool, string, error) {
592	kv := c.client.KV()
593
594	var queryOptions *api.QueryOptions
595	if c.consistencyMode == consistencyModeStrong {
596		queryOptions = &api.QueryOptions{
597			RequireConsistent: true,
598		}
599	}
600
601	pair, _, err := kv.Get(c.key, queryOptions)
602	if err != nil {
603		return false, "", err
604	}
605	if pair == nil {
606		return false, "", nil
607	}
608	held := pair.Session != ""
609	value := string(pair.Value)
610	return held, value, nil
611}
612
613func (c *ConsulBackend) NotifyActiveStateChange() error {
614	select {
615	case c.notifyActiveCh <- notifyEvent{}:
616	default:
617		// NOTE: If this occurs Vault's active status could be out of
618		// sync with Consul until reconcileTimer expires.
619		c.logger.Warn("concurrent state change notify dropped")
620	}
621
622	return nil
623}
624
625func (c *ConsulBackend) NotifyPerformanceStandbyStateChange() error {
626	select {
627	case c.notifyPerfStandbyCh <- notifyEvent{}:
628	default:
629		// NOTE: If this occurs Vault's active status could be out of
630		// sync with Consul until reconcileTimer expires.
631		c.logger.Warn("concurrent state change notify dropped")
632	}
633
634	return nil
635}
636
637func (c *ConsulBackend) NotifySealedStateChange() error {
638	select {
639	case c.notifySealedCh <- notifyEvent{}:
640	default:
641		// NOTE: If this occurs Vault's sealed status could be out of
642		// sync with Consul until checkTimer expires.
643		c.logger.Warn("concurrent sealed state change notify dropped")
644	}
645
646	return nil
647}
648
649func (c *ConsulBackend) checkDuration() time.Duration {
650	return DurationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor)
651}
652
653func (c *ConsulBackend) RunServiceDiscovery(waitGroup *sync.WaitGroup, shutdownCh physical.ShutdownChannel, redirectAddr string, activeFunc physical.ActiveFunction, sealedFunc physical.SealedFunction, perfStandbyFunc physical.PerformanceStandbyFunction) (err error) {
654	if err := c.setRedirectAddr(redirectAddr); err != nil {
655		return err
656	}
657
658	// 'server' command will wait for the below goroutine to complete
659	waitGroup.Add(1)
660
661	go c.runEventDemuxer(waitGroup, shutdownCh, redirectAddr, activeFunc, sealedFunc, perfStandbyFunc)
662
663	return nil
664}
665
666func (c *ConsulBackend) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh physical.ShutdownChannel, redirectAddr string, activeFunc physical.ActiveFunction, sealedFunc physical.SealedFunction, perfStandbyFunc physical.PerformanceStandbyFunction) {
667	// This defer statement should be executed last. So push it first.
668	defer waitGroup.Done()
669
670	// Fire the reconcileTimer immediately upon starting the event demuxer
671	reconcileTimer := time.NewTimer(0)
672	defer reconcileTimer.Stop()
673
674	// Schedule the first check.  Consul TTL checks are passing by
675	// default, checkTimer does not need to be run immediately.
676	checkTimer := time.NewTimer(c.checkDuration())
677	defer checkTimer.Stop()
678
679	// Use a reactor pattern to handle and dispatch events to singleton
680	// goroutine handlers for execution.  It is not acceptable to drop
681	// inbound events from Notify*().
682	//
683	// goroutines are dispatched if the demuxer can acquire a lock (via
684	// an atomic CAS incr) on the handler.  Handlers are responsible for
685	// deregistering themselves (atomic CAS decr).  Handlers and the
686	// demuxer share a lock to synchronize information at the beginning
687	// and end of a handler's life (or after a handler wakes up from
688	// sleeping during a back-off/retry).
689	var shutdown bool
690	var registeredServiceID string
691	checkLock := new(int32)
692	serviceRegLock := new(int32)
693
694	for !shutdown {
695		select {
696		case <-c.notifyActiveCh:
697			// Run reconcile immediately upon active state change notification
698			reconcileTimer.Reset(0)
699		case <-c.notifySealedCh:
700			// Run check timer immediately upon a seal state change notification
701			checkTimer.Reset(0)
702		case <-c.notifyPerfStandbyCh:
703			// Run check timer immediately upon a seal state change notification
704			checkTimer.Reset(0)
705		case <-reconcileTimer.C:
706			// Unconditionally rearm the reconcileTimer
707			reconcileTimer.Reset(reconcileTimeout - RandomStagger(reconcileTimeout/checkJitterFactor))
708
709			// Abort if service discovery is disabled or a
710			// reconcile handler is already active
711			if !c.disableRegistration && atomic.CompareAndSwapInt32(serviceRegLock, 0, 1) {
712				// Enter handler with serviceRegLock held
713				go func() {
714					defer atomic.CompareAndSwapInt32(serviceRegLock, 1, 0)
715					for !shutdown {
716						serviceID, err := c.reconcileConsul(registeredServiceID, activeFunc, sealedFunc, perfStandbyFunc)
717						if err != nil {
718							if c.logger.IsWarn() {
719								c.logger.Warn("reconcile unable to talk with Consul backend", "error", err)
720							}
721							time.Sleep(consulRetryInterval)
722							continue
723						}
724
725						c.serviceLock.Lock()
726						defer c.serviceLock.Unlock()
727
728						registeredServiceID = serviceID
729						return
730					}
731				}()
732			}
733		case <-checkTimer.C:
734			checkTimer.Reset(c.checkDuration())
735			// Abort if service discovery is disabled or a
736			// reconcile handler is active
737			if !c.disableRegistration && atomic.CompareAndSwapInt32(checkLock, 0, 1) {
738				// Enter handler with checkLock held
739				go func() {
740					defer atomic.CompareAndSwapInt32(checkLock, 1, 0)
741					for !shutdown {
742						sealed := sealedFunc()
743						if err := c.runCheck(sealed); err != nil {
744							if c.logger.IsWarn() {
745								c.logger.Warn("check unable to talk with Consul backend", "error", err)
746							}
747							time.Sleep(consulRetryInterval)
748							continue
749						}
750						return
751					}
752				}()
753			}
754		case <-shutdownCh:
755			c.logger.Info("shutting down consul backend")
756			shutdown = true
757		}
758	}
759
760	c.serviceLock.RLock()
761	defer c.serviceLock.RUnlock()
762	if err := c.client.Agent().ServiceDeregister(registeredServiceID); err != nil {
763		if c.logger.IsWarn() {
764			c.logger.Warn("service deregistration failed", "error", err)
765		}
766	}
767}
768
769// checkID returns the ID used for a Consul Check.  Assume at least a read
770// lock is held.
771func (c *ConsulBackend) checkID() string {
772	return fmt.Sprintf("%s:vault-sealed-check", c.serviceID())
773}
774
775// serviceID returns the Vault ServiceID for use in Consul.  Assume at least
776// a read lock is held.
777func (c *ConsulBackend) serviceID() string {
778	return fmt.Sprintf("%s:%s:%d", c.serviceName, c.redirectHost, c.redirectPort)
779}
780
781// reconcileConsul queries the state of Vault Core and Consul and fixes up
782// Consul's state according to what's in Vault.  reconcileConsul is called
783// without any locks held and can be run concurrently, therefore no changes
784// to ConsulBackend can be made in this method (i.e. wtb const receiver for
785// compiler enforced safety).
786func (c *ConsulBackend) reconcileConsul(registeredServiceID string, activeFunc physical.ActiveFunction, sealedFunc physical.SealedFunction, perfStandbyFunc physical.PerformanceStandbyFunction) (serviceID string, err error) {
787	// Query vault Core for its current state
788	active := activeFunc()
789	sealed := sealedFunc()
790	perfStandby := perfStandbyFunc()
791
792	agent := c.client.Agent()
793	catalog := c.client.Catalog()
794
795	serviceID = c.serviceID()
796
797	// Get the current state of Vault from Consul
798	var currentVaultService *api.CatalogService
799	if services, _, err := catalog.Service(c.serviceName, "", &api.QueryOptions{AllowStale: true}); err == nil {
800		for _, service := range services {
801			if serviceID == service.ServiceID {
802				currentVaultService = service
803				break
804			}
805		}
806	}
807
808	tags := c.fetchServiceTags(active, perfStandby)
809
810	var reregister bool
811
812	switch {
813	case currentVaultService == nil, registeredServiceID == "":
814		reregister = true
815	default:
816		switch {
817		case !strutil.EquivalentSlices(currentVaultService.ServiceTags, tags):
818			reregister = true
819		}
820	}
821
822	if !reregister {
823		// When re-registration is not required, return a valid serviceID
824		// to avoid registration in the next cycle.
825		return serviceID, nil
826	}
827
828	// If service address was set explicitly in configuration, use that
829	// as the service-specific address instead of the HA redirect address.
830	var serviceAddress string
831	if c.serviceAddress == nil {
832		serviceAddress = c.redirectHost
833	} else {
834		serviceAddress = *c.serviceAddress
835	}
836
837	service := &api.AgentServiceRegistration{
838		ID:                serviceID,
839		Name:              c.serviceName,
840		Tags:              tags,
841		Port:              int(c.redirectPort),
842		Address:           serviceAddress,
843		EnableTagOverride: false,
844	}
845
846	checkStatus := api.HealthCritical
847	if !sealed {
848		checkStatus = api.HealthPassing
849	}
850
851	sealedCheck := &api.AgentCheckRegistration{
852		ID:        c.checkID(),
853		Name:      "Vault Sealed Status",
854		Notes:     "Vault service is healthy when Vault is in an unsealed status and can become an active Vault server",
855		ServiceID: serviceID,
856		AgentServiceCheck: api.AgentServiceCheck{
857			TTL:    c.checkTimeout.String(),
858			Status: checkStatus,
859		},
860	}
861
862	if err := agent.ServiceRegister(service); err != nil {
863		return "", errwrap.Wrapf(`service registration failed: {{err}}`, err)
864	}
865
866	if err := agent.CheckRegister(sealedCheck); err != nil {
867		return serviceID, errwrap.Wrapf(`service check registration failed: {{err}}`, err)
868	}
869
870	return serviceID, nil
871}
872
873// runCheck immediately pushes a TTL check.
874func (c *ConsulBackend) runCheck(sealed bool) error {
875	// Run a TTL check
876	agent := c.client.Agent()
877	if !sealed {
878		return agent.PassTTL(c.checkID(), "Vault Unsealed")
879	} else {
880		return agent.FailTTL(c.checkID(), "Vault Sealed")
881	}
882}
883
884// fetchServiceTags returns all of the relevant tags for Consul.
885func (c *ConsulBackend) fetchServiceTags(active bool, perfStandby bool) []string {
886	activeTag := "standby"
887	if active {
888		activeTag = "active"
889	}
890
891	result := append(c.serviceTags, activeTag)
892
893	if perfStandby {
894		result = append(c.serviceTags, "performance-standby")
895	}
896
897	return result
898}
899
900func (c *ConsulBackend) setRedirectAddr(addr string) (err error) {
901	if addr == "" {
902		return fmt.Errorf("redirect address must not be empty")
903	}
904
905	url, err := url.Parse(addr)
906	if err != nil {
907		return errwrap.Wrapf(fmt.Sprintf("failed to parse redirect URL %q: {{err}}", addr), err)
908	}
909
910	var portStr string
911	c.redirectHost, portStr, err = net.SplitHostPort(url.Host)
912	if err != nil {
913		if url.Scheme == "http" {
914			portStr = "80"
915		} else if url.Scheme == "https" {
916			portStr = "443"
917		} else if url.Scheme == "unix" {
918			portStr = "-1"
919			c.redirectHost = url.Path
920		} else {
921			return errwrap.Wrapf(fmt.Sprintf(`failed to find a host:port in redirect address "%v": {{err}}`, url.Host), err)
922		}
923	}
924	c.redirectPort, err = strconv.ParseInt(portStr, 10, 0)
925	if err != nil || c.redirectPort < -1 || c.redirectPort > 65535 {
926		return errwrap.Wrapf(fmt.Sprintf(`failed to parse valid port "%v": {{err}}`, portStr), err)
927	}
928
929	return nil
930}
931