1package physical
2
3import (
4	"fmt"
5	"io/ioutil"
6	"log"
7	"net"
8	"net/url"
9	"strconv"
10	"strings"
11	"sync"
12	"sync/atomic"
13	"time"
14
15	"crypto/tls"
16	"crypto/x509"
17
18	"github.com/armon/go-metrics"
19	"github.com/hashicorp/consul/api"
20	"github.com/hashicorp/consul/lib"
21	"github.com/hashicorp/errwrap"
22	"github.com/hashicorp/go-cleanhttp"
23)
24
25const (
26	// checkJitterFactor specifies the jitter factor used to stagger checks
27	checkJitterFactor = 16
28
29	// checkMinBuffer specifies provides a guarantee that a check will not
30	// be executed too close to the TTL check timeout
31	checkMinBuffer = 100 * time.Millisecond
32
33	// consulRetryInterval specifies the retry duration to use when an
34	// API call to the Consul agent fails.
35	consulRetryInterval = 1 * time.Second
36
37	// defaultCheckTimeout changes the timeout of TTL checks
38	defaultCheckTimeout = 5 * time.Second
39
40	// defaultServiceName is the default Consul service name used when
41	// advertising a Vault instance.
42	defaultServiceName = "vault"
43
44	// reconcileTimeout is how often Vault should query Consul to detect
45	// and fix any state drift.
46	reconcileTimeout = 60 * time.Second
47)
48
49type notifyEvent struct{}
50
51// ConsulBackend is a physical backend that stores data at specific
52// prefix within Consul. It is used for most production situations as
53// it allows Vault to run on multiple machines in a highly-available manner.
54type ConsulBackend struct {
55	path                string
56	logger              *log.Logger
57	client              *api.Client
58	kv                  *api.KV
59	permitPool          *PermitPool
60	serviceLock         sync.RWMutex
61	advertiseHost       string
62	advertisePort       int64
63	serviceName         string
64	disableRegistration bool
65	checkTimeout        time.Duration
66
67	notifyActiveCh chan notifyEvent
68	notifySealedCh chan notifyEvent
69}
70
71// newConsulBackend constructs a Consul backend using the given API client
72// and the prefix in the KV store.
73func newConsulBackend(conf map[string]string, logger *log.Logger) (Backend, error) {
74	// Get the path in Consul
75	path, ok := conf["path"]
76	if !ok {
77		path = "vault/"
78	}
79	logger.Printf("[DEBUG]: consul: config path set to %v", path)
80
81	// Ensure path is suffixed but not prefixed
82	if !strings.HasSuffix(path, "/") {
83		logger.Printf("[WARN]: consul: appending trailing forward slash to path")
84		path += "/"
85	}
86	if strings.HasPrefix(path, "/") {
87		logger.Printf("[WARN]: consul: trimming path of its forward slash")
88		path = strings.TrimPrefix(path, "/")
89	}
90
91	// Allow admins to disable consul integration
92	disableReg, ok := conf["disable_registration"]
93	var disableRegistration bool
94	if ok && disableReg != "" {
95		b, err := strconv.ParseBool(disableReg)
96		if err != nil {
97			return nil, errwrap.Wrapf("failed parsing disable_registration parameter: {{err}}", err)
98		}
99		disableRegistration = b
100	}
101	logger.Printf("[DEBUG]: consul: config disable_registration set to %v", disableRegistration)
102
103	// Get the service name to advertise in Consul
104	service, ok := conf["service"]
105	if !ok {
106		service = defaultServiceName
107	}
108	logger.Printf("[DEBUG]: consul: config service set to %s", service)
109
110	checkTimeout := defaultCheckTimeout
111	checkTimeoutStr, ok := conf["check_timeout"]
112	if ok {
113		d, err := time.ParseDuration(checkTimeoutStr)
114		if err != nil {
115			return nil, err
116		}
117
118		min, _ := lib.DurationMinusBufferDomain(d, checkMinBuffer, checkJitterFactor)
119		if min < checkMinBuffer {
120			return nil, fmt.Errorf("Consul check_timeout must be greater than %v", min)
121		}
122
123		checkTimeout = d
124		logger.Printf("[DEBUG]: consul: config check_timeout set to %v", d)
125	}
126
127	// Configure the client
128	consulConf := api.DefaultConfig()
129
130	if addr, ok := conf["address"]; ok {
131		consulConf.Address = addr
132		logger.Printf("[DEBUG]: consul: config address set to %d", addr)
133	}
134	if scheme, ok := conf["scheme"]; ok {
135		consulConf.Scheme = scheme
136		logger.Printf("[DEBUG]: consul: config scheme set to %d", scheme)
137	}
138	if token, ok := conf["token"]; ok {
139		consulConf.Token = token
140		logger.Printf("[DEBUG]: consul: config token set")
141	}
142
143	if consulConf.Scheme == "https" {
144		tlsClientConfig, err := setupTLSConfig(conf)
145		if err != nil {
146			return nil, err
147		}
148
149		transport := cleanhttp.DefaultPooledTransport()
150		transport.MaxIdleConnsPerHost = 4
151		transport.TLSClientConfig = tlsClientConfig
152		consulConf.HttpClient.Transport = transport
153		logger.Printf("[DEBUG]: consul: configured TLS")
154	}
155
156	client, err := api.NewClient(consulConf)
157	if err != nil {
158		return nil, errwrap.Wrapf("client setup failed: {{err}}", err)
159	}
160
161	maxParStr, ok := conf["max_parallel"]
162	var maxParInt int
163	if ok {
164		maxParInt, err = strconv.Atoi(maxParStr)
165		if err != nil {
166			return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
167		}
168		logger.Printf("[DEBUG]: consul: max_parallel set to %d", maxParInt)
169	}
170
171	// Setup the backend
172	c := &ConsulBackend{
173		path:                path,
174		logger:              logger,
175		client:              client,
176		kv:                  client.KV(),
177		permitPool:          NewPermitPool(maxParInt),
178		serviceName:         service,
179		checkTimeout:        checkTimeout,
180		disableRegistration: disableRegistration,
181	}
182	return c, nil
183}
184
185func setupTLSConfig(conf map[string]string) (*tls.Config, error) {
186	serverName := strings.Split(conf["address"], ":")
187
188	insecureSkipVerify := false
189	if _, ok := conf["tls_skip_verify"]; ok {
190		insecureSkipVerify = true
191	}
192
193	tlsClientConfig := &tls.Config{
194		InsecureSkipVerify: insecureSkipVerify,
195		ServerName:         serverName[0],
196	}
197
198	_, okCert := conf["tls_cert_file"]
199	_, okKey := conf["tls_key_file"]
200
201	if okCert && okKey {
202		tlsCert, err := tls.LoadX509KeyPair(conf["tls_cert_file"], conf["tls_key_file"])
203		if err != nil {
204			return nil, fmt.Errorf("client tls setup failed: %v", err)
205		}
206
207		tlsClientConfig.Certificates = []tls.Certificate{tlsCert}
208	}
209
210	if tlsCaFile, ok := conf["tls_ca_file"]; ok {
211		caPool := x509.NewCertPool()
212
213		data, err := ioutil.ReadFile(tlsCaFile)
214		if err != nil {
215			return nil, fmt.Errorf("failed to read CA file: %v", err)
216		}
217
218		if !caPool.AppendCertsFromPEM(data) {
219			return nil, fmt.Errorf("failed to parse CA certificate")
220		}
221
222		tlsClientConfig.RootCAs = caPool
223	}
224
225	return tlsClientConfig, nil
226}
227
228// Put is used to insert or update an entry
229func (c *ConsulBackend) Put(entry *Entry) error {
230	defer metrics.MeasureSince([]string{"consul", "put"}, time.Now())
231	pair := &api.KVPair{
232		Key:   c.path + entry.Key,
233		Value: entry.Value,
234	}
235
236	c.permitPool.Acquire()
237	defer c.permitPool.Release()
238
239	_, err := c.kv.Put(pair, nil)
240	return err
241}
242
243// Get is used to fetch an entry
244func (c *ConsulBackend) Get(key string) (*Entry, error) {
245	defer metrics.MeasureSince([]string{"consul", "get"}, time.Now())
246
247	c.permitPool.Acquire()
248	defer c.permitPool.Release()
249
250	pair, _, err := c.kv.Get(c.path+key, nil)
251	if err != nil {
252		return nil, err
253	}
254	if pair == nil {
255		return nil, nil
256	}
257	ent := &Entry{
258		Key:   key,
259		Value: pair.Value,
260	}
261	return ent, nil
262}
263
264// Delete is used to permanently delete an entry
265func (c *ConsulBackend) Delete(key string) error {
266	defer metrics.MeasureSince([]string{"consul", "delete"}, time.Now())
267
268	c.permitPool.Acquire()
269	defer c.permitPool.Release()
270
271	_, err := c.kv.Delete(c.path+key, nil)
272	return err
273}
274
275// List is used to list all the keys under a given
276// prefix, up to the next prefix.
277func (c *ConsulBackend) List(prefix string) ([]string, error) {
278	defer metrics.MeasureSince([]string{"consul", "list"}, time.Now())
279	scan := c.path + prefix
280
281	// The TrimPrefix call below will not work correctly if we have "//" at the
282	// end. This can happen in cases where you are e.g. listing the root of a
283	// prefix in a logical backend via "/" instead of ""
284	if strings.HasSuffix(scan, "//") {
285		scan = scan[:len(scan)-1]
286	}
287
288	c.permitPool.Acquire()
289	defer c.permitPool.Release()
290
291	out, _, err := c.kv.Keys(scan, "/", nil)
292	for idx, val := range out {
293		out[idx] = strings.TrimPrefix(val, scan)
294	}
295
296	return out, err
297}
298
299// Lock is used for mutual exclusion based on the given key.
300func (c *ConsulBackend) LockWith(key, value string) (Lock, error) {
301	// Create the lock
302	opts := &api.LockOptions{
303		Key:            c.path + key,
304		Value:          []byte(value),
305		SessionName:    "Vault Lock",
306		MonitorRetries: 5,
307	}
308	lock, err := c.client.LockOpts(opts)
309	if err != nil {
310		return nil, fmt.Errorf("failed to create lock: %v", err)
311	}
312	cl := &ConsulLock{
313		client: c.client,
314		key:    c.path + key,
315		lock:   lock,
316	}
317	return cl, nil
318}
319
320// DetectHostAddr is used to detect the host address by asking the Consul agent
321func (c *ConsulBackend) DetectHostAddr() (string, error) {
322	agent := c.client.Agent()
323	self, err := agent.Self()
324	if err != nil {
325		return "", err
326	}
327	addr, ok := self["Member"]["Addr"].(string)
328	if !ok {
329		return "", fmt.Errorf("Unable to convert an address to string")
330	}
331	return addr, nil
332}
333
334// ConsulLock is used to provide the Lock interface backed by Consul
335type ConsulLock struct {
336	client *api.Client
337	key    string
338	lock   *api.Lock
339}
340
341func (c *ConsulLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
342	return c.lock.Lock(stopCh)
343}
344
345func (c *ConsulLock) Unlock() error {
346	return c.lock.Unlock()
347}
348
349func (c *ConsulLock) Value() (bool, string, error) {
350	kv := c.client.KV()
351
352	pair, _, err := kv.Get(c.key, nil)
353	if err != nil {
354		return false, "", err
355	}
356	if pair == nil {
357		return false, "", nil
358	}
359	held := pair.Session != ""
360	value := string(pair.Value)
361	return held, value, nil
362}
363
364func (c *ConsulBackend) NotifyActiveStateChange() error {
365	select {
366	case c.notifyActiveCh <- notifyEvent{}:
367	default:
368		// NOTE: If this occurs Vault's active status could be out of
369		// sync with Consul until reconcileTimer expires.
370		c.logger.Printf("[WARN]: consul: Concurrent state change notify dropped")
371	}
372
373	return nil
374}
375
376func (c *ConsulBackend) NotifySealedStateChange() error {
377	select {
378	case c.notifySealedCh <- notifyEvent{}:
379	default:
380		// NOTE: If this occurs Vault's sealed status could be out of
381		// sync with Consul until checkTimer expires.
382		c.logger.Printf("[WARN]: consul: Concurrent sealed state change notify dropped")
383	}
384
385	return nil
386}
387
388func (c *ConsulBackend) checkDuration() time.Duration {
389	return lib.DurationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor)
390}
391
392func (c *ConsulBackend) RunServiceDiscovery(shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) (err error) {
393	if err := c.setAdvertiseAddr(advertiseAddr); err != nil {
394		return err
395	}
396
397	go c.runEventDemuxer(shutdownCh, advertiseAddr, activeFunc, sealedFunc)
398
399	return nil
400}
401
402func (c *ConsulBackend) runEventDemuxer(shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) {
403	// Fire the reconcileTimer immediately upon starting the event demuxer
404	reconcileTimer := time.NewTimer(0)
405	defer reconcileTimer.Stop()
406
407	// Schedule the first check.  Consul TTL checks are passing by
408	// default, checkTimer does not need to be run immediately.
409	checkTimer := time.NewTimer(c.checkDuration())
410	defer checkTimer.Stop()
411
412	// Use a reactor pattern to handle and dispatch events to singleton
413	// goroutine handlers for execution.  It is not acceptable to drop
414	// inbound events from Notify*().
415	//
416	// goroutines are dispatched if the demuxer can acquire a lock (via
417	// an atomic CAS incr) on the handler.  Handlers are responsible for
418	// deregistering themselves (atomic CAS decr).  Handlers and the
419	// demuxer share a lock to synchronize information at the beginning
420	// and end of a handler's life (or after a handler wakes up from
421	// sleeping during a back-off/retry).
422	var shutdown bool
423	var checkLock int64
424	var registeredServiceID string
425	var serviceRegLock int64
426shutdown:
427	for {
428		select {
429		case <-c.notifyActiveCh:
430			// Run reconcile immediately upon active state change notification
431			reconcileTimer.Reset(0)
432		case <-c.notifySealedCh:
433			// Run check timer immediately upon a seal state change notification
434			checkTimer.Reset(0)
435		case <-reconcileTimer.C:
436			// Unconditionally rearm the reconcileTimer
437			reconcileTimer.Reset(reconcileTimeout - lib.RandomStagger(reconcileTimeout/checkJitterFactor))
438
439			// Abort if service discovery is disabled or a
440			// reconcile handler is already active
441			if !c.disableRegistration && atomic.CompareAndSwapInt64(&serviceRegLock, 0, 1) {
442				// Enter handler with serviceRegLock held
443				go func() {
444					defer atomic.CompareAndSwapInt64(&serviceRegLock, 1, 0)
445					for !shutdown {
446						serviceID, err := c.reconcileConsul(registeredServiceID, activeFunc, sealedFunc)
447						if err != nil {
448							c.logger.Printf("[WARN]: consul: reconcile unable to talk with Consul backend: %v", err)
449							time.Sleep(consulRetryInterval)
450							continue
451						}
452
453						c.serviceLock.Lock()
454						defer c.serviceLock.Unlock()
455
456						registeredServiceID = serviceID
457						return
458					}
459				}()
460			}
461		case <-checkTimer.C:
462			checkTimer.Reset(c.checkDuration())
463			// Abort if service discovery is disabled or a
464			// reconcile handler is active
465			if !c.disableRegistration && atomic.CompareAndSwapInt64(&checkLock, 0, 1) {
466				// Enter handler with checkLock held
467				go func() {
468					defer atomic.CompareAndSwapInt64(&checkLock, 1, 0)
469					for !shutdown {
470						sealed := sealedFunc()
471						if err := c.runCheck(sealed); err != nil {
472							c.logger.Printf("[WARN]: consul: check unable to talk with Consul backend: %v", err)
473							time.Sleep(consulRetryInterval)
474							continue
475						}
476						return
477					}
478				}()
479			}
480		case <-shutdownCh:
481			c.logger.Printf("[INFO]: consul: Shutting down consul backend")
482			shutdown = true
483			break shutdown
484		}
485	}
486
487	c.serviceLock.RLock()
488	defer c.serviceLock.RUnlock()
489	if err := c.client.Agent().ServiceDeregister(registeredServiceID); err != nil {
490		c.logger.Printf("[WARN]: consul: service deregistration failed: %v", err)
491	}
492}
493
494// checkID returns the ID used for a Consul Check.  Assume at least a read
495// lock is held.
496func (c *ConsulBackend) checkID() string {
497	return "vault-sealed-check"
498}
499
500// reconcileConsul queries the state of Vault Core and Consul and fixes up
501// Consul's state according to what's in Vault.  reconcileConsul is called
502// without any locks held and can be run concurrently, therefore no changes
503// to ConsulBackend can be made in this method (i.e. wtb const receiver for
504// compiler enforced safety).
505func (c *ConsulBackend) reconcileConsul(registeredServiceID string, activeFunc activeFunction, sealedFunc sealedFunction) (serviceID string, err error) {
506	// Query vault Core for its current state
507	active := activeFunc()
508	sealed := sealedFunc()
509
510	agent := c.client.Agent()
511
512	// Get the current state of Vault from Consul
513	var currentVaultService *api.AgentService
514	if services, err := agent.Services(); err == nil {
515		if service, ok := services[c.serviceName]; ok {
516			currentVaultService = service
517		}
518	}
519
520	serviceID = c.serviceID()
521	tags := serviceTags(active)
522
523	var reregister bool
524	switch {
525	case currentVaultService == nil,
526		registeredServiceID == "":
527		reregister = true
528	default:
529		switch {
530		case len(currentVaultService.Tags) != 1,
531			currentVaultService.Tags[0] != tags[0]:
532			reregister = true
533		}
534	}
535
536	if !reregister {
537		return "", nil
538	}
539
540	service := &api.AgentServiceRegistration{
541		ID:                serviceID,
542		Name:              c.serviceName,
543		Tags:              tags,
544		Port:              int(c.advertisePort),
545		Address:           c.advertiseHost,
546		EnableTagOverride: false,
547	}
548
549	checkStatus := api.HealthCritical
550	if !sealed {
551		checkStatus = api.HealthPassing
552	}
553
554	sealedCheck := &api.AgentCheckRegistration{
555		ID:        c.checkID(),
556		Name:      "Vault Sealed Status",
557		Notes:     "Vault service is healthy when Vault is in an unsealed status and can become an active Vault server",
558		ServiceID: serviceID,
559		AgentServiceCheck: api.AgentServiceCheck{
560			TTL:    c.checkTimeout.String(),
561			Status: checkStatus,
562		},
563	}
564
565	if err := agent.ServiceRegister(service); err != nil {
566		return "", errwrap.Wrapf(`service registration failed: {{err}}`, err)
567	}
568
569	if err := agent.CheckRegister(sealedCheck); err != nil {
570		return serviceID, errwrap.Wrapf(`service check registration failed: {{err}}`, err)
571	}
572
573	return serviceID, nil
574}
575
576// runCheck immediately pushes a TTL check.
577func (c *ConsulBackend) runCheck(sealed bool) error {
578	// Run a TTL check
579	agent := c.client.Agent()
580	if !sealed {
581		return agent.PassTTL(c.checkID(), "Vault Unsealed")
582	} else {
583		return agent.FailTTL(c.checkID(), "Vault Sealed")
584	}
585}
586
587// serviceID returns the Vault ServiceID for use in Consul.  Assume at least
588// a read lock is held.
589func (c *ConsulBackend) serviceID() string {
590	return fmt.Sprintf("%s:%s:%d", c.serviceName, c.advertiseHost, c.advertisePort)
591}
592
593// serviceTags returns all of the relevant tags for Consul.
594func serviceTags(active bool) []string {
595	activeTag := "standby"
596	if active {
597		activeTag = "active"
598	}
599	return []string{activeTag}
600}
601
602func (c *ConsulBackend) setAdvertiseAddr(addr string) (err error) {
603	if addr == "" {
604		return fmt.Errorf("advertise address must not be empty")
605	}
606
607	url, err := url.Parse(addr)
608	if err != nil {
609		return errwrap.Wrapf(fmt.Sprintf(`failed to parse advertise URL "%v": {{err}}`, addr), err)
610	}
611
612	var portStr string
613	c.advertiseHost, portStr, err = net.SplitHostPort(url.Host)
614	if err != nil {
615		if url.Scheme == "http" {
616			portStr = "80"
617		} else if url.Scheme == "https" {
618			portStr = "443"
619		} else if url.Scheme == "unix" {
620			portStr = "-1"
621			c.advertiseHost = url.Path
622		} else {
623			return errwrap.Wrapf(fmt.Sprintf(`failed to find a host:port in advertise address "%v": {{err}}`, url.Host), err)
624		}
625	}
626	c.advertisePort, err = strconv.ParseInt(portStr, 10, 0)
627	if err != nil || c.advertisePort < -1 || c.advertisePort > 65535 {
628		return errwrap.Wrapf(fmt.Sprintf(`failed to parse valid port "%v": {{err}}`, portStr), err)
629	}
630
631	return nil
632}
633