1package consul
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"math/rand"
8	"net"
9	"net/http"
10	"net/url"
11	"regexp"
12	"strconv"
13	"strings"
14	"sync"
15	"sync/atomic"
16	"time"
17
18	"github.com/hashicorp/consul/api"
19	log "github.com/hashicorp/go-hclog"
20	"github.com/hashicorp/vault/sdk/helper/consts"
21	"github.com/hashicorp/vault/sdk/helper/parseutil"
22	"github.com/hashicorp/vault/sdk/helper/strutil"
23	"github.com/hashicorp/vault/sdk/helper/tlsutil"
24	sr "github.com/hashicorp/vault/serviceregistration"
25	"github.com/hashicorp/vault/vault/diagnose"
26	atomicB "go.uber.org/atomic"
27	"golang.org/x/net/http2"
28)
29
30const (
31	// checkJitterFactor specifies the jitter factor used to stagger checks
32	checkJitterFactor = 16
33
34	// checkMinBuffer specifies provides a guarantee that a check will not
35	// be executed too close to the TTL check timeout
36	checkMinBuffer = 100 * time.Millisecond
37
38	// consulRetryInterval specifies the retry duration to use when an
39	// API call to the Consul agent fails.
40	consulRetryInterval = 1 * time.Second
41
42	// defaultCheckTimeout changes the timeout of TTL checks
43	defaultCheckTimeout = 5 * time.Second
44
45	// DefaultServiceName is the default Consul service name used when
46	// advertising a Vault instance.
47	DefaultServiceName = "vault"
48
49	// reconcileTimeout is how often Vault should query Consul to detect
50	// and fix any state drift.
51	reconcileTimeout = 60 * time.Second
52
53	// metaExternalSource is a metadata value for external-source that can be
54	// used by the Consul UI.
55	metaExternalSource = "vault"
56)
57
58var hostnameRegex = regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
59
60// serviceRegistration is a ServiceRegistration that advertises the state of
61// Vault to Consul.
62type serviceRegistration struct {
63	Client *api.Client
64
65	logger              log.Logger
66	serviceLock         sync.RWMutex
67	redirectHost        string
68	redirectPort        int64
69	serviceName         string
70	serviceTags         []string
71	serviceAddress      *string
72	disableRegistration bool
73	checkTimeout        time.Duration
74
75	notifyActiveCh      chan struct{}
76	notifySealedCh      chan struct{}
77	notifyPerfStandbyCh chan struct{}
78	notifyInitializedCh chan struct{}
79
80	isActive      *atomicB.Bool
81	isSealed      *atomicB.Bool
82	isPerfStandby *atomicB.Bool
83	isInitialized *atomicB.Bool
84}
85
86// NewConsulServiceRegistration constructs a Consul-based ServiceRegistration.
87func NewServiceRegistration(conf map[string]string, logger log.Logger, state sr.State) (sr.ServiceRegistration, error) {
88	// Allow admins to disable consul integration
89	disableReg, ok := conf["disable_registration"]
90	var disableRegistration bool
91	if ok && disableReg != "" {
92		b, err := parseutil.ParseBool(disableReg)
93		if err != nil {
94			return nil, fmt.Errorf("failed parsing disable_registration parameter: %w", err)
95		}
96		disableRegistration = b
97	}
98	if logger.IsDebug() {
99		logger.Debug("config disable_registration set", "disable_registration", disableRegistration)
100	}
101
102	// Get the service name to advertise in Consul
103	service, ok := conf["service"]
104	if !ok {
105		service = DefaultServiceName
106	}
107	if !hostnameRegex.MatchString(service) {
108		return nil, errors.New("service name must be valid per RFC 1123 and can contain only alphanumeric characters or dashes")
109	}
110	if logger.IsDebug() {
111		logger.Debug("config service set", "service", service)
112	}
113
114	// Get the additional tags to attach to the registered service name
115	tags := conf["service_tags"]
116	if logger.IsDebug() {
117		logger.Debug("config service_tags set", "service_tags", tags)
118	}
119
120	// Get the service-specific address to override the use of the HA redirect address
121	var serviceAddr *string
122	serviceAddrStr, ok := conf["service_address"]
123	if ok {
124		serviceAddr = &serviceAddrStr
125	}
126	if logger.IsDebug() {
127		logger.Debug("config service_address set", "service_address", serviceAddr)
128	}
129
130	checkTimeout := defaultCheckTimeout
131	checkTimeoutStr, ok := conf["check_timeout"]
132	if ok {
133		d, err := parseutil.ParseDurationSecond(checkTimeoutStr)
134		if err != nil {
135			return nil, err
136		}
137
138		min, _ := durationMinusBufferDomain(d, checkMinBuffer, checkJitterFactor)
139		if min < checkMinBuffer {
140			return nil, fmt.Errorf("consul check_timeout must be greater than %v", min)
141		}
142
143		checkTimeout = d
144		if logger.IsDebug() {
145			logger.Debug("config check_timeout set", "check_timeout", d)
146		}
147	}
148
149	// Configure the client
150	consulConf := api.DefaultConfig()
151	// Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore
152	consulConf.Transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
153
154	SetupSecureTLS(context.Background(), consulConf, conf, logger, false)
155
156	consulConf.HttpClient = &http.Client{Transport: consulConf.Transport}
157	client, err := api.NewClient(consulConf)
158	if err != nil {
159		return nil, fmt.Errorf("client setup failed: %w", err)
160	}
161
162	// Setup the backend
163	c := &serviceRegistration{
164		Client: client,
165
166		logger:              logger,
167		serviceName:         service,
168		serviceTags:         strutil.ParseDedupLowercaseAndSortStrings(tags, ","),
169		serviceAddress:      serviceAddr,
170		checkTimeout:        checkTimeout,
171		disableRegistration: disableRegistration,
172
173		notifyActiveCh:      make(chan struct{}),
174		notifySealedCh:      make(chan struct{}),
175		notifyPerfStandbyCh: make(chan struct{}),
176		notifyInitializedCh: make(chan struct{}),
177
178		isActive:      atomicB.NewBool(state.IsActive),
179		isSealed:      atomicB.NewBool(state.IsSealed),
180		isPerfStandby: atomicB.NewBool(state.IsPerformanceStandby),
181		isInitialized: atomicB.NewBool(state.IsInitialized),
182	}
183	return c, nil
184}
185
186func SetupSecureTLS(ctx context.Context, consulConf *api.Config, conf map[string]string, logger log.Logger, isDiagnose bool) error {
187	if addr, ok := conf["address"]; ok {
188		consulConf.Address = addr
189		if logger.IsDebug() {
190			logger.Debug("config address set", "address", addr)
191		}
192
193		// Copied from the Consul API module; set the Scheme based on
194		// the protocol field if address looks ike a URL.
195		// This can enable the TLS configuration below.
196		parts := strings.SplitN(addr, "://", 2)
197		if len(parts) == 2 {
198			if parts[0] == "http" || parts[0] == "https" {
199				consulConf.Scheme = parts[0]
200				consulConf.Address = parts[1]
201				if logger.IsDebug() {
202					logger.Debug("config address parsed", "scheme", parts[0])
203					logger.Debug("config scheme parsed", "address", parts[1])
204				}
205			} // allow "unix:" or whatever else consul supports in the future
206		}
207	}
208	if scheme, ok := conf["scheme"]; ok {
209		consulConf.Scheme = scheme
210		if logger.IsDebug() {
211			logger.Debug("config scheme set", "scheme", scheme)
212		}
213	}
214	if token, ok := conf["token"]; ok {
215		consulConf.Token = token
216		logger.Debug("config token set")
217	}
218
219	if consulConf.Scheme == "https" {
220		if isDiagnose {
221			certPath, okCert := conf["tls_cert_file"]
222			keyPath, okKey := conf["tls_key_file"]
223			if okCert && okKey {
224				warnings, err := diagnose.TLSFileChecks(certPath, keyPath)
225				for _, warning := range warnings {
226					diagnose.Warn(ctx, warning)
227				}
228				if err != nil {
229					return err
230				}
231				return nil
232			}
233			return fmt.Errorf("key or cert path: %s, %s, cannot be loaded from consul config file", certPath, keyPath)
234		}
235
236		// Use the parsed Address instead of the raw conf['address']
237		tlsClientConfig, err := tlsutil.SetupTLSConfig(conf, consulConf.Address)
238		if err != nil {
239			return err
240		}
241
242		consulConf.Transport.TLSClientConfig = tlsClientConfig
243		if err := http2.ConfigureTransport(consulConf.Transport); err != nil {
244			return err
245		}
246		logger.Debug("configured TLS")
247	} else {
248		if isDiagnose {
249			diagnose.Skipped(ctx, "HTTPS is not used, Skipping TLS verification.")
250		}
251	}
252	return nil
253}
254
255func (c *serviceRegistration) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup, redirectAddr string) error {
256	go func() {
257		if err := c.runServiceRegistration(wait, shutdownCh, redirectAddr); err != nil {
258			if c.logger.IsError() {
259				c.logger.Error(fmt.Sprintf("error running service registration: %s", err))
260			}
261		}
262	}()
263	return nil
264}
265
266func (c *serviceRegistration) NotifyActiveStateChange(isActive bool) error {
267	c.isActive.Store(isActive)
268	select {
269	case c.notifyActiveCh <- struct{}{}:
270	default:
271		// NOTE: If this occurs Vault's active status could be out of
272		// sync with Consul until reconcileTimer expires.
273		c.logger.Warn("concurrent state change notify dropped")
274	}
275
276	return nil
277}
278
279func (c *serviceRegistration) NotifyPerformanceStandbyStateChange(isStandby bool) error {
280	c.isPerfStandby.Store(isStandby)
281	select {
282	case c.notifyPerfStandbyCh <- struct{}{}:
283	default:
284		// NOTE: If this occurs Vault's active status could be out of
285		// sync with Consul until reconcileTimer expires.
286		c.logger.Warn("concurrent state change notify dropped")
287	}
288
289	return nil
290}
291
292func (c *serviceRegistration) NotifySealedStateChange(isSealed bool) error {
293	c.isSealed.Store(isSealed)
294	select {
295	case c.notifySealedCh <- struct{}{}:
296	default:
297		// NOTE: If this occurs Vault's sealed status could be out of
298		// sync with Consul until checkTimer expires.
299		c.logger.Warn("concurrent sealed state change notify dropped")
300	}
301
302	return nil
303}
304
305func (c *serviceRegistration) NotifyInitializedStateChange(isInitialized bool) error {
306	c.isInitialized.Store(isInitialized)
307	select {
308	case c.notifyInitializedCh <- struct{}{}:
309	default:
310		// NOTE: If this occurs Vault's initialized status could be out of
311		// sync with Consul until checkTimer expires.
312		c.logger.Warn("concurrent initalize state change notify dropped")
313	}
314
315	return nil
316}
317
318func (c *serviceRegistration) checkDuration() time.Duration {
319	return durationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor)
320}
321
322func (c *serviceRegistration) runServiceRegistration(waitGroup *sync.WaitGroup, shutdownCh <-chan struct{}, redirectAddr string) (err error) {
323	if err := c.setRedirectAddr(redirectAddr); err != nil {
324		return err
325	}
326
327	// 'server' command will wait for the below goroutine to complete
328	waitGroup.Add(1)
329
330	go c.runEventDemuxer(waitGroup, shutdownCh)
331
332	return nil
333}
334
335func (c *serviceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh <-chan struct{}) {
336	// This defer statement should be executed last. So push it first.
337	defer waitGroup.Done()
338
339	// Fire the reconcileTimer immediately upon starting the event demuxer
340	reconcileTimer := time.NewTimer(0)
341	defer reconcileTimer.Stop()
342
343	// Schedule the first check.  Consul TTL checks are passing by
344	// default, checkTimer does not need to be run immediately.
345	checkTimer := time.NewTimer(c.checkDuration())
346	defer checkTimer.Stop()
347
348	// Use a reactor pattern to handle and dispatch events to singleton
349	// goroutine handlers for execution.  It is not acceptable to drop
350	// inbound events from Notify*().
351	//
352	// goroutines are dispatched if the demuxer can acquire a lock (via
353	// an atomic CAS incr) on the handler.  Handlers are responsible for
354	// deregistering themselves (atomic CAS decr).  Handlers and the
355	// demuxer share a lock to synchronize information at the beginning
356	// and end of a handler's life (or after a handler wakes up from
357	// sleeping during a back-off/retry).
358	var shutdown atomicB.Bool
359	var registeredServiceID string
360	checkLock := new(int32)
361	serviceRegLock := new(int32)
362
363	for !shutdown.Load() {
364		select {
365		case <-c.notifyActiveCh:
366			// Run reconcile immediately upon active state change notification
367			reconcileTimer.Reset(0)
368		case <-c.notifySealedCh:
369			// Run check timer immediately upon a seal state change notification
370			checkTimer.Reset(0)
371		case <-c.notifyPerfStandbyCh:
372			// Run check timer immediately upon a perfstandby state change notification
373			checkTimer.Reset(0)
374		case <-c.notifyInitializedCh:
375			// Run check timer immediately upon an initialized state change notification
376			checkTimer.Reset(0)
377		case <-reconcileTimer.C:
378			// Unconditionally rearm the reconcileTimer
379			reconcileTimer.Reset(reconcileTimeout - randomStagger(reconcileTimeout/checkJitterFactor))
380
381			// Abort if service discovery is disabled or a
382			// reconcile handler is already active
383			if !c.disableRegistration && atomic.CompareAndSwapInt32(serviceRegLock, 0, 1) {
384				// Enter handler with serviceRegLock held
385				go func() {
386					defer atomic.CompareAndSwapInt32(serviceRegLock, 1, 0)
387					for !shutdown.Load() {
388						serviceID, err := c.reconcileConsul(registeredServiceID)
389						if err != nil {
390							if c.logger.IsWarn() {
391								c.logger.Warn("reconcile unable to talk with Consul backend", "error", err)
392							}
393							time.Sleep(consulRetryInterval)
394							continue
395						}
396
397						c.serviceLock.Lock()
398						defer c.serviceLock.Unlock()
399
400						registeredServiceID = serviceID
401						return
402					}
403				}()
404			}
405		case <-checkTimer.C:
406			checkTimer.Reset(c.checkDuration())
407			// Abort if service discovery is disabled or a
408			// reconcile handler is active
409			if !c.disableRegistration && atomic.CompareAndSwapInt32(checkLock, 0, 1) {
410				// Enter handler with checkLock held
411				go func() {
412					defer atomic.CompareAndSwapInt32(checkLock, 1, 0)
413					for !shutdown.Load() {
414						if err := c.runCheck(c.isSealed.Load()); err != nil {
415							if c.logger.IsWarn() {
416								c.logger.Warn("check unable to talk with Consul backend", "error", err)
417							}
418							time.Sleep(consulRetryInterval)
419							continue
420						}
421						return
422					}
423				}()
424			}
425		case <-shutdownCh:
426			c.logger.Info("shutting down consul backend")
427			shutdown.Store(true)
428		}
429	}
430
431	c.serviceLock.RLock()
432	defer c.serviceLock.RUnlock()
433	if err := c.Client.Agent().ServiceDeregister(registeredServiceID); err != nil {
434		if c.logger.IsWarn() {
435			c.logger.Warn("service deregistration failed", "error", err)
436		}
437	}
438}
439
440// checkID returns the ID used for a Consul Check.  Assume at least a read
441// lock is held.
442func (c *serviceRegistration) checkID() string {
443	return fmt.Sprintf("%s:vault-sealed-check", c.serviceID())
444}
445
446// serviceID returns the Vault ServiceID for use in Consul.  Assume at least
447// a read lock is held.
448func (c *serviceRegistration) serviceID() string {
449	return fmt.Sprintf("%s:%s:%d", c.serviceName, c.redirectHost, c.redirectPort)
450}
451
452// reconcileConsul queries the state of Vault Core and Consul and fixes up
453// Consul's state according to what's in Vault.  reconcileConsul is called
454// without any locks held and can be run concurrently, therefore no changes
455// to serviceRegistration can be made in this method (i.e. wtb const receiver for
456// compiler enforced safety).
457func (c *serviceRegistration) reconcileConsul(registeredServiceID string) (serviceID string, err error) {
458	agent := c.Client.Agent()
459	catalog := c.Client.Catalog()
460
461	serviceID = c.serviceID()
462
463	// Get the current state of Vault from Consul
464	var currentVaultService *api.CatalogService
465	if services, _, err := catalog.Service(c.serviceName, "", &api.QueryOptions{AllowStale: true}); err == nil {
466		for _, service := range services {
467			if serviceID == service.ServiceID {
468				currentVaultService = service
469				break
470			}
471		}
472	}
473
474	tags := c.fetchServiceTags(c.isActive.Load(), c.isPerfStandby.Load(), c.isInitialized.Load())
475
476	var reregister bool
477
478	switch {
479	case currentVaultService == nil, registeredServiceID == "":
480		reregister = true
481	default:
482		switch {
483		case !strutil.EquivalentSlices(currentVaultService.ServiceTags, tags):
484			reregister = true
485		}
486	}
487
488	if !reregister {
489		// When re-registration is not required, return a valid serviceID
490		// to avoid registration in the next cycle.
491		return serviceID, nil
492	}
493
494	// If service address was set explicitly in configuration, use that
495	// as the service-specific address instead of the HA redirect address.
496	var serviceAddress string
497	if c.serviceAddress == nil {
498		serviceAddress = c.redirectHost
499	} else {
500		serviceAddress = *c.serviceAddress
501	}
502
503	service := &api.AgentServiceRegistration{
504		ID:                serviceID,
505		Name:              c.serviceName,
506		Tags:              tags,
507		Port:              int(c.redirectPort),
508		Address:           serviceAddress,
509		EnableTagOverride: false,
510		Meta: map[string]string{
511			"external-source": metaExternalSource,
512		},
513	}
514
515	checkStatus := api.HealthCritical
516	if !c.isSealed.Load() {
517		checkStatus = api.HealthPassing
518	}
519
520	sealedCheck := &api.AgentCheckRegistration{
521		ID:        c.checkID(),
522		Name:      "Vault Sealed Status",
523		Notes:     "Vault service is healthy when Vault is in an unsealed status and can become an active Vault server",
524		ServiceID: serviceID,
525		AgentServiceCheck: api.AgentServiceCheck{
526			TTL:    c.checkTimeout.String(),
527			Status: checkStatus,
528		},
529	}
530
531	if err := agent.ServiceRegister(service); err != nil {
532		return "", fmt.Errorf(`service registration failed: %w`, err)
533	}
534
535	if err := agent.CheckRegister(sealedCheck); err != nil {
536		return serviceID, fmt.Errorf(`service check registration failed: %w`, err)
537	}
538
539	return serviceID, nil
540}
541
542// runCheck immediately pushes a TTL check.
543func (c *serviceRegistration) runCheck(sealed bool) error {
544	// Run a TTL check
545	agent := c.Client.Agent()
546	if !sealed {
547		return agent.PassTTL(c.checkID(), "Vault Unsealed")
548	} else {
549		return agent.FailTTL(c.checkID(), "Vault Sealed")
550	}
551}
552
553// fetchServiceTags returns all of the relevant tags for Consul.
554func (c *serviceRegistration) fetchServiceTags(active, perfStandby, initialized bool) []string {
555	activeTag := "standby"
556	if active {
557		activeTag = "active"
558	}
559
560	result := append(c.serviceTags, activeTag)
561
562	if perfStandby {
563		result = append(c.serviceTags, "performance-standby")
564	}
565
566	if initialized {
567		result = append(result, "initialized")
568	}
569
570	return result
571}
572
573func (c *serviceRegistration) setRedirectAddr(addr string) (err error) {
574	if addr == "" {
575		return fmt.Errorf("redirect address must not be empty")
576	}
577
578	url, err := url.Parse(addr)
579	if err != nil {
580		return fmt.Errorf("failed to parse redirect URL %q: %w", addr, err)
581	}
582
583	var portStr string
584	c.redirectHost, portStr, err = net.SplitHostPort(url.Host)
585	if err != nil {
586		if url.Scheme == "http" {
587			portStr = "80"
588		} else if url.Scheme == "https" {
589			portStr = "443"
590		} else if url.Scheme == "unix" {
591			portStr = "-1"
592			c.redirectHost = url.Path
593		} else {
594			return fmt.Errorf("failed to find a host:port in redirect address %q: %w", url.Host, err)
595		}
596	}
597	c.redirectPort, err = strconv.ParseInt(portStr, 10, 0)
598	if err != nil || c.redirectPort < -1 || c.redirectPort > 65535 {
599		return fmt.Errorf("failed to parse valid port %q: %w", portStr, err)
600	}
601
602	return nil
603}
604
605// durationMinusBuffer returns a duration, minus a buffer and jitter
606// subtracted from the duration.  This function is used primarily for
607// servicing Consul TTL Checks in advance of the TTL.
608func durationMinusBuffer(intv time.Duration, buffer time.Duration, jitter int64) time.Duration {
609	d := intv - buffer
610	if jitter == 0 {
611		d -= randomStagger(d)
612	} else {
613		d -= randomStagger(time.Duration(int64(d) / jitter))
614	}
615	return d
616}
617
618// durationMinusBufferDomain returns the domain of valid durations from a
619// call to durationMinusBuffer.  This function is used to check user
620// specified input values to durationMinusBuffer.
621func durationMinusBufferDomain(intv time.Duration, buffer time.Duration, jitter int64) (min time.Duration, max time.Duration) {
622	max = intv - buffer
623	if jitter == 0 {
624		min = max
625	} else {
626		min = max - time.Duration(int64(max)/jitter)
627	}
628	return min, max
629}
630
631// randomStagger returns an interval between 0 and the duration
632func randomStagger(intv time.Duration) time.Duration {
633	if intv == 0 {
634		return 0
635	}
636	return time.Duration(uint64(rand.Int63()) % uint64(intv))
637}
638