1// Package zk is a native Go client library for the ZooKeeper orchestration service.
2package zk
3
4/*
5TODO:
6* make sure a ping response comes back in a reasonable time
7
8Possible watcher events:
9* Event{Type: EventNotWatching, State: StateDisconnected, Path: path, Err: err}
10*/
11
12import (
13	"context"
14	"crypto/rand"
15	"encoding/binary"
16	"errors"
17	"fmt"
18	"io"
19	"net"
20	"strings"
21	"sync"
22	"sync/atomic"
23	"time"
24)
25
26// ErrNoServer indicates that an operation cannot be completed
27// because attempts to connect to all servers in the list failed.
28var ErrNoServer = errors.New("zk: could not connect to a server")
29
30// ErrInvalidPath indicates that an operation was being attempted on
31// an invalid path. (e.g. empty path)
32var ErrInvalidPath = errors.New("zk: invalid path")
33
34// DefaultLogger uses the stdlib log package for logging.
35var DefaultLogger Logger = defaultLogger{}
36
37const (
38	bufferSize      = 1536 * 1024
39	eventChanSize   = 6
40	sendChanSize    = 16
41	protectedPrefix = "_c_"
42)
43
44type watchType int
45
46const (
47	watchTypeData = iota
48	watchTypeExist
49	watchTypeChild
50)
51
52type watchPathType struct {
53	path  string
54	wType watchType
55}
56
57type Dialer func(network, address string, timeout time.Duration) (net.Conn, error)
58
59// Logger is an interface that can be implemented to provide custom log output.
60type Logger interface {
61	Printf(string, ...interface{})
62}
63
64type authCreds struct {
65	scheme string
66	auth   []byte
67}
68
69type Conn struct {
70	lastZxid         int64
71	sessionID        int64
72	state            State // must be 32-bit aligned
73	xid              uint32
74	sessionTimeoutMs int32 // session timeout in milliseconds
75	passwd           []byte
76
77	dialer         Dialer
78	hostProvider   HostProvider
79	serverMu       sync.Mutex // protects server
80	server         string     // remember the address/port of the current server
81	conn           net.Conn
82	eventChan      chan Event
83	eventCallback  EventCallback // may be nil
84	shouldQuit     chan struct{}
85	shouldQuitOnce sync.Once
86	pingInterval   time.Duration
87	recvTimeout    time.Duration
88	connectTimeout time.Duration
89	maxBufferSize  int
90
91	creds   []authCreds
92	credsMu sync.Mutex // protects server
93
94	sendChan     chan *request
95	requests     map[int32]*request // Xid -> pending request
96	requestsLock sync.Mutex
97	watchers     map[watchPathType][]chan Event
98	watchersLock sync.Mutex
99	closeChan    chan struct{} // channel to tell send loop stop
100
101	// Debug (used by unit tests)
102	reconnectLatch   chan struct{}
103	setWatchLimit    int
104	setWatchCallback func([]*setWatchesRequest)
105
106	// Debug (for recurring re-auth hang)
107	debugCloseRecvLoop bool
108	resendZkAuthFn     func(context.Context, *Conn) error
109
110	logger  Logger
111	logInfo bool // true if information messages are logged; false if only errors are logged
112
113	buf []byte
114}
115
116// connOption represents a connection option.
117type connOption func(c *Conn)
118
119type request struct {
120	xid        int32
121	opcode     int32
122	pkt        interface{}
123	recvStruct interface{}
124	recvChan   chan response
125
126	// Because sending and receiving happen in separate go routines, there's
127	// a possible race condition when creating watches from outside the read
128	// loop. We must ensure that a watcher gets added to the list synchronously
129	// with the response from the server on any request that creates a watch.
130	// In order to not hard code the watch logic for each opcode in the recv
131	// loop the caller can use recvFunc to insert some synchronously code
132	// after a response.
133	recvFunc func(*request, *responseHeader, error)
134}
135
136type response struct {
137	zxid int64
138	err  error
139}
140
141type Event struct {
142	Type   EventType
143	State  State
144	Path   string // For non-session events, the path of the watched node.
145	Err    error
146	Server string // For connection events
147}
148
149// HostProvider is used to represent a set of hosts a ZooKeeper client should connect to.
150// It is an analog of the Java equivalent:
151// http://svn.apache.org/viewvc/zookeeper/trunk/src/java/main/org/apache/zookeeper/client/HostProvider.java?view=markup
152type HostProvider interface {
153	// Init is called first, with the servers specified in the connection string.
154	Init(servers []string) error
155	// Len returns the number of servers.
156	Len() int
157	// Next returns the next server to connect to. retryStart will be true if we've looped through
158	// all known servers without Connected() being called.
159	Next() (server string, retryStart bool)
160	// Notify the HostProvider of a successful connection.
161	Connected()
162}
163
164// ConnectWithDialer establishes a new connection to a pool of zookeeper servers
165// using a custom Dialer. See Connect for further information about session timeout.
166// This method is deprecated and provided for compatibility: use the WithDialer option instead.
167func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Dialer) (*Conn, <-chan Event, error) {
168	return Connect(servers, sessionTimeout, WithDialer(dialer))
169}
170
171// Connect establishes a new connection to a pool of zookeeper
172// servers. The provided session timeout sets the amount of time for which
173// a session is considered valid after losing connection to a server. Within
174// the session timeout it's possible to reestablish a connection to a different
175// server and keep the same session. This is means any ephemeral nodes and
176// watches are maintained.
177func Connect(servers []string, sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) {
178	if len(servers) == 0 {
179		return nil, nil, errors.New("zk: server list must not be empty")
180	}
181
182	srvs := FormatServers(servers)
183
184	// Randomize the order of the servers to avoid creating hotspots
185	stringShuffle(srvs)
186
187	ec := make(chan Event, eventChanSize)
188	conn := &Conn{
189		dialer:         net.DialTimeout,
190		hostProvider:   &DNSHostProvider{},
191		conn:           nil,
192		state:          StateDisconnected,
193		eventChan:      ec,
194		shouldQuit:     make(chan struct{}),
195		connectTimeout: 1 * time.Second,
196		sendChan:       make(chan *request, sendChanSize),
197		requests:       make(map[int32]*request),
198		watchers:       make(map[watchPathType][]chan Event),
199		passwd:         emptyPassword,
200		logger:         DefaultLogger,
201		logInfo:        true, // default is true for backwards compatability
202		buf:            make([]byte, bufferSize),
203		resendZkAuthFn: resendZkAuth,
204	}
205
206	// Set provided options.
207	for _, option := range options {
208		option(conn)
209	}
210
211	if err := conn.hostProvider.Init(srvs); err != nil {
212		return nil, nil, err
213	}
214
215	conn.setTimeouts(int32(sessionTimeout / time.Millisecond))
216	// TODO: This context should be passed in by the caller to be the connection lifecycle context.
217	ctx := context.Background()
218
219	go func() {
220		conn.loop(ctx)
221		conn.flushRequests(ErrClosing)
222		conn.invalidateWatches(ErrClosing)
223		close(conn.eventChan)
224	}()
225	return conn, ec, nil
226}
227
228// WithDialer returns a connection option specifying a non-default Dialer.
229func WithDialer(dialer Dialer) connOption {
230	return func(c *Conn) {
231		c.dialer = dialer
232	}
233}
234
235// WithHostProvider returns a connection option specifying a non-default HostProvider.
236func WithHostProvider(hostProvider HostProvider) connOption {
237	return func(c *Conn) {
238		c.hostProvider = hostProvider
239	}
240}
241
242// WithLogger returns a connection option specifying a non-default Logger
243func WithLogger(logger Logger) connOption {
244	return func(c *Conn) {
245		c.logger = logger
246	}
247}
248
249// WithLogInfo returns a connection option specifying whether or not information messages
250// shoud be logged.
251func WithLogInfo(logInfo bool) connOption {
252	return func(c *Conn) {
253		c.logInfo = logInfo
254	}
255}
256
257// EventCallback is a function that is called when an Event occurs.
258type EventCallback func(Event)
259
260// WithEventCallback returns a connection option that specifies an event
261// callback.
262// The callback must not block - doing so would delay the ZK go routines.
263func WithEventCallback(cb EventCallback) connOption {
264	return func(c *Conn) {
265		c.eventCallback = cb
266	}
267}
268
269// WithMaxBufferSize sets the maximum buffer size used to read and decode
270// packets received from the Zookeeper server. The standard Zookeeper client for
271// Java defaults to a limit of 1mb. For backwards compatibility, this Go client
272// defaults to unbounded unless overridden via this option. A value that is zero
273// or negative indicates that no limit is enforced.
274//
275// This is meant to prevent resource exhaustion in the face of potentially
276// malicious data in ZK. It should generally match the server setting (which
277// also defaults ot 1mb) so that clients and servers agree on the limits for
278// things like the size of data in an individual znode and the total size of a
279// transaction.
280//
281// For production systems, this should be set to a reasonable value (ideally
282// that matches the server configuration). For ops tooling, it is handy to use a
283// much larger limit, in order to do things like clean-up problematic state in
284// the ZK tree. For example, if a single znode has a huge number of children, it
285// is possible for the response to a "list children" operation to exceed this
286// buffer size and cause errors in clients. The only way to subsequently clean
287// up the tree (by removing superfluous children) is to use a client configured
288// with a larger buffer size that can successfully query for all of the child
289// names and then remove them. (Note there are other tools that can list all of
290// the child names without an increased buffer size in the client, but they work
291// by inspecting the servers' transaction logs to enumerate children instead of
292// sending an online request to a server.
293func WithMaxBufferSize(maxBufferSize int) connOption {
294	return func(c *Conn) {
295		c.maxBufferSize = maxBufferSize
296	}
297}
298
299// WithMaxConnBufferSize sets maximum buffer size used to send and encode
300// packets to Zookeeper server. The standard Zookeepeer client for java defaults
301// to a limit of 1mb. This option should be used for non-standard server setup
302// where znode is bigger than default 1mb.
303func WithMaxConnBufferSize(maxBufferSize int) connOption {
304	return func(c *Conn) {
305		c.buf = make([]byte, maxBufferSize)
306	}
307}
308
309// Close will submit a close request with ZK and signal the connection to stop
310// sending and receiving packets.
311func (c *Conn) Close() {
312	c.shouldQuitOnce.Do(func() {
313		close(c.shouldQuit)
314
315		select {
316		case <-c.queueRequest(opClose, &closeRequest{}, &closeResponse{}, nil):
317		case <-time.After(time.Second):
318		}
319	})
320}
321
322// State returns the current state of the connection.
323func (c *Conn) State() State {
324	return State(atomic.LoadInt32((*int32)(&c.state)))
325}
326
327// SessionID returns the current session id of the connection.
328func (c *Conn) SessionID() int64 {
329	return atomic.LoadInt64(&c.sessionID)
330}
331
332// SetLogger sets the logger to be used for printing errors.
333// Logger is an interface provided by this package.
334func (c *Conn) SetLogger(l Logger) {
335	c.logger = l
336}
337
338func (c *Conn) setTimeouts(sessionTimeoutMs int32) {
339	c.sessionTimeoutMs = sessionTimeoutMs
340	sessionTimeout := time.Duration(sessionTimeoutMs) * time.Millisecond
341	c.recvTimeout = sessionTimeout * 2 / 3
342	c.pingInterval = c.recvTimeout / 2
343}
344
345func (c *Conn) setState(state State) {
346	atomic.StoreInt32((*int32)(&c.state), int32(state))
347	c.sendEvent(Event{Type: EventSession, State: state, Server: c.Server()})
348}
349
350func (c *Conn) sendEvent(evt Event) {
351	if c.eventCallback != nil {
352		c.eventCallback(evt)
353	}
354
355	select {
356	case c.eventChan <- evt:
357	default:
358		// panic("zk: event channel full - it must be monitored and never allowed to be full")
359	}
360}
361
362func (c *Conn) connect() error {
363	var retryStart bool
364	for {
365		c.serverMu.Lock()
366		c.server, retryStart = c.hostProvider.Next()
367		c.serverMu.Unlock()
368
369		c.setState(StateConnecting)
370
371		if retryStart {
372			c.flushUnsentRequests(ErrNoServer)
373			select {
374			case <-time.After(time.Second):
375				// pass
376			case <-c.shouldQuit:
377				c.setState(StateDisconnected)
378				c.flushUnsentRequests(ErrClosing)
379				return ErrClosing
380			}
381		}
382
383		zkConn, err := c.dialer("tcp", c.Server(), c.connectTimeout)
384		if err == nil {
385			c.conn = zkConn
386			c.setState(StateConnected)
387			if c.logInfo {
388				c.logger.Printf("connected to %s", c.Server())
389			}
390			return nil
391		}
392
393		c.logger.Printf("failed to connect to %s: %v", c.Server(), err)
394	}
395}
396
397func (c *Conn) sendRequest(
398	opcode int32,
399	req interface{},
400	res interface{},
401	recvFunc func(*request, *responseHeader, error),
402) (
403	<-chan response,
404	error,
405) {
406	rq := &request{
407		xid:        c.nextXid(),
408		opcode:     opcode,
409		pkt:        req,
410		recvStruct: res,
411		recvChan:   make(chan response, 1),
412		recvFunc:   recvFunc,
413	}
414
415	if err := c.sendData(rq); err != nil {
416		return nil, err
417	}
418
419	return rq.recvChan, nil
420}
421
422func (c *Conn) loop(ctx context.Context) {
423	for {
424		if err := c.connect(); err != nil {
425			// c.Close() was called
426			return
427		}
428
429		err := c.authenticate()
430		switch {
431		case err == ErrSessionExpired:
432			c.logger.Printf("authentication failed: %s", err)
433			c.invalidateWatches(err)
434		case err != nil && c.conn != nil:
435			c.logger.Printf("authentication failed: %s", err)
436			c.conn.Close()
437		case err == nil:
438			if c.logInfo {
439				c.logger.Printf("authenticated: id=%d, timeout=%d", c.SessionID(), c.sessionTimeoutMs)
440			}
441			c.hostProvider.Connected()        // mark success
442			c.closeChan = make(chan struct{}) // channel to tell send loop stop
443
444			var wg sync.WaitGroup
445
446			wg.Add(1)
447			go func() {
448				defer c.conn.Close() // causes recv loop to EOF/exit
449				defer wg.Done()
450
451				if err := c.resendZkAuthFn(ctx, c); err != nil {
452					c.logger.Printf("error in resending auth creds: %v", err)
453					return
454				}
455
456				if err := c.sendLoop(); err != nil || c.logInfo {
457					c.logger.Printf("send loop terminated: %v", err)
458				}
459			}()
460
461			wg.Add(1)
462			go func() {
463				defer close(c.closeChan) // tell send loop to exit
464				defer wg.Done()
465
466				var err error
467				if c.debugCloseRecvLoop {
468					err = errors.New("DEBUG: close recv loop")
469				} else {
470					err = c.recvLoop(c.conn)
471				}
472				if err != io.EOF || c.logInfo {
473					c.logger.Printf("recv loop terminated: %v", err)
474				}
475				if err == nil {
476					panic("zk: recvLoop should never return nil error")
477				}
478			}()
479
480			c.sendSetWatches()
481			wg.Wait()
482		}
483
484		c.setState(StateDisconnected)
485
486		select {
487		case <-c.shouldQuit:
488			c.flushRequests(ErrClosing)
489			return
490		default:
491		}
492
493		if err != ErrSessionExpired {
494			err = ErrConnectionClosed
495		}
496		c.flushRequests(err)
497
498		if c.reconnectLatch != nil {
499			select {
500			case <-c.shouldQuit:
501				return
502			case <-c.reconnectLatch:
503			}
504		}
505	}
506}
507
508func (c *Conn) flushUnsentRequests(err error) {
509	for {
510		select {
511		default:
512			return
513		case req := <-c.sendChan:
514			req.recvChan <- response{-1, err}
515		}
516	}
517}
518
519// Send error to all pending requests and clear request map
520func (c *Conn) flushRequests(err error) {
521	c.requestsLock.Lock()
522	for _, req := range c.requests {
523		req.recvChan <- response{-1, err}
524	}
525	c.requests = make(map[int32]*request)
526	c.requestsLock.Unlock()
527}
528
529// Send error to all watchers and clear watchers map
530func (c *Conn) invalidateWatches(err error) {
531	c.watchersLock.Lock()
532	defer c.watchersLock.Unlock()
533
534	if len(c.watchers) >= 0 {
535		for pathType, watchers := range c.watchers {
536			ev := Event{Type: EventNotWatching, State: StateDisconnected, Path: pathType.path, Err: err}
537			for _, ch := range watchers {
538				ch <- ev
539				close(ch)
540			}
541		}
542		c.watchers = make(map[watchPathType][]chan Event)
543	}
544}
545
546func (c *Conn) sendSetWatches() {
547	c.watchersLock.Lock()
548	defer c.watchersLock.Unlock()
549
550	if len(c.watchers) == 0 {
551		return
552	}
553
554	// NB: A ZK server, by default, rejects packets >1mb. So, if we have too
555	// many watches to reset, we need to break this up into multiple packets
556	// to avoid hitting that limit. Mirroring the Java client behavior: we are
557	// conservative in that we limit requests to 128kb (since server limit is
558	// is actually configurable and could conceivably be configured smaller
559	// than default of 1mb).
560	limit := 128 * 1024
561	if c.setWatchLimit > 0 {
562		limit = c.setWatchLimit
563	}
564
565	var reqs []*setWatchesRequest
566	var req *setWatchesRequest
567	var sizeSoFar int
568
569	n := 0
570	for pathType, watchers := range c.watchers {
571		if len(watchers) == 0 {
572			continue
573		}
574		addlLen := 4 + len(pathType.path)
575		if req == nil || sizeSoFar+addlLen > limit {
576			if req != nil {
577				// add to set of requests that we'll send
578				reqs = append(reqs, req)
579			}
580			sizeSoFar = 28 // fixed overhead of a set-watches packet
581			req = &setWatchesRequest{
582				RelativeZxid: c.lastZxid,
583				DataWatches:  make([]string, 0),
584				ExistWatches: make([]string, 0),
585				ChildWatches: make([]string, 0),
586			}
587		}
588		sizeSoFar += addlLen
589		switch pathType.wType {
590		case watchTypeData:
591			req.DataWatches = append(req.DataWatches, pathType.path)
592		case watchTypeExist:
593			req.ExistWatches = append(req.ExistWatches, pathType.path)
594		case watchTypeChild:
595			req.ChildWatches = append(req.ChildWatches, pathType.path)
596		}
597		n++
598	}
599	if n == 0 {
600		return
601	}
602	if req != nil { // don't forget any trailing packet we were building
603		reqs = append(reqs, req)
604	}
605
606	if c.setWatchCallback != nil {
607		c.setWatchCallback(reqs)
608	}
609
610	go func() {
611		res := &setWatchesResponse{}
612		// TODO: Pipeline these so queue all of them up before waiting on any
613		// response. That will require some investigation to make sure there
614		// aren't failure modes where a blocking write to the channel of requests
615		// could hang indefinitely and cause this goroutine to leak...
616		for _, req := range reqs {
617			_, err := c.request(opSetWatches, req, res, nil)
618			if err != nil {
619				c.logger.Printf("Failed to set previous watches: %v", err)
620				break
621			}
622		}
623	}()
624}
625
626func (c *Conn) authenticate() error {
627	buf := make([]byte, 256)
628
629	// Encode and send a connect request.
630	n, err := encodePacket(buf[4:], &connectRequest{
631		ProtocolVersion: protocolVersion,
632		LastZxidSeen:    c.lastZxid,
633		TimeOut:         c.sessionTimeoutMs,
634		SessionID:       c.SessionID(),
635		Passwd:          c.passwd,
636	})
637	if err != nil {
638		return err
639	}
640
641	binary.BigEndian.PutUint32(buf[:4], uint32(n))
642
643	c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout * 10))
644	_, err = c.conn.Write(buf[:n+4])
645	c.conn.SetWriteDeadline(time.Time{})
646	if err != nil {
647		return err
648	}
649
650	// Receive and decode a connect response.
651	c.conn.SetReadDeadline(time.Now().Add(c.recvTimeout * 10))
652	_, err = io.ReadFull(c.conn, buf[:4])
653	c.conn.SetReadDeadline(time.Time{})
654	if err != nil {
655		return err
656	}
657
658	blen := int(binary.BigEndian.Uint32(buf[:4]))
659	if cap(buf) < blen {
660		buf = make([]byte, blen)
661	}
662
663	_, err = io.ReadFull(c.conn, buf[:blen])
664	if err != nil {
665		return err
666	}
667
668	r := connectResponse{}
669	_, err = decodePacket(buf[:blen], &r)
670	if err != nil {
671		return err
672	}
673	if r.SessionID == 0 {
674		atomic.StoreInt64(&c.sessionID, int64(0))
675		c.passwd = emptyPassword
676		c.lastZxid = 0
677		c.setState(StateExpired)
678		return ErrSessionExpired
679	}
680
681	atomic.StoreInt64(&c.sessionID, r.SessionID)
682	c.setTimeouts(r.TimeOut)
683	c.passwd = r.Passwd
684	c.setState(StateHasSession)
685
686	return nil
687}
688
689func (c *Conn) sendData(req *request) error {
690	header := &requestHeader{req.xid, req.opcode}
691	n, err := encodePacket(c.buf[4:], header)
692	if err != nil {
693		req.recvChan <- response{-1, err}
694		return nil
695	}
696
697	n2, err := encodePacket(c.buf[4+n:], req.pkt)
698	if err != nil {
699		req.recvChan <- response{-1, err}
700		return nil
701	}
702
703	n += n2
704
705	binary.BigEndian.PutUint32(c.buf[:4], uint32(n))
706
707	c.requestsLock.Lock()
708	select {
709	case <-c.closeChan:
710		req.recvChan <- response{-1, ErrConnectionClosed}
711		c.requestsLock.Unlock()
712		return ErrConnectionClosed
713	default:
714	}
715	c.requests[req.xid] = req
716	c.requestsLock.Unlock()
717
718	c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout))
719	_, err = c.conn.Write(c.buf[:n+4])
720	c.conn.SetWriteDeadline(time.Time{})
721	if err != nil {
722		req.recvChan <- response{-1, err}
723		c.conn.Close()
724		return err
725	}
726
727	return nil
728}
729
730func (c *Conn) sendLoop() error {
731	pingTicker := time.NewTicker(c.pingInterval)
732	defer pingTicker.Stop()
733
734	for {
735		select {
736		case req := <-c.sendChan:
737			if err := c.sendData(req); err != nil {
738				return err
739			}
740		case <-pingTicker.C:
741			n, err := encodePacket(c.buf[4:], &requestHeader{Xid: -2, Opcode: opPing})
742			if err != nil {
743				panic("zk: opPing should never fail to serialize")
744			}
745
746			binary.BigEndian.PutUint32(c.buf[:4], uint32(n))
747
748			c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout))
749			_, err = c.conn.Write(c.buf[:n+4])
750			c.conn.SetWriteDeadline(time.Time{})
751			if err != nil {
752				c.conn.Close()
753				return err
754			}
755		case <-c.closeChan:
756			return nil
757		}
758	}
759}
760
761func (c *Conn) recvLoop(conn net.Conn) error {
762	sz := bufferSize
763	if c.maxBufferSize > 0 && sz > c.maxBufferSize {
764		sz = c.maxBufferSize
765	}
766	buf := make([]byte, sz)
767	for {
768		// package length
769		if err := conn.SetReadDeadline(time.Now().Add(c.recvTimeout)); err != nil {
770			c.logger.Printf("failed to set connection deadline: %v", err)
771		}
772		_, err := io.ReadFull(conn, buf[:4])
773		if err != nil {
774			return fmt.Errorf("failed to read from connection: %v", err)
775		}
776
777		blen := int(binary.BigEndian.Uint32(buf[:4]))
778		if cap(buf) < blen {
779			if c.maxBufferSize > 0 && blen > c.maxBufferSize {
780				return fmt.Errorf("received packet from server with length %d, which exceeds max buffer size %d", blen, c.maxBufferSize)
781			}
782			buf = make([]byte, blen)
783		}
784
785		_, err = io.ReadFull(conn, buf[:blen])
786		conn.SetReadDeadline(time.Time{})
787		if err != nil {
788			return err
789		}
790
791		res := responseHeader{}
792		_, err = decodePacket(buf[:16], &res)
793		if err != nil {
794			return err
795		}
796
797		if res.Xid == -1 {
798			res := &watcherEvent{}
799			_, err := decodePacket(buf[16:blen], res)
800			if err != nil {
801				return err
802			}
803			ev := Event{
804				Type:  res.Type,
805				State: res.State,
806				Path:  res.Path,
807				Err:   nil,
808			}
809			c.sendEvent(ev)
810			wTypes := make([]watchType, 0, 2)
811			switch res.Type {
812			case EventNodeCreated:
813				wTypes = append(wTypes, watchTypeExist)
814			case EventNodeDeleted, EventNodeDataChanged:
815				wTypes = append(wTypes, watchTypeExist, watchTypeData, watchTypeChild)
816			case EventNodeChildrenChanged:
817				wTypes = append(wTypes, watchTypeChild)
818			}
819			c.watchersLock.Lock()
820			for _, t := range wTypes {
821				wpt := watchPathType{res.Path, t}
822				if watchers := c.watchers[wpt]; watchers != nil && len(watchers) > 0 {
823					for _, ch := range watchers {
824						ch <- ev
825						close(ch)
826					}
827					delete(c.watchers, wpt)
828				}
829			}
830			c.watchersLock.Unlock()
831		} else if res.Xid == -2 {
832			// Ping response. Ignore.
833		} else if res.Xid < 0 {
834			c.logger.Printf("Xid < 0 (%d) but not ping or watcher event", res.Xid)
835		} else {
836			if res.Zxid > 0 {
837				c.lastZxid = res.Zxid
838			}
839
840			c.requestsLock.Lock()
841			req, ok := c.requests[res.Xid]
842			if ok {
843				delete(c.requests, res.Xid)
844			}
845			c.requestsLock.Unlock()
846
847			if !ok {
848				c.logger.Printf("Response for unknown request with xid %d", res.Xid)
849			} else {
850				if res.Err != 0 {
851					err = res.Err.toError()
852				} else {
853					_, err = decodePacket(buf[16:blen], req.recvStruct)
854				}
855				if req.recvFunc != nil {
856					req.recvFunc(req, &res, err)
857				}
858				req.recvChan <- response{res.Zxid, err}
859				if req.opcode == opClose {
860					return io.EOF
861				}
862			}
863		}
864	}
865}
866
867func (c *Conn) nextXid() int32 {
868	return int32(atomic.AddUint32(&c.xid, 1) & 0x7fffffff)
869}
870
871func (c *Conn) addWatcher(path string, watchType watchType) <-chan Event {
872	c.watchersLock.Lock()
873	defer c.watchersLock.Unlock()
874
875	ch := make(chan Event, 1)
876	wpt := watchPathType{path, watchType}
877	c.watchers[wpt] = append(c.watchers[wpt], ch)
878	return ch
879}
880
881func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) <-chan response {
882	rq := &request{
883		xid:        c.nextXid(),
884		opcode:     opcode,
885		pkt:        req,
886		recvStruct: res,
887		recvChan:   make(chan response, 2),
888		recvFunc:   recvFunc,
889	}
890
891	switch opcode {
892	case opClose:
893		// always attempt to send close ops.
894		select {
895		case c.sendChan <- rq:
896		case <-time.After(c.connectTimeout * 2):
897			c.logger.Printf("gave up trying to send opClose to server")
898			rq.recvChan <- response{-1, ErrConnectionClosed}
899		}
900	default:
901		// otherwise avoid deadlocks for dumb clients who aren't aware that
902		// the ZK connection is closed yet.
903		select {
904		case <-c.shouldQuit:
905			rq.recvChan <- response{-1, ErrConnectionClosed}
906		case c.sendChan <- rq:
907			// check for a tie
908			select {
909			case <-c.shouldQuit:
910				// maybe the caller gets this, maybe not- we tried.
911				rq.recvChan <- response{-1, ErrConnectionClosed}
912			default:
913			}
914		}
915	}
916	return rq.recvChan
917}
918
919func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (int64, error) {
920	r := <-c.queueRequest(opcode, req, res, recvFunc)
921	select {
922	case <-c.shouldQuit:
923		// queueRequest() can be racy, double-check for the race here and avoid
924		// a potential data-race. otherwise the client of this func may try to
925		// access `res` fields concurrently w/ the async response processor.
926		// NOTE: callers of this func should check for (at least) ErrConnectionClosed
927		// and avoid accessing fields of the response object if such error is present.
928		return -1, ErrConnectionClosed
929	default:
930		return r.zxid, r.err
931	}
932}
933
934func (c *Conn) AddAuth(scheme string, auth []byte) error {
935	_, err := c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil)
936
937	if err != nil {
938		return err
939	}
940
941	// Remember authdata so that it can be re-submitted on reconnect
942	//
943	// FIXME(prozlach): For now we treat "userfoo:passbar" and "userfoo:passbar2"
944	// as two different entries, which will be re-submitted on reconnet. Some
945	// research is needed on how ZK treats these cases and
946	// then maybe switch to something like "map[username] = password" to allow
947	// only single password for given user with users being unique.
948	obj := authCreds{
949		scheme: scheme,
950		auth:   auth,
951	}
952
953	c.credsMu.Lock()
954	c.creds = append(c.creds, obj)
955	c.credsMu.Unlock()
956
957	return nil
958}
959
960func (c *Conn) Children(path string) ([]string, *Stat, error) {
961	if err := validatePath(path, false); err != nil {
962		return nil, nil, err
963	}
964
965	res := &getChildren2Response{}
966	_, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: false}, res, nil)
967	if err == ErrConnectionClosed {
968		return nil, nil, err
969	}
970	return res.Children, &res.Stat, err
971}
972
973func (c *Conn) ChildrenW(path string) ([]string, *Stat, <-chan Event, error) {
974	if err := validatePath(path, false); err != nil {
975		return nil, nil, nil, err
976	}
977
978	var ech <-chan Event
979	res := &getChildren2Response{}
980	_, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) {
981		if err == nil {
982			ech = c.addWatcher(path, watchTypeChild)
983		}
984	})
985	if err != nil {
986		return nil, nil, nil, err
987	}
988	return res.Children, &res.Stat, ech, err
989}
990
991func (c *Conn) Get(path string) ([]byte, *Stat, error) {
992	if err := validatePath(path, false); err != nil {
993		return nil, nil, err
994	}
995
996	res := &getDataResponse{}
997	_, err := c.request(opGetData, &getDataRequest{Path: path, Watch: false}, res, nil)
998	if err == ErrConnectionClosed {
999		return nil, nil, err
1000	}
1001	return res.Data, &res.Stat, err
1002}
1003
1004// GetW returns the contents of a znode and sets a watch
1005func (c *Conn) GetW(path string) ([]byte, *Stat, <-chan Event, error) {
1006	if err := validatePath(path, false); err != nil {
1007		return nil, nil, nil, err
1008	}
1009
1010	var ech <-chan Event
1011	res := &getDataResponse{}
1012	_, err := c.request(opGetData, &getDataRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) {
1013		if err == nil {
1014			ech = c.addWatcher(path, watchTypeData)
1015		}
1016	})
1017	if err != nil {
1018		return nil, nil, nil, err
1019	}
1020	return res.Data, &res.Stat, ech, err
1021}
1022
1023func (c *Conn) Set(path string, data []byte, version int32) (*Stat, error) {
1024	if err := validatePath(path, false); err != nil {
1025		return nil, err
1026	}
1027
1028	res := &setDataResponse{}
1029	_, err := c.request(opSetData, &SetDataRequest{path, data, version}, res, nil)
1030	if err == ErrConnectionClosed {
1031		return nil, err
1032	}
1033	return &res.Stat, err
1034}
1035
1036func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, error) {
1037	if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil {
1038		return "", err
1039	}
1040
1041	res := &createResponse{}
1042	_, err := c.request(opCreate, &CreateRequest{path, data, acl, flags}, res, nil)
1043	if err == ErrConnectionClosed {
1044		return "", err
1045	}
1046	return res.Path, err
1047}
1048
1049func (c *Conn) CreateContainer(path string, data []byte, flags int32, acl []ACL) (string, error) {
1050	if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil {
1051		return "", err
1052	}
1053	if flags&FlagTTL != FlagTTL {
1054		return "", ErrInvalidFlags
1055	}
1056
1057	res := &createResponse{}
1058	_, err := c.request(opCreateContainer, &CreateContainerRequest{path, data, acl, flags}, res, nil)
1059	return res.Path, err
1060}
1061
1062func (c *Conn) CreateTTL(path string, data []byte, flags int32, acl []ACL, ttl time.Duration) (string, error) {
1063	if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil {
1064		return "", err
1065	}
1066	if flags&FlagTTL != FlagTTL {
1067		return "", ErrInvalidFlags
1068	}
1069
1070	res := &createResponse{}
1071	_, err := c.request(opCreateTTL, &CreateTTLRequest{path, data, acl, flags, ttl.Milliseconds()}, res, nil)
1072	return res.Path, err
1073}
1074
1075// CreateProtectedEphemeralSequential fixes a race condition if the server crashes
1076// after it creates the node. On reconnect the session may still be valid so the
1077// ephemeral node still exists. Therefore, on reconnect we need to check if a node
1078// with a GUID generated on create exists.
1079func (c *Conn) CreateProtectedEphemeralSequential(path string, data []byte, acl []ACL) (string, error) {
1080	if err := validatePath(path, true); err != nil {
1081		return "", err
1082	}
1083
1084	var guid [16]byte
1085	_, err := io.ReadFull(rand.Reader, guid[:16])
1086	if err != nil {
1087		return "", err
1088	}
1089	guidStr := fmt.Sprintf("%x", guid)
1090
1091	parts := strings.Split(path, "/")
1092	parts[len(parts)-1] = fmt.Sprintf("%s%s-%s", protectedPrefix, guidStr, parts[len(parts)-1])
1093	rootPath := strings.Join(parts[:len(parts)-1], "/")
1094	protectedPath := strings.Join(parts, "/")
1095
1096	var newPath string
1097	for i := 0; i < 3; i++ {
1098		newPath, err = c.Create(protectedPath, data, FlagEphemeral|FlagSequence, acl)
1099		switch err {
1100		case ErrSessionExpired:
1101			// No need to search for the node since it can't exist. Just try again.
1102		case ErrConnectionClosed:
1103			children, _, err := c.Children(rootPath)
1104			if err != nil {
1105				return "", err
1106			}
1107			for _, p := range children {
1108				parts := strings.Split(p, "/")
1109				if pth := parts[len(parts)-1]; strings.HasPrefix(pth, protectedPrefix) {
1110					if g := pth[len(protectedPrefix) : len(protectedPrefix)+32]; g == guidStr {
1111						return rootPath + "/" + p, nil
1112					}
1113				}
1114			}
1115		case nil:
1116			return newPath, nil
1117		default:
1118			return "", err
1119		}
1120	}
1121	return "", err
1122}
1123
1124func (c *Conn) Delete(path string, version int32) error {
1125	if err := validatePath(path, false); err != nil {
1126		return err
1127	}
1128
1129	_, err := c.request(opDelete, &DeleteRequest{path, version}, &deleteResponse{}, nil)
1130	return err
1131}
1132
1133func (c *Conn) Exists(path string) (bool, *Stat, error) {
1134	if err := validatePath(path, false); err != nil {
1135		return false, nil, err
1136	}
1137
1138	res := &existsResponse{}
1139	_, err := c.request(opExists, &existsRequest{Path: path, Watch: false}, res, nil)
1140	if err == ErrConnectionClosed {
1141		return false, nil, err
1142	}
1143	exists := true
1144	if err == ErrNoNode {
1145		exists = false
1146		err = nil
1147	}
1148	return exists, &res.Stat, err
1149}
1150
1151func (c *Conn) ExistsW(path string) (bool, *Stat, <-chan Event, error) {
1152	if err := validatePath(path, false); err != nil {
1153		return false, nil, nil, err
1154	}
1155
1156	var ech <-chan Event
1157	res := &existsResponse{}
1158	_, err := c.request(opExists, &existsRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) {
1159		if err == nil {
1160			ech = c.addWatcher(path, watchTypeData)
1161		} else if err == ErrNoNode {
1162			ech = c.addWatcher(path, watchTypeExist)
1163		}
1164	})
1165	exists := true
1166	if err == ErrNoNode {
1167		exists = false
1168		err = nil
1169	}
1170	if err != nil {
1171		return false, nil, nil, err
1172	}
1173	return exists, &res.Stat, ech, err
1174}
1175
1176func (c *Conn) GetACL(path string) ([]ACL, *Stat, error) {
1177	if err := validatePath(path, false); err != nil {
1178		return nil, nil, err
1179	}
1180
1181	res := &getAclResponse{}
1182	_, err := c.request(opGetAcl, &getAclRequest{Path: path}, res, nil)
1183	if err == ErrConnectionClosed {
1184		return nil, nil, err
1185	}
1186	return res.Acl, &res.Stat, err
1187}
1188func (c *Conn) SetACL(path string, acl []ACL, version int32) (*Stat, error) {
1189	if err := validatePath(path, false); err != nil {
1190		return nil, err
1191	}
1192
1193	res := &setAclResponse{}
1194	_, err := c.request(opSetAcl, &setAclRequest{Path: path, Acl: acl, Version: version}, res, nil)
1195	if err == ErrConnectionClosed {
1196		return nil, err
1197	}
1198	return &res.Stat, err
1199}
1200
1201func (c *Conn) Sync(path string) (string, error) {
1202	if err := validatePath(path, false); err != nil {
1203		return "", err
1204	}
1205
1206	res := &syncResponse{}
1207	_, err := c.request(opSync, &syncRequest{Path: path}, res, nil)
1208	if err == ErrConnectionClosed {
1209		return "", err
1210	}
1211	return res.Path, err
1212}
1213
1214type MultiResponse struct {
1215	Stat   *Stat
1216	String string
1217	Error  error
1218}
1219
1220// Multi executes multiple ZooKeeper operations or none of them. The provided
1221// ops must be one of *CreateRequest, *DeleteRequest, *SetDataRequest, or
1222// *CheckVersionRequest.
1223func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) {
1224	req := &multiRequest{
1225		Ops:        make([]multiRequestOp, 0, len(ops)),
1226		DoneHeader: multiHeader{Type: -1, Done: true, Err: -1},
1227	}
1228	for _, op := range ops {
1229		var opCode int32
1230		switch op.(type) {
1231		case *CreateRequest:
1232			opCode = opCreate
1233		case *SetDataRequest:
1234			opCode = opSetData
1235		case *DeleteRequest:
1236			opCode = opDelete
1237		case *CheckVersionRequest:
1238			opCode = opCheck
1239		default:
1240			return nil, fmt.Errorf("unknown operation type %T", op)
1241		}
1242		req.Ops = append(req.Ops, multiRequestOp{multiHeader{opCode, false, -1}, op})
1243	}
1244	res := &multiResponse{}
1245	_, err := c.request(opMulti, req, res, nil)
1246	if err == ErrConnectionClosed {
1247		return nil, err
1248	}
1249	mr := make([]MultiResponse, len(res.Ops))
1250	for i, op := range res.Ops {
1251		mr[i] = MultiResponse{Stat: op.Stat, String: op.String, Error: op.Err.toError()}
1252	}
1253	return mr, err
1254}
1255
1256// IncrementalReconfig is the zookeeper reconfiguration api that allows adding and removing servers
1257// by lists of members. For more info refer to the ZK documentation.
1258//
1259// An optional version allows for conditional reconfigurations, -1 ignores the condition.
1260//
1261// Returns the new configuration znode stat.
1262func (c *Conn) IncrementalReconfig(joining, leaving []string, version int64) (*Stat, error) {
1263	// TODO: validate the shape of the member string to give early feedback.
1264	request := &reconfigRequest{
1265		JoiningServers: []byte(strings.Join(joining, ",")),
1266		LeavingServers: []byte(strings.Join(leaving, ",")),
1267		CurConfigId:    version,
1268	}
1269
1270	return c.internalReconfig(request)
1271}
1272
1273// Reconfig is the non-incremental update functionality for Zookeeper where the list provided
1274// is the entire new member list. For more info refer to the ZK documentation.
1275//
1276// An optional version allows for conditional reconfigurations, -1 ignores the condition.
1277//
1278// Returns the new configuration znode stat.
1279func (c *Conn) Reconfig(members []string, version int64) (*Stat, error) {
1280	request := &reconfigRequest{
1281		NewMembers:  []byte(strings.Join(members, ",")),
1282		CurConfigId: version,
1283	}
1284
1285	return c.internalReconfig(request)
1286}
1287
1288func (c *Conn) internalReconfig(request *reconfigRequest) (*Stat, error) {
1289	response := &reconfigReponse{}
1290	_, err := c.request(opReconfig, request, response, nil)
1291	return &response.Stat, err
1292}
1293
1294// Server returns the current or last-connected server name.
1295func (c *Conn) Server() string {
1296	c.serverMu.Lock()
1297	defer c.serverMu.Unlock()
1298	return c.server
1299}
1300
1301func resendZkAuth(ctx context.Context, c *Conn) error {
1302	shouldCancel := func() bool {
1303		select {
1304		case <-c.shouldQuit:
1305			return true
1306		case <-c.closeChan:
1307			return true
1308		default:
1309			return false
1310		}
1311	}
1312
1313	c.credsMu.Lock()
1314	defer c.credsMu.Unlock()
1315
1316	if c.logInfo {
1317		c.logger.Printf("re-submitting `%d` credentials after reconnect", len(c.creds))
1318	}
1319
1320	for _, cred := range c.creds {
1321		// return early before attempting to send request.
1322		if shouldCancel() {
1323			return nil
1324		}
1325		// do not use the public API for auth since it depends on the send/recv loops
1326		// that are waiting for this to return
1327		resChan, err := c.sendRequest(
1328			opSetAuth,
1329			&setAuthRequest{Type: 0,
1330				Scheme: cred.scheme,
1331				Auth:   cred.auth,
1332			},
1333			&setAuthResponse{},
1334			nil, /* recvFunc*/
1335		)
1336		if err != nil {
1337			return fmt.Errorf("failed to send auth request: %v", err)
1338		}
1339
1340		var res response
1341		select {
1342		case res = <-resChan:
1343		case <-c.closeChan:
1344			c.logger.Printf("recv closed, cancel re-submitting credentials")
1345			return nil
1346		case <-c.shouldQuit:
1347			c.logger.Printf("should quit, cancel re-submitting credentials")
1348			return nil
1349		case <-ctx.Done():
1350			return ctx.Err()
1351		}
1352		if res.err != nil {
1353			return fmt.Errorf("failed conneciton setAuth request: %v", res.err)
1354		}
1355	}
1356
1357	return nil
1358}
1359