1package zookeeper
2
3import (
4	"context"
5	"crypto/tls"
6	"crypto/x509"
7	"fmt"
8	"io/ioutil"
9	"net"
10	"path/filepath"
11	"sort"
12	"strings"
13	"sync"
14	"time"
15
16	log "github.com/hashicorp/go-hclog"
17	"github.com/hashicorp/vault/sdk/helper/parseutil"
18	"github.com/hashicorp/vault/sdk/physical"
19
20	metrics "github.com/armon/go-metrics"
21	"github.com/hashicorp/vault/sdk/helper/tlsutil"
22	"github.com/samuel/go-zookeeper/zk"
23)
24
25const (
26	// ZKNodeFilePrefix is prefixed to any "files" in ZooKeeper,
27	// so that they do not collide with directory entries. Otherwise,
28	// we cannot delete a file if the path is a full-prefix of another
29	// key.
30	ZKNodeFilePrefix = "_"
31)
32
33// Verify ZooKeeperBackend satisfies the correct interfaces
34var (
35	_ physical.Backend   = (*ZooKeeperBackend)(nil)
36	_ physical.HABackend = (*ZooKeeperBackend)(nil)
37	_ physical.Lock      = (*ZooKeeperHALock)(nil)
38)
39
40// ZooKeeperBackend is a physical backend that stores data at specific
41// prefix within ZooKeeper. It is used in production situations as
42// it allows Vault to run on multiple machines in a highly-available manner.
43type ZooKeeperBackend struct {
44	path   string
45	client *zk.Conn
46	acl    []zk.ACL
47	logger log.Logger
48}
49
50// NewZooKeeperBackend constructs a ZooKeeper backend using the given API client
51// and the prefix in the KV store.
52func NewZooKeeperBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
53	// Get the path in ZooKeeper
54	path, ok := conf["path"]
55	if !ok {
56		path = "vault/"
57	}
58
59	// Ensure path is suffixed and prefixed (zk requires prefix /)
60	if !strings.HasSuffix(path, "/") {
61		path += "/"
62	}
63	if !strings.HasPrefix(path, "/") {
64		path = "/" + path
65	}
66
67	// Configure the client, default to localhost instance
68	var machines string
69	machines, ok = conf["address"]
70	if !ok {
71		machines = "localhost:2181"
72	}
73
74	// zNode owner and schema.
75	var owner string
76	var schema string
77	var schemaAndOwner string
78	schemaAndOwner, ok = conf["znode_owner"]
79	if !ok {
80		owner = "anyone"
81		schema = "world"
82	} else {
83		parsedSchemaAndOwner := strings.SplitN(schemaAndOwner, ":", 2)
84		if len(parsedSchemaAndOwner) != 2 {
85			return nil, fmt.Errorf("znode_owner expected format is 'schema:owner'")
86		} else {
87			schema = parsedSchemaAndOwner[0]
88			owner = parsedSchemaAndOwner[1]
89
90			// znode_owner is in config and structured correctly - but does it make any sense?
91			// Either 'owner' or 'schema' was set but not both - this seems like a failed attempt
92			// (e.g. ':MyUser' which omit the schema, or ':' omitting both)
93			if owner == "" || schema == "" {
94				return nil, fmt.Errorf("znode_owner expected format is 'schema:auth'")
95			}
96		}
97	}
98
99	acl := []zk.ACL{
100		{
101			Perms:  zk.PermAll,
102			Scheme: schema,
103			ID:     owner,
104		},
105	}
106
107	// Authentication info
108	var schemaAndUser string
109	var useAddAuth bool
110	schemaAndUser, useAddAuth = conf["auth_info"]
111	if useAddAuth {
112		parsedSchemaAndUser := strings.SplitN(schemaAndUser, ":", 2)
113		if len(parsedSchemaAndUser) != 2 {
114			return nil, fmt.Errorf("auth_info expected format is 'schema:auth'")
115		} else {
116			schema = parsedSchemaAndUser[0]
117			owner = parsedSchemaAndUser[1]
118
119			// auth_info is in config and structured correctly - but does it make any sense?
120			// Either 'owner' or 'schema' was set but not both - this seems like a failed attempt
121			// (e.g. ':MyUser' which omit the schema, or ':' omitting both)
122			if owner == "" || schema == "" {
123				return nil, fmt.Errorf("auth_info expected format is 'schema:auth'")
124			}
125		}
126	}
127
128	// We have all of the configuration in hand - let's try and connect to ZK
129	client, _, err := createClient(conf, machines, time.Second)
130	if err != nil {
131		return nil, fmt.Errorf("client setup failed: %w", err)
132	}
133
134	// ZK AddAuth API if the user asked for it
135	if useAddAuth {
136		err = client.AddAuth(schema, []byte(owner))
137		if err != nil {
138			return nil, fmt.Errorf("ZooKeeper rejected authentication information provided at auth_info: %w", err)
139		}
140	}
141
142	// Setup the backend
143	c := &ZooKeeperBackend{
144		path:   path,
145		client: client,
146		acl:    acl,
147		logger: logger,
148	}
149	return c, nil
150}
151
152func caseInsenstiveContains(superset, val string) bool {
153	return strings.Contains(strings.ToUpper(superset), strings.ToUpper(val))
154}
155
156// Returns a client for ZK connection. Config value 'tls_enabled' determines if TLS is enabled or not.
157func createClient(conf map[string]string, machines string, timeout time.Duration) (*zk.Conn, <-chan zk.Event, error) {
158	// 'tls_enabled' defaults to false
159	isTlsEnabled := false
160	isTlsEnabledStr, ok := conf["tls_enabled"]
161
162	if ok && isTlsEnabledStr != "" {
163		parsedBoolval, err := parseutil.ParseBool(isTlsEnabledStr)
164		if err != nil {
165			return nil, nil, fmt.Errorf("failed parsing tls_enabled parameter: %w", err)
166		}
167		isTlsEnabled = parsedBoolval
168	}
169
170	if isTlsEnabled {
171		// Create a custom Dialer with cert configuration for TLS handshake.
172		tlsDialer := customTLSDial(conf, machines)
173		options := zk.WithDialer(tlsDialer)
174		return zk.Connect(strings.Split(machines, ","), timeout, options)
175	} else {
176		return zk.Connect(strings.Split(machines, ","), timeout)
177	}
178}
179
180// Vault config file properties:
181// 1. tls_skip_verify: skip host name verification.
182// 2. tls_min_version: minimum supported/acceptable tls version
183// 3. tls_cert_file: Cert file Absolute path
184// 4. tls_key_file: Key file Absolute path
185// 5. tls_ca_file: ca file absolute path
186// 6. tls_verify_ip: If set to true, server's IP is verified in certificate if tls_skip_verify is false.
187func customTLSDial(conf map[string]string, machines string) zk.Dialer {
188	return func(network, addr string, timeout time.Duration) (net.Conn, error) {
189		// Sets the serverName. *Note* the addr field comes in as an IP address
190		serverName, _, sParseErr := net.SplitHostPort(addr)
191		if sParseErr != nil {
192			// If the address is only missing port, assign the full address anyway
193			if strings.Contains(sParseErr.Error(), "missing port") {
194				serverName = addr
195			} else {
196				return nil, fmt.Errorf("failed parsing the server address for 'serverName' setting %w", sParseErr)
197			}
198		}
199
200		insecureSkipVerify := false
201		tlsSkipVerify, ok := conf["tls_skip_verify"]
202
203		if ok && tlsSkipVerify != "" {
204			b, err := parseutil.ParseBool(tlsSkipVerify)
205			if err != nil {
206				return nil, fmt.Errorf("failed parsing tls_skip_verify parameter: %w", err)
207			}
208			insecureSkipVerify = b
209		}
210
211		if !insecureSkipVerify {
212			// If tls_verify_ip is set to false, Server's DNS name is verified in the CN/SAN of the certificate.
213			// if tls_verify_ip is true, Server's IP is verified in the CN/SAN of the certificate.
214			// These checks happen only when tls_skip_verify is set to false.
215			// This value defaults to false
216			ipSanCheck := false
217			configVal, lookupOk := conf["tls_verify_ip"]
218
219			if lookupOk && configVal != "" {
220				parsedIpSanCheck, ipSanErr := parseutil.ParseBool(configVal)
221				if ipSanErr != nil {
222					return nil, fmt.Errorf("failed parsing tls_verify_ip parameter: %w", ipSanErr)
223				}
224				ipSanCheck = parsedIpSanCheck
225			}
226			// The addr/serverName parameter to this method comes in as an IP address.
227			// Here we lookup the DNS name and assign it to serverName if ipSanCheck is set to false
228			if !ipSanCheck {
229				lookupAddressMany, lookupErr := net.LookupAddr(serverName)
230				if lookupErr == nil {
231					for _, lookupAddress := range lookupAddressMany {
232						// strip the trailing '.' from lookupAddr
233						if lookupAddress[len(lookupAddress)-1] == '.' {
234							lookupAddress = lookupAddress[:len(lookupAddress)-1]
235						}
236						// Allow serverName to be replaced only if the lookupname is part of the
237						// supplied machine names
238						// If there is no match, the serverName will continue to be an IP value.
239						if caseInsenstiveContains(machines, lookupAddress) {
240							serverName = lookupAddress
241							break
242						}
243					}
244				}
245			}
246
247		}
248
249		tlsMinVersionStr, ok := conf["tls_min_version"]
250		if !ok {
251			// Set the default value
252			tlsMinVersionStr = "tls12"
253		}
254
255		tlsMinVersion, ok := tlsutil.TLSLookup[tlsMinVersionStr]
256		if !ok {
257			return nil, fmt.Errorf("invalid 'tls_min_version'")
258		}
259
260		tlsClientConfig := &tls.Config{
261			MinVersion:         tlsMinVersion,
262			InsecureSkipVerify: insecureSkipVerify,
263			ServerName:         serverName,
264		}
265
266		_, okCert := conf["tls_cert_file"]
267		_, okKey := conf["tls_key_file"]
268
269		if okCert && okKey {
270			tlsCert, err := tls.LoadX509KeyPair(conf["tls_cert_file"], conf["tls_key_file"])
271			if err != nil {
272				return nil, fmt.Errorf("client tls setup failed for ZK: %w", err)
273			}
274
275			tlsClientConfig.Certificates = []tls.Certificate{tlsCert}
276		}
277
278		if tlsCaFile, ok := conf["tls_ca_file"]; ok {
279			caPool := x509.NewCertPool()
280
281			data, err := ioutil.ReadFile(tlsCaFile)
282			if err != nil {
283				return nil, fmt.Errorf("failed to read ZK CA file: %w", err)
284			}
285
286			if !caPool.AppendCertsFromPEM(data) {
287				return nil, fmt.Errorf("failed to parse ZK CA certificate")
288			}
289			tlsClientConfig.RootCAs = caPool
290		}
291
292		if network != "tcp" {
293			return nil, fmt.Errorf("unsupported network %q", network)
294		}
295
296		tcpConn, err := net.DialTimeout("tcp", addr, timeout)
297		if err != nil {
298			return nil, err
299		}
300		conn := tls.Client(tcpConn, tlsClientConfig)
301		if err := conn.Handshake(); err != nil {
302			return nil, fmt.Errorf("Handshake failed with Zookeeper : %v", err)
303		}
304		return conn, nil
305	}
306}
307
308// ensurePath is used to create each node in the path hierarchy.
309// We avoid calling this optimistically, and invoke it when we get
310// an error during an operation
311func (c *ZooKeeperBackend) ensurePath(path string, value []byte) error {
312	nodes := strings.Split(path, "/")
313	fullPath := ""
314	for index, node := range nodes {
315		if strings.TrimSpace(node) != "" {
316			fullPath += "/" + node
317			isLastNode := index+1 == len(nodes)
318
319			// set parent nodes to nil, leaf to value
320			// this block reduces round trips by being smart on the leaf create/set
321			if exists, _, _ := c.client.Exists(fullPath); !isLastNode && !exists {
322				if _, err := c.client.Create(fullPath, nil, int32(0), c.acl); err != nil {
323					return err
324				}
325			} else if isLastNode && !exists {
326				if _, err := c.client.Create(fullPath, value, int32(0), c.acl); err != nil {
327					return err
328				}
329			} else if isLastNode && exists {
330				if _, err := c.client.Set(fullPath, value, int32(-1)); err != nil {
331					return err
332				}
333			}
334		}
335	}
336	return nil
337}
338
339// cleanupLogicalPath is used to remove all empty nodes, beginning with deepest one,
340// aborting on first non-empty one, up to top-level node.
341func (c *ZooKeeperBackend) cleanupLogicalPath(path string) error {
342	nodes := strings.Split(path, "/")
343	for i := len(nodes) - 1; i > 0; i-- {
344		fullPath := c.path + strings.Join(nodes[:i], "/")
345
346		_, stat, err := c.client.Exists(fullPath)
347		if err != nil {
348			return fmt.Errorf("failed to acquire node data: %w", err)
349		}
350
351		if stat.DataLength > 0 && stat.NumChildren > 0 {
352			panic(fmt.Sprintf("node %q is both of data and leaf type", fullPath))
353		} else if stat.DataLength > 0 {
354			panic(fmt.Sprintf("node %q is a data node, this is either a bug or backend data is corrupted", fullPath))
355		} else if stat.NumChildren > 0 {
356			return nil
357		} else {
358			// Empty node, lets clean it up!
359			if err := c.client.Delete(fullPath, -1); err != nil && err != zk.ErrNoNode {
360				return fmt.Errorf("removal of node %q failed: %w", fullPath, err)
361			}
362		}
363	}
364	return nil
365}
366
367// nodePath returns an zk path based on the given key.
368func (c *ZooKeeperBackend) nodePath(key string) string {
369	return filepath.Join(c.path, filepath.Dir(key), ZKNodeFilePrefix+filepath.Base(key))
370}
371
372// Put is used to insert or update an entry
373func (c *ZooKeeperBackend) Put(ctx context.Context, entry *physical.Entry) error {
374	defer metrics.MeasureSince([]string{"zookeeper", "put"}, time.Now())
375
376	// Attempt to set the full path
377	fullPath := c.nodePath(entry.Key)
378	_, err := c.client.Set(fullPath, entry.Value, -1)
379
380	// If we get ErrNoNode, we need to construct the path hierarchy
381	if err == zk.ErrNoNode {
382		return c.ensurePath(fullPath, entry.Value)
383	}
384	return err
385}
386
387// Get is used to fetch an entry
388func (c *ZooKeeperBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
389	defer metrics.MeasureSince([]string{"zookeeper", "get"}, time.Now())
390
391	// Attempt to read the full path
392	fullPath := c.nodePath(key)
393	value, _, err := c.client.Get(fullPath)
394
395	// Ignore if the node does not exist
396	if err == zk.ErrNoNode {
397		err = nil
398	}
399	if err != nil {
400		return nil, err
401	}
402
403	// Handle a non-existing value
404	if value == nil {
405		return nil, nil
406	}
407	ent := &physical.Entry{
408		Key:   key,
409		Value: value,
410	}
411	return ent, nil
412}
413
414// Delete is used to permanently delete an entry
415func (c *ZooKeeperBackend) Delete(ctx context.Context, key string) error {
416	defer metrics.MeasureSince([]string{"zookeeper", "delete"}, time.Now())
417
418	if key == "" {
419		return nil
420	}
421
422	// Delete the full path
423	fullPath := c.nodePath(key)
424	err := c.client.Delete(fullPath, -1)
425
426	// Mask if the node does not exist
427	if err != nil && err != zk.ErrNoNode {
428		return fmt.Errorf("failed to remove %q: %w", fullPath, err)
429	}
430
431	err = c.cleanupLogicalPath(key)
432
433	return err
434}
435
436// List is used ot list all the keys under a given
437// prefix, up to the next prefix.
438func (c *ZooKeeperBackend) List(ctx context.Context, prefix string) ([]string, error) {
439	defer metrics.MeasureSince([]string{"zookeeper", "list"}, time.Now())
440
441	// Query the children at the full path
442	fullPath := strings.TrimSuffix(c.path+prefix, "/")
443	result, _, err := c.client.Children(fullPath)
444
445	// If the path nodes are missing, no children!
446	if err == zk.ErrNoNode {
447		return []string{}, nil
448	} else if err != nil {
449		return []string{}, err
450	}
451
452	children := []string{}
453	for _, key := range result {
454		childPath := fullPath + "/" + key
455		_, stat, err := c.client.Exists(childPath)
456		if err != nil {
457			// Node is ought to exists, so it must be something different
458			return []string{}, err
459		}
460
461		// Check if this entry is a leaf of a node,
462		// and append the slash which is what Vault depends on
463		// for iteration
464		if stat.DataLength > 0 && stat.NumChildren > 0 {
465			if childPath == c.nodePath("core/lock") {
466				// go-zookeeper Lock() breaks Vault semantics and creates a directory
467				// under the lock file; just treat it like the file Vault expects
468				children = append(children, key[1:])
469			} else {
470				panic(fmt.Sprintf("node %q is both of data and leaf type", childPath))
471			}
472		} else if stat.DataLength == 0 {
473			// No, we cannot differentiate here on number of children as node
474			// can have all it leafs removed, and it still is a node.
475			children = append(children, key+"/")
476		} else {
477			children = append(children, key[1:])
478		}
479	}
480	sort.Strings(children)
481	return children, nil
482}
483
484// LockWith is used for mutual exclusion based on the given key.
485func (c *ZooKeeperBackend) LockWith(key, value string) (physical.Lock, error) {
486	l := &ZooKeeperHALock{
487		in:     c,
488		key:    key,
489		value:  value,
490		logger: c.logger,
491	}
492	return l, nil
493}
494
495// HAEnabled indicates whether the HA functionality should be exposed.
496// Currently always returns true.
497func (c *ZooKeeperBackend) HAEnabled() bool {
498	return true
499}
500
501// ZooKeeperHALock is a ZooKeeper Lock implementation for the HABackend
502type ZooKeeperHALock struct {
503	in     *ZooKeeperBackend
504	key    string
505	value  string
506	logger log.Logger
507
508	held      bool
509	localLock sync.Mutex
510	leaderCh  chan struct{}
511	stopCh    <-chan struct{}
512	zkLock    *zk.Lock
513}
514
515func (i *ZooKeeperHALock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
516	i.localLock.Lock()
517	defer i.localLock.Unlock()
518	if i.held {
519		return nil, fmt.Errorf("lock already held")
520	}
521
522	// Attempt an async acquisition
523	didLock := make(chan struct{})
524	failLock := make(chan error, 1)
525	releaseCh := make(chan bool, 1)
526	lockpath := i.in.nodePath(i.key)
527	go i.attemptLock(lockpath, didLock, failLock, releaseCh)
528
529	// Wait for lock acquisition, failure, or shutdown
530	select {
531	case <-didLock:
532		releaseCh <- false
533	case err := <-failLock:
534		return nil, err
535	case <-stopCh:
536		releaseCh <- true
537		return nil, nil
538	}
539
540	// Create the leader channel
541	i.held = true
542	i.leaderCh = make(chan struct{})
543
544	// Watch for Events which could result in loss of our zkLock and close(i.leaderCh)
545	currentVal, _, lockeventCh, err := i.in.client.GetW(lockpath)
546	if err != nil {
547		return nil, fmt.Errorf("unable to watch HA lock: %w", err)
548	}
549	if i.value != string(currentVal) {
550		return nil, fmt.Errorf("lost HA lock immediately before watch")
551	}
552	go i.monitorLock(lockeventCh, i.leaderCh)
553
554	i.stopCh = stopCh
555
556	return i.leaderCh, nil
557}
558
559func (i *ZooKeeperHALock) attemptLock(lockpath string, didLock chan struct{}, failLock chan error, releaseCh chan bool) {
560	// Wait to acquire the lock in ZK
561	lock := zk.NewLock(i.in.client, lockpath, i.in.acl)
562	err := lock.Lock()
563	if err != nil {
564		failLock <- err
565		return
566	}
567	// Set node value
568	data := []byte(i.value)
569	err = i.in.ensurePath(lockpath, data)
570	if err != nil {
571		failLock <- err
572		lock.Unlock()
573		return
574	}
575	i.zkLock = lock
576
577	// Signal that lock is held
578	close(didLock)
579
580	// Handle an early abort
581	release := <-releaseCh
582	if release {
583		lock.Unlock()
584	}
585}
586
587func (i *ZooKeeperHALock) monitorLock(lockeventCh <-chan zk.Event, leaderCh chan struct{}) {
588	for {
589		select {
590		case event := <-lockeventCh:
591			// Lost connection?
592			switch event.State {
593			case zk.StateConnected:
594			case zk.StateHasSession:
595			default:
596				close(leaderCh)
597				return
598			}
599
600			// Lost lock?
601			switch event.Type {
602			case zk.EventNodeChildrenChanged:
603			case zk.EventSession:
604			default:
605				close(leaderCh)
606				return
607			}
608		}
609	}
610}
611
612func (i *ZooKeeperHALock) unlockInternal() error {
613	i.localLock.Lock()
614	defer i.localLock.Unlock()
615	if !i.held {
616		return nil
617	}
618
619	err := i.zkLock.Unlock()
620
621	if err == nil {
622		i.held = false
623		return nil
624	}
625
626	return err
627}
628
629func (i *ZooKeeperHALock) Unlock() error {
630	var err error
631
632	if err = i.unlockInternal(); err != nil {
633		i.logger.Error("failed to release distributed lock", "error", err)
634
635		go func(i *ZooKeeperHALock) {
636			attempts := 0
637			i.logger.Info("launching automated distributed lock release")
638
639			for {
640				if err := i.unlockInternal(); err == nil {
641					i.logger.Info("distributed lock released")
642					return
643				}
644
645				select {
646				case <-time.After(time.Second):
647					attempts := attempts + 1
648					if attempts >= 10 {
649						i.logger.Error("release lock max attempts reached. Lock may not be released", "error", err)
650						return
651					}
652					continue
653				case <-i.stopCh:
654					return
655				}
656			}
657		}(i)
658	}
659
660	return err
661}
662
663func (i *ZooKeeperHALock) Value() (bool, string, error) {
664	lockpath := i.in.nodePath(i.key)
665	value, _, err := i.in.client.Get(lockpath)
666	return (value != nil), string(value), err
667}
668