1package zookeeper
2
3import (
4	"context"
5	"crypto/tls"
6	"crypto/x509"
7	"fmt"
8	"io/ioutil"
9	"net"
10	"path/filepath"
11	"sort"
12	"strings"
13	"sync"
14	"time"
15
16	"github.com/hashicorp/errwrap"
17	log "github.com/hashicorp/go-hclog"
18	"github.com/hashicorp/vault/sdk/helper/parseutil"
19	"github.com/hashicorp/vault/sdk/physical"
20
21	metrics "github.com/armon/go-metrics"
22	"github.com/hashicorp/vault/sdk/helper/tlsutil"
23	"github.com/samuel/go-zookeeper/zk"
24)
25
26const (
27	// ZKNodeFilePrefix is prefixed to any "files" in ZooKeeper,
28	// so that they do not collide with directory entries. Otherwise,
29	// we cannot delete a file if the path is a full-prefix of another
30	// key.
31	ZKNodeFilePrefix = "_"
32)
33
34// Verify ZooKeeperBackend satisfies the correct interfaces
35var _ physical.Backend = (*ZooKeeperBackend)(nil)
36var _ physical.HABackend = (*ZooKeeperBackend)(nil)
37var _ physical.Lock = (*ZooKeeperHALock)(nil)
38
39// ZooKeeperBackend is a physical backend that stores data at specific
40// prefix within ZooKeeper. It is used in production situations as
41// it allows Vault to run on multiple machines in a highly-available manner.
42type ZooKeeperBackend struct {
43	path   string
44	client *zk.Conn
45	acl    []zk.ACL
46	logger log.Logger
47}
48
49// NewZooKeeperBackend constructs a ZooKeeper backend using the given API client
50// and the prefix in the KV store.
51func NewZooKeeperBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
52	// Get the path in ZooKeeper
53	path, ok := conf["path"]
54	if !ok {
55		path = "vault/"
56	}
57
58	// Ensure path is suffixed and prefixed (zk requires prefix /)
59	if !strings.HasSuffix(path, "/") {
60		path += "/"
61	}
62	if !strings.HasPrefix(path, "/") {
63		path = "/" + path
64	}
65
66	// Configure the client, default to localhost instance
67	var machines string
68	machines, ok = conf["address"]
69	if !ok {
70		machines = "localhost:2181"
71	}
72
73	// zNode owner and schema.
74	var owner string
75	var schema string
76	var schemaAndOwner string
77	schemaAndOwner, ok = conf["znode_owner"]
78	if !ok {
79		owner = "anyone"
80		schema = "world"
81	} else {
82		parsedSchemaAndOwner := strings.SplitN(schemaAndOwner, ":", 2)
83		if len(parsedSchemaAndOwner) != 2 {
84			return nil, fmt.Errorf("znode_owner expected format is 'schema:owner'")
85		} else {
86			schema = parsedSchemaAndOwner[0]
87			owner = parsedSchemaAndOwner[1]
88
89			// znode_owner is in config and structured correctly - but does it make any sense?
90			// Either 'owner' or 'schema' was set but not both - this seems like a failed attempt
91			// (e.g. ':MyUser' which omit the schema, or ':' omitting both)
92			if owner == "" || schema == "" {
93				return nil, fmt.Errorf("znode_owner expected format is 'schema:auth'")
94			}
95		}
96	}
97
98	acl := []zk.ACL{
99		{
100			Perms:  zk.PermAll,
101			Scheme: schema,
102			ID:     owner,
103		},
104	}
105
106	// Authentication info
107	var schemaAndUser string
108	var useAddAuth bool
109	schemaAndUser, useAddAuth = conf["auth_info"]
110	if useAddAuth {
111		parsedSchemaAndUser := strings.SplitN(schemaAndUser, ":", 2)
112		if len(parsedSchemaAndUser) != 2 {
113			return nil, fmt.Errorf("auth_info expected format is 'schema:auth'")
114		} else {
115			schema = parsedSchemaAndUser[0]
116			owner = parsedSchemaAndUser[1]
117
118			// auth_info is in config and structured correctly - but does it make any sense?
119			// Either 'owner' or 'schema' was set but not both - this seems like a failed attempt
120			// (e.g. ':MyUser' which omit the schema, or ':' omitting both)
121			if owner == "" || schema == "" {
122				return nil, fmt.Errorf("auth_info expected format is 'schema:auth'")
123			}
124		}
125	}
126
127	// We have all of the configuration in hand - let's try and connect to ZK
128	client, _, err := createClient(conf, machines, time.Second)
129	if err != nil {
130		return nil, errwrap.Wrapf("client setup failed: {{err}}", err)
131	}
132
133	// ZK AddAuth API if the user asked for it
134	if useAddAuth {
135		err = client.AddAuth(schema, []byte(owner))
136		if err != nil {
137			return nil, errwrap.Wrapf("ZooKeeper rejected authentication information provided at auth_info: {{err}}", err)
138		}
139	}
140
141	// Setup the backend
142	c := &ZooKeeperBackend{
143		path:   path,
144		client: client,
145		acl:    acl,
146		logger: logger,
147	}
148	return c, nil
149}
150
151func caseInsenstiveContains(superset, val string) bool {
152	return strings.Contains(strings.ToUpper(superset), strings.ToUpper(val))
153}
154
155// Returns a client for ZK connection. Config value 'tls_enabled' determines if TLS is enabled or not.
156func createClient(conf map[string]string, machines string, timeout time.Duration) (*zk.Conn, <-chan zk.Event, error) {
157	// 'tls_enabled' defaults to false
158	isTlsEnabled := false
159	isTlsEnabledStr, ok := conf["tls_enabled"]
160
161	if ok && isTlsEnabledStr != "" {
162		parsedBoolval, err := parseutil.ParseBool(isTlsEnabledStr)
163		if err != nil {
164			return nil, nil, errwrap.Wrapf("failed parsing tls_enabled parameter: {{err}}", err)
165		}
166		isTlsEnabled = parsedBoolval
167	}
168
169	if isTlsEnabled {
170		// Create a custom Dialer with cert configuration for TLS handshake.
171		tlsDialer := customTLSDial(conf, machines)
172		options := zk.WithDialer(tlsDialer)
173		return zk.Connect(strings.Split(machines, ","), timeout, options)
174	} else {
175		return zk.Connect(strings.Split(machines, ","), timeout)
176	}
177}
178
179// Vault config file properties:
180// 1. tls_skip_verify: skip host name verification.
181// 2. tls_min_version: minimum supported/acceptable tls version
182// 3. tls_cert_file: Cert file Absolute path
183// 4. tls_key_file: Key file Absolute path
184// 5. tls_ca_file: ca file absolute path
185// 6. tls_verify_ip: If set to true, server's IP is verified in certificate if tls_skip_verify is false.
186func customTLSDial(conf map[string]string, machines string) zk.Dialer {
187	return func(network, addr string, timeout time.Duration) (net.Conn, error) {
188		// Sets the serverName. *Note* the addr field comes in as an IP address
189		serverName, _, sParseErr := net.SplitHostPort(addr)
190		if sParseErr != nil {
191			// If the address is only missing port, assign the full address anyway
192			if strings.Contains(sParseErr.Error(), "missing port") {
193				serverName = addr
194			} else {
195				return nil, errwrap.Wrapf("failed parsing the server address for 'serverName' setting {{err}}", sParseErr)
196			}
197		}
198
199		insecureSkipVerify := false
200		tlsSkipVerify, ok := conf["tls_skip_verify"]
201
202		if ok && tlsSkipVerify != "" {
203			b, err := parseutil.ParseBool(tlsSkipVerify)
204			if err != nil {
205				return nil, errwrap.Wrapf("failed parsing tls_skip_verify parameter: {{err}}", err)
206			}
207			insecureSkipVerify = b
208		}
209
210		if !insecureSkipVerify {
211			// If tls_verify_ip is set to false, Server's DNS name is verified in the CN/SAN of the certificate.
212			// if tls_verify_ip is true, Server's IP is verified in the CN/SAN of the certificate.
213			// These checks happen only when tls_skip_verify is set to false.
214			// This value defaults to false
215			ipSanCheck := false
216			configVal, lookupOk := conf["tls_verify_ip"]
217
218			if lookupOk && configVal != "" {
219				parsedIpSanCheck, ipSanErr := parseutil.ParseBool(configVal)
220				if ipSanErr != nil {
221					return nil, errwrap.Wrapf("failed parsing tls_verify_ip parameter: {{err}}", ipSanErr)
222				}
223				ipSanCheck = parsedIpSanCheck
224			}
225			// The addr/serverName parameter to this method comes in as an IP address.
226			// Here we lookup the DNS name and assign it to serverName if ipSanCheck is set to false
227			if !ipSanCheck {
228				lookupAddressMany, lookupErr := net.LookupAddr(serverName)
229				if lookupErr == nil {
230					for _, lookupAddress := range lookupAddressMany {
231						// strip the trailing '.' from lookupAddr
232						if lookupAddress[len(lookupAddress)-1] == '.' {
233							lookupAddress = lookupAddress[:len(lookupAddress)-1]
234						}
235						// Allow serverName to be replaced only if the lookupname is part of the
236						// supplied machine names
237						// If there is no match, the serverName will continue to be an IP value.
238						if caseInsenstiveContains(machines, lookupAddress) {
239							serverName = lookupAddress
240							break
241						}
242					}
243				}
244			}
245
246		}
247
248		tlsMinVersionStr, ok := conf["tls_min_version"]
249		if !ok {
250			// Set the default value
251			tlsMinVersionStr = "tls12"
252		}
253
254		tlsMinVersion, ok := tlsutil.TLSLookup[tlsMinVersionStr]
255		if !ok {
256			return nil, fmt.Errorf("invalid 'tls_min_version'")
257		}
258
259		tlsClientConfig := &tls.Config{
260			MinVersion:         tlsMinVersion,
261			InsecureSkipVerify: insecureSkipVerify,
262			ServerName:         serverName,
263		}
264
265		_, okCert := conf["tls_cert_file"]
266		_, okKey := conf["tls_key_file"]
267
268		if okCert && okKey {
269			tlsCert, err := tls.LoadX509KeyPair(conf["tls_cert_file"], conf["tls_key_file"])
270			if err != nil {
271				return nil, errwrap.Wrapf("client tls setup failed for ZK: {{err}}", err)
272			}
273
274			tlsClientConfig.Certificates = []tls.Certificate{tlsCert}
275		}
276
277		if tlsCaFile, ok := conf["tls_ca_file"]; ok {
278			caPool := x509.NewCertPool()
279
280			data, err := ioutil.ReadFile(tlsCaFile)
281			if err != nil {
282				return nil, errwrap.Wrapf("failed to read ZK CA file: {{err}}", err)
283			}
284
285			if !caPool.AppendCertsFromPEM(data) {
286				return nil, fmt.Errorf("failed to parse ZK CA certificate")
287			}
288			tlsClientConfig.RootCAs = caPool
289		}
290
291		if network != "tcp" {
292			return nil, fmt.Errorf("unsupported network %q", network)
293		}
294
295		tcpConn, err := net.DialTimeout("tcp", addr, timeout)
296		if err != nil {
297			return nil, err
298		}
299		conn := tls.Client(tcpConn, tlsClientConfig)
300		if err := conn.Handshake(); err != nil {
301			return nil, fmt.Errorf("Handshake failed with Zookeeper : %v", err)
302		}
303		return conn, nil
304	}
305}
306
307// ensurePath is used to create each node in the path hierarchy.
308// We avoid calling this optimistically, and invoke it when we get
309// an error during an operation
310func (c *ZooKeeperBackend) ensurePath(path string, value []byte) error {
311	nodes := strings.Split(path, "/")
312	fullPath := ""
313	for index, node := range nodes {
314		if strings.TrimSpace(node) != "" {
315			fullPath += "/" + node
316			isLastNode := index+1 == len(nodes)
317
318			// set parent nodes to nil, leaf to value
319			// this block reduces round trips by being smart on the leaf create/set
320			if exists, _, _ := c.client.Exists(fullPath); !isLastNode && !exists {
321				if _, err := c.client.Create(fullPath, nil, int32(0), c.acl); err != nil {
322					return err
323				}
324			} else if isLastNode && !exists {
325				if _, err := c.client.Create(fullPath, value, int32(0), c.acl); err != nil {
326					return err
327				}
328			} else if isLastNode && exists {
329				if _, err := c.client.Set(fullPath, value, int32(-1)); err != nil {
330					return err
331				}
332			}
333		}
334	}
335	return nil
336}
337
338// cleanupLogicalPath is used to remove all empty nodes, beginning with deepest one,
339// aborting on first non-empty one, up to top-level node.
340func (c *ZooKeeperBackend) cleanupLogicalPath(path string) error {
341	nodes := strings.Split(path, "/")
342	for i := len(nodes) - 1; i > 0; i-- {
343		fullPath := c.path + strings.Join(nodes[:i], "/")
344
345		_, stat, err := c.client.Exists(fullPath)
346		if err != nil {
347			return errwrap.Wrapf("failed to acquire node data: {{err}}", err)
348		}
349
350		if stat.DataLength > 0 && stat.NumChildren > 0 {
351			panic(fmt.Sprintf("node %q is both of data and leaf type", fullPath))
352		} else if stat.DataLength > 0 {
353			panic(fmt.Sprintf("node %q is a data node, this is either a bug or backend data is corrupted", fullPath))
354		} else if stat.NumChildren > 0 {
355			return nil
356		} else {
357			// Empty node, lets clean it up!
358			if err := c.client.Delete(fullPath, -1); err != nil && err != zk.ErrNoNode {
359				return errwrap.Wrapf(fmt.Sprintf("removal of node %q failed: {{err}}", fullPath), err)
360			}
361		}
362	}
363	return nil
364}
365
366// nodePath returns an zk path based on the given key.
367func (c *ZooKeeperBackend) nodePath(key string) string {
368	return filepath.Join(c.path, filepath.Dir(key), ZKNodeFilePrefix+filepath.Base(key))
369}
370
371// Put is used to insert or update an entry
372func (c *ZooKeeperBackend) Put(ctx context.Context, entry *physical.Entry) error {
373	defer metrics.MeasureSince([]string{"zookeeper", "put"}, time.Now())
374
375	// Attempt to set the full path
376	fullPath := c.nodePath(entry.Key)
377	_, err := c.client.Set(fullPath, entry.Value, -1)
378
379	// If we get ErrNoNode, we need to construct the path hierarchy
380	if err == zk.ErrNoNode {
381		return c.ensurePath(fullPath, entry.Value)
382	}
383	return err
384}
385
386// Get is used to fetch an entry
387func (c *ZooKeeperBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
388	defer metrics.MeasureSince([]string{"zookeeper", "get"}, time.Now())
389
390	// Attempt to read the full path
391	fullPath := c.nodePath(key)
392	value, _, err := c.client.Get(fullPath)
393
394	// Ignore if the node does not exist
395	if err == zk.ErrNoNode {
396		err = nil
397	}
398	if err != nil {
399		return nil, err
400	}
401
402	// Handle a non-existing value
403	if value == nil {
404		return nil, nil
405	}
406	ent := &physical.Entry{
407		Key:   key,
408		Value: value,
409	}
410	return ent, nil
411}
412
413// Delete is used to permanently delete an entry
414func (c *ZooKeeperBackend) Delete(ctx context.Context, key string) error {
415	defer metrics.MeasureSince([]string{"zookeeper", "delete"}, time.Now())
416
417	if key == "" {
418		return nil
419	}
420
421	// Delete the full path
422	fullPath := c.nodePath(key)
423	err := c.client.Delete(fullPath, -1)
424
425	// Mask if the node does not exist
426	if err != nil && err != zk.ErrNoNode {
427		return errwrap.Wrapf(fmt.Sprintf("failed to remove %q: {{err}}", fullPath), err)
428	}
429
430	err = c.cleanupLogicalPath(key)
431
432	return err
433}
434
435// List is used ot list all the keys under a given
436// prefix, up to the next prefix.
437func (c *ZooKeeperBackend) List(ctx context.Context, prefix string) ([]string, error) {
438	defer metrics.MeasureSince([]string{"zookeeper", "list"}, time.Now())
439
440	// Query the children at the full path
441	fullPath := strings.TrimSuffix(c.path+prefix, "/")
442	result, _, err := c.client.Children(fullPath)
443
444	// If the path nodes are missing, no children!
445	if err == zk.ErrNoNode {
446		return []string{}, nil
447	} else if err != nil {
448		return []string{}, err
449	}
450
451	children := []string{}
452	for _, key := range result {
453		childPath := fullPath + "/" + key
454		_, stat, err := c.client.Exists(childPath)
455		if err != nil {
456			// Node is ought to exists, so it must be something different
457			return []string{}, err
458		}
459
460		// Check if this entry is a leaf of a node,
461		// and append the slash which is what Vault depends on
462		// for iteration
463		if stat.DataLength > 0 && stat.NumChildren > 0 {
464			if childPath == c.nodePath("core/lock") {
465				// go-zookeeper Lock() breaks Vault semantics and creates a directory
466				// under the lock file; just treat it like the file Vault expects
467				children = append(children, key[1:])
468			} else {
469				panic(fmt.Sprintf("node %q is both of data and leaf type", childPath))
470			}
471		} else if stat.DataLength == 0 {
472			// No, we cannot differentiate here on number of children as node
473			// can have all it leafs removed, and it still is a node.
474			children = append(children, key+"/")
475		} else {
476			children = append(children, key[1:])
477		}
478	}
479	sort.Strings(children)
480	return children, nil
481}
482
483// LockWith is used for mutual exclusion based on the given key.
484func (c *ZooKeeperBackend) LockWith(key, value string) (physical.Lock, error) {
485	l := &ZooKeeperHALock{
486		in:     c,
487		key:    key,
488		value:  value,
489		logger: c.logger,
490	}
491	return l, nil
492}
493
494// HAEnabled indicates whether the HA functionality should be exposed.
495// Currently always returns true.
496func (c *ZooKeeperBackend) HAEnabled() bool {
497	return true
498}
499
500// ZooKeeperHALock is a ZooKeeper Lock implementation for the HABackend
501type ZooKeeperHALock struct {
502	in     *ZooKeeperBackend
503	key    string
504	value  string
505	logger log.Logger
506
507	held      bool
508	localLock sync.Mutex
509	leaderCh  chan struct{}
510	stopCh    <-chan struct{}
511	zkLock    *zk.Lock
512}
513
514func (i *ZooKeeperHALock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
515	i.localLock.Lock()
516	defer i.localLock.Unlock()
517	if i.held {
518		return nil, fmt.Errorf("lock already held")
519	}
520
521	// Attempt an async acquisition
522	didLock := make(chan struct{})
523	failLock := make(chan error, 1)
524	releaseCh := make(chan bool, 1)
525	lockpath := i.in.nodePath(i.key)
526	go i.attemptLock(lockpath, didLock, failLock, releaseCh)
527
528	// Wait for lock acquisition, failure, or shutdown
529	select {
530	case <-didLock:
531		releaseCh <- false
532	case err := <-failLock:
533		return nil, err
534	case <-stopCh:
535		releaseCh <- true
536		return nil, nil
537	}
538
539	// Create the leader channel
540	i.held = true
541	i.leaderCh = make(chan struct{})
542
543	// Watch for Events which could result in loss of our zkLock and close(i.leaderCh)
544	currentVal, _, lockeventCh, err := i.in.client.GetW(lockpath)
545	if err != nil {
546		return nil, errwrap.Wrapf("unable to watch HA lock: {{err}}", err)
547	}
548	if i.value != string(currentVal) {
549		return nil, fmt.Errorf("lost HA lock immediately before watch")
550	}
551	go i.monitorLock(lockeventCh, i.leaderCh)
552
553	i.stopCh = stopCh
554
555	return i.leaderCh, nil
556}
557
558func (i *ZooKeeperHALock) attemptLock(lockpath string, didLock chan struct{}, failLock chan error, releaseCh chan bool) {
559	// Wait to acquire the lock in ZK
560	lock := zk.NewLock(i.in.client, lockpath, i.in.acl)
561	err := lock.Lock()
562	if err != nil {
563		failLock <- err
564		return
565	}
566	// Set node value
567	data := []byte(i.value)
568	err = i.in.ensurePath(lockpath, data)
569	if err != nil {
570		failLock <- err
571		lock.Unlock()
572		return
573	}
574	i.zkLock = lock
575
576	// Signal that lock is held
577	close(didLock)
578
579	// Handle an early abort
580	release := <-releaseCh
581	if release {
582		lock.Unlock()
583	}
584}
585
586func (i *ZooKeeperHALock) monitorLock(lockeventCh <-chan zk.Event, leaderCh chan struct{}) {
587	for {
588		select {
589		case event := <-lockeventCh:
590			// Lost connection?
591			switch event.State {
592			case zk.StateConnected:
593			case zk.StateHasSession:
594			default:
595				close(leaderCh)
596				return
597			}
598
599			// Lost lock?
600			switch event.Type {
601			case zk.EventNodeChildrenChanged:
602			case zk.EventSession:
603			default:
604				close(leaderCh)
605				return
606			}
607		}
608	}
609}
610
611func (i *ZooKeeperHALock) unlockInternal() error {
612	i.localLock.Lock()
613	defer i.localLock.Unlock()
614	if !i.held {
615		return nil
616	}
617
618	err := i.zkLock.Unlock()
619
620	if err == nil {
621		i.held = false
622		return nil
623	}
624
625	return err
626}
627
628func (i *ZooKeeperHALock) Unlock() error {
629	var err error
630
631	if err = i.unlockInternal(); err != nil {
632		i.logger.Error("failed to release distributed lock", "error", err)
633
634		go func(i *ZooKeeperHALock) {
635			attempts := 0
636			i.logger.Info("launching automated distributed lock release")
637
638			for {
639				if err := i.unlockInternal(); err == nil {
640					i.logger.Info("distributed lock released")
641					return
642				}
643
644				select {
645				case <-time.After(time.Second):
646					attempts := attempts + 1
647					if attempts >= 10 {
648						i.logger.Error("release lock max attempts reached. Lock may not be released", "error", err)
649						return
650					}
651					continue
652				case <-i.stopCh:
653					return
654				}
655			}
656		}(i)
657	}
658
659	return err
660}
661
662func (i *ZooKeeperHALock) Value() (bool, string, error) {
663	lockpath := i.in.nodePath(i.key)
664	value, _, err := i.in.client.Get(lockpath)
665	return (value != nil), string(value), err
666}
667