1// Package zk is a native Go client library for the ZooKeeper orchestration service.
2package zk
3
4/*
5TODO:
6* make sure a ping response comes back in a reasonable time
7
8Possible watcher events:
9* Event{Type: EventNotWatching, State: StateDisconnected, Path: path, Err: err}
10*/
11
12import (
13	"crypto/rand"
14	"encoding/binary"
15	"errors"
16	"fmt"
17	"io"
18	"net"
19	"strconv"
20	"strings"
21	"sync"
22	"sync/atomic"
23	"time"
24)
25
26// ErrNoServer indicates that an operation cannot be completed
27// because attempts to connect to all servers in the list failed.
28var ErrNoServer = errors.New("zk: could not connect to a server")
29
30// ErrInvalidPath indicates that an operation was being attempted on
31// an invalid path. (e.g. empty path)
32var ErrInvalidPath = errors.New("zk: invalid path")
33
34// DefaultLogger uses the stdlib log package for logging.
35var DefaultLogger Logger = defaultLogger{}
36
37const (
38	bufferSize      = 1536 * 1024
39	eventChanSize   = 6
40	sendChanSize    = 16
41	protectedPrefix = "_c_"
42)
43
44type watchType int
45
46const (
47	watchTypeData = iota
48	watchTypeExist
49	watchTypeChild
50)
51
52type watchPathType struct {
53	path  string
54	wType watchType
55}
56
57type Dialer func(network, address string, timeout time.Duration) (net.Conn, error)
58
59// Logger is an interface that can be implemented to provide custom log output.
60type Logger interface {
61	Printf(string, ...interface{})
62}
63
64type authCreds struct {
65	scheme string
66	auth   []byte
67}
68
69type Conn struct {
70	lastZxid         int64
71	sessionID        int64
72	state            State // must be 32-bit aligned
73	xid              uint32
74	sessionTimeoutMs int32 // session timeout in milliseconds
75	passwd           []byte
76
77	dialer         Dialer
78	hostProvider   HostProvider
79	serverMu       sync.Mutex // protects server
80	server         string     // remember the address/port of the current server
81	conn           net.Conn
82	eventChan      chan Event
83	eventCallback  EventCallback // may be nil
84	shouldQuit     chan struct{}
85	pingInterval   time.Duration
86	recvTimeout    time.Duration
87	connectTimeout time.Duration
88	maxBufferSize  int
89
90	creds   []authCreds
91	credsMu sync.Mutex // protects server
92
93	sendChan     chan *request
94	requests     map[int32]*request // Xid -> pending request
95	requestsLock sync.Mutex
96	watchers     map[watchPathType][]chan Event
97	watchersLock sync.Mutex
98	closeChan    chan struct{} // channel to tell send loop stop
99
100	// Debug (used by unit tests)
101	reconnectLatch   chan struct{}
102	setWatchLimit    int
103	setWatchCallback func([]*setWatchesRequest)
104	// Debug (for recurring re-auth hang)
105	debugCloseRecvLoop bool
106	debugReauthDone    chan struct{}
107
108	logger  Logger
109	logInfo bool // true if information messages are logged; false if only errors are logged
110
111	buf []byte
112}
113
114// connOption represents a connection option.
115type connOption func(c *Conn)
116
117type request struct {
118	xid        int32
119	opcode     int32
120	pkt        interface{}
121	recvStruct interface{}
122	recvChan   chan response
123
124	// Because sending and receiving happen in separate go routines, there's
125	// a possible race condition when creating watches from outside the read
126	// loop. We must ensure that a watcher gets added to the list synchronously
127	// with the response from the server on any request that creates a watch.
128	// In order to not hard code the watch logic for each opcode in the recv
129	// loop the caller can use recvFunc to insert some synchronously code
130	// after a response.
131	recvFunc func(*request, *responseHeader, error)
132}
133
134type response struct {
135	zxid int64
136	err  error
137}
138
139type Event struct {
140	Type   EventType
141	State  State
142	Path   string // For non-session events, the path of the watched node.
143	Err    error
144	Server string // For connection events
145}
146
147// HostProvider is used to represent a set of hosts a ZooKeeper client should connect to.
148// It is an analog of the Java equivalent:
149// http://svn.apache.org/viewvc/zookeeper/trunk/src/java/main/org/apache/zookeeper/client/HostProvider.java?view=markup
150type HostProvider interface {
151	// Init is called first, with the servers specified in the connection string.
152	Init(servers []string) error
153	// Len returns the number of servers.
154	Len() int
155	// Next returns the next server to connect to. retryStart will be true if we've looped through
156	// all known servers without Connected() being called.
157	Next() (server string, retryStart bool)
158	// Notify the HostProvider of a successful connection.
159	Connected()
160}
161
162// ConnectWithDialer establishes a new connection to a pool of zookeeper servers
163// using a custom Dialer. See Connect for further information about session timeout.
164// This method is deprecated and provided for compatibility: use the WithDialer option instead.
165func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Dialer) (*Conn, <-chan Event, error) {
166	return Connect(servers, sessionTimeout, WithDialer(dialer))
167}
168
169// Connect establishes a new connection to a pool of zookeeper
170// servers. The provided session timeout sets the amount of time for which
171// a session is considered valid after losing connection to a server. Within
172// the session timeout it's possible to reestablish a connection to a different
173// server and keep the same session. This is means any ephemeral nodes and
174// watches are maintained.
175func Connect(servers []string, sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) {
176	if len(servers) == 0 {
177		return nil, nil, errors.New("zk: server list must not be empty")
178	}
179
180	srvs := make([]string, len(servers))
181
182	for i, addr := range servers {
183		if strings.Contains(addr, ":") {
184			srvs[i] = addr
185		} else {
186			srvs[i] = addr + ":" + strconv.Itoa(DefaultPort)
187		}
188	}
189
190	// Randomize the order of the servers to avoid creating hotspots
191	stringShuffle(srvs)
192
193	ec := make(chan Event, eventChanSize)
194	conn := &Conn{
195		dialer:         net.DialTimeout,
196		hostProvider:   &DNSHostProvider{},
197		conn:           nil,
198		state:          StateDisconnected,
199		eventChan:      ec,
200		shouldQuit:     make(chan struct{}),
201		connectTimeout: 1 * time.Second,
202		sendChan:       make(chan *request, sendChanSize),
203		requests:       make(map[int32]*request),
204		watchers:       make(map[watchPathType][]chan Event),
205		passwd:         emptyPassword,
206		logger:         DefaultLogger,
207		logInfo:        true, // default is true for backwards compatability
208		buf:            make([]byte, bufferSize),
209	}
210
211	// Set provided options.
212	for _, option := range options {
213		option(conn)
214	}
215
216	if err := conn.hostProvider.Init(srvs); err != nil {
217		return nil, nil, err
218	}
219
220	conn.setTimeouts(int32(sessionTimeout / time.Millisecond))
221
222	go func() {
223		conn.loop()
224		conn.flushRequests(ErrClosing)
225		conn.invalidateWatches(ErrClosing)
226		close(conn.eventChan)
227	}()
228	return conn, ec, nil
229}
230
231// WithDialer returns a connection option specifying a non-default Dialer.
232func WithDialer(dialer Dialer) connOption {
233	return func(c *Conn) {
234		c.dialer = dialer
235	}
236}
237
238// WithHostProvider returns a connection option specifying a non-default HostProvider.
239func WithHostProvider(hostProvider HostProvider) connOption {
240	return func(c *Conn) {
241		c.hostProvider = hostProvider
242	}
243}
244
245// WithLogger returns a connection option specifying a non-default Logger
246func WithLogger(logger Logger) connOption {
247	return func(c *Conn) {
248		c.logger = logger
249	}
250}
251
252// WithLogInfo returns a connection option specifying whether or not information messages
253// shoud be logged.
254func WithLogInfo(logInfo bool) connOption {
255	return func(c *Conn) {
256		c.logInfo = logInfo
257	}
258}
259
260// EventCallback is a function that is called when an Event occurs.
261type EventCallback func(Event)
262
263// WithEventCallback returns a connection option that specifies an event
264// callback.
265// The callback must not block - doing so would delay the ZK go routines.
266func WithEventCallback(cb EventCallback) connOption {
267	return func(c *Conn) {
268		c.eventCallback = cb
269	}
270}
271
272// WithMaxBufferSize sets the maximum buffer size used to read and decode
273// packets received from the Zookeeper server. The standard Zookeeper client for
274// Java defaults to a limit of 1mb. For backwards compatibility, this Go client
275// defaults to unbounded unless overridden via this option. A value that is zero
276// or negative indicates that no limit is enforced.
277//
278// This is meant to prevent resource exhaustion in the face of potentially
279// malicious data in ZK. It should generally match the server setting (which
280// also defaults ot 1mb) so that clients and servers agree on the limits for
281// things like the size of data in an individual znode and the total size of a
282// transaction.
283//
284// For production systems, this should be set to a reasonable value (ideally
285// that matches the server configuration). For ops tooling, it is handy to use a
286// much larger limit, in order to do things like clean-up problematic state in
287// the ZK tree. For example, if a single znode has a huge number of children, it
288// is possible for the response to a "list children" operation to exceed this
289// buffer size and cause errors in clients. The only way to subsequently clean
290// up the tree (by removing superfluous children) is to use a client configured
291// with a larger buffer size that can successfully query for all of the child
292// names and then remove them. (Note there are other tools that can list all of
293// the child names without an increased buffer size in the client, but they work
294// by inspecting the servers' transaction logs to enumerate children instead of
295// sending an online request to a server.
296func WithMaxBufferSize(maxBufferSize int) connOption {
297	return func(c *Conn) {
298		c.maxBufferSize = maxBufferSize
299	}
300}
301
302// WithMaxConnBufferSize sets maximum buffer size used to send and encode
303// packets to Zookeeper server. The standard Zookeepeer client for java defaults
304// to a limit of 1mb. This option should be used for non-standard server setup
305// where znode is bigger than default 1mb.
306func WithMaxConnBufferSize(maxBufferSize int) connOption {
307	return func(c *Conn) {
308		c.buf = make([]byte, maxBufferSize)
309	}
310}
311
312func (c *Conn) Close() {
313	close(c.shouldQuit)
314
315	select {
316	case <-c.queueRequest(opClose, &closeRequest{}, &closeResponse{}, nil):
317	case <-time.After(time.Second):
318	}
319}
320
321// State returns the current state of the connection.
322func (c *Conn) State() State {
323	return State(atomic.LoadInt32((*int32)(&c.state)))
324}
325
326// SessionID returns the current session id of the connection.
327func (c *Conn) SessionID() int64 {
328	return atomic.LoadInt64(&c.sessionID)
329}
330
331// SetLogger sets the logger to be used for printing errors.
332// Logger is an interface provided by this package.
333func (c *Conn) SetLogger(l Logger) {
334	c.logger = l
335}
336
337func (c *Conn) setTimeouts(sessionTimeoutMs int32) {
338	c.sessionTimeoutMs = sessionTimeoutMs
339	sessionTimeout := time.Duration(sessionTimeoutMs) * time.Millisecond
340	c.recvTimeout = sessionTimeout * 2 / 3
341	c.pingInterval = c.recvTimeout / 2
342}
343
344func (c *Conn) setState(state State) {
345	atomic.StoreInt32((*int32)(&c.state), int32(state))
346	c.sendEvent(Event{Type: EventSession, State: state, Server: c.Server()})
347}
348
349func (c *Conn) sendEvent(evt Event) {
350	if c.eventCallback != nil {
351		c.eventCallback(evt)
352	}
353
354	select {
355	case c.eventChan <- evt:
356	default:
357		// panic("zk: event channel full - it must be monitored and never allowed to be full")
358	}
359}
360
361func (c *Conn) connect() error {
362	var retryStart bool
363	for {
364		c.serverMu.Lock()
365		c.server, retryStart = c.hostProvider.Next()
366		c.serverMu.Unlock()
367		c.setState(StateConnecting)
368		if retryStart {
369			c.flushUnsentRequests(ErrNoServer)
370			select {
371			case <-time.After(time.Second):
372				// pass
373			case <-c.shouldQuit:
374				c.setState(StateDisconnected)
375				c.flushUnsentRequests(ErrClosing)
376				return ErrClosing
377			}
378		}
379
380		zkConn, err := c.dialer("tcp", c.Server(), c.connectTimeout)
381		if err == nil {
382			c.conn = zkConn
383			c.setState(StateConnected)
384			if c.logInfo {
385				c.logger.Printf("Connected to %s", c.Server())
386			}
387			return nil
388		}
389
390		c.logger.Printf("Failed to connect to %s: %+v", c.Server(), err)
391	}
392}
393
394func (c *Conn) resendZkAuth(reauthReadyChan chan struct{}) {
395	shouldCancel := func() bool {
396		select {
397		case <-c.shouldQuit:
398			return true
399		case <-c.closeChan:
400			return true
401		default:
402			return false
403		}
404	}
405
406	c.credsMu.Lock()
407	defer c.credsMu.Unlock()
408
409	defer close(reauthReadyChan)
410
411	if c.logInfo {
412		c.logger.Printf("Re-submitting `%d` credentials after reconnect",
413			len(c.creds))
414	}
415
416	for _, cred := range c.creds {
417		if shouldCancel() {
418			c.logger.Printf("Cancel rer-submitting credentials")
419			return
420		}
421		resChan, err := c.sendRequest(
422			opSetAuth,
423			&setAuthRequest{Type: 0,
424				Scheme: cred.scheme,
425				Auth:   cred.auth,
426			},
427			&setAuthResponse{},
428			nil)
429
430		if err != nil {
431			c.logger.Printf("Call to sendRequest failed during credential resubmit: %s", err)
432			// FIXME(prozlach): lets ignore errors for now
433			continue
434		}
435
436		var res response
437		select {
438		case res = <-resChan:
439		case <-c.closeChan:
440			c.logger.Printf("Recv closed, cancel re-submitting credentials")
441			return
442		case <-c.shouldQuit:
443			c.logger.Printf("Should quit, cancel re-submitting credentials")
444			return
445		}
446		if res.err != nil {
447			c.logger.Printf("Credential re-submit failed: %s", res.err)
448			// FIXME(prozlach): lets ignore errors for now
449			continue
450		}
451	}
452}
453
454func (c *Conn) sendRequest(
455	opcode int32,
456	req interface{},
457	res interface{},
458	recvFunc func(*request, *responseHeader, error),
459) (
460	<-chan response,
461	error,
462) {
463	rq := &request{
464		xid:        c.nextXid(),
465		opcode:     opcode,
466		pkt:        req,
467		recvStruct: res,
468		recvChan:   make(chan response, 1),
469		recvFunc:   recvFunc,
470	}
471
472	if err := c.sendData(rq); err != nil {
473		return nil, err
474	}
475
476	return rq.recvChan, nil
477}
478
479func (c *Conn) loop() {
480	for {
481		if err := c.connect(); err != nil {
482			// c.Close() was called
483			return
484		}
485
486		err := c.authenticate()
487		switch {
488		case err == ErrSessionExpired:
489			c.logger.Printf("Authentication failed: %s", err)
490			c.invalidateWatches(err)
491		case err != nil && c.conn != nil:
492			c.logger.Printf("Authentication failed: %s", err)
493			c.conn.Close()
494		case err == nil:
495			if c.logInfo {
496				c.logger.Printf("Authenticated: id=%d, timeout=%d", c.SessionID(), c.sessionTimeoutMs)
497			}
498			c.hostProvider.Connected()        // mark success
499			c.closeChan = make(chan struct{}) // channel to tell send loop stop
500			reauthChan := make(chan struct{}) // channel to tell send loop that authdata has been resubmitted
501
502			var wg sync.WaitGroup
503			wg.Add(1)
504			go func() {
505				<-reauthChan
506				if c.debugCloseRecvLoop {
507					close(c.debugReauthDone)
508				}
509				err := c.sendLoop()
510				if err != nil || c.logInfo {
511					c.logger.Printf("Send loop terminated: err=%v", err)
512				}
513				c.conn.Close() // causes recv loop to EOF/exit
514				wg.Done()
515			}()
516
517			wg.Add(1)
518			go func() {
519				var err error
520				if c.debugCloseRecvLoop {
521					err = errors.New("DEBUG: close recv loop")
522				} else {
523					err = c.recvLoop(c.conn)
524				}
525				if err != io.EOF || c.logInfo {
526					c.logger.Printf("Recv loop terminated: err=%v", err)
527				}
528				if err == nil {
529					panic("zk: recvLoop should never return nil error")
530				}
531				close(c.closeChan) // tell send loop to exit
532				wg.Done()
533			}()
534
535			c.resendZkAuth(reauthChan)
536
537			c.sendSetWatches()
538			wg.Wait()
539		}
540
541		c.setState(StateDisconnected)
542
543		select {
544		case <-c.shouldQuit:
545			c.flushRequests(ErrClosing)
546			return
547		default:
548		}
549
550		if err != ErrSessionExpired {
551			err = ErrConnectionClosed
552		}
553		c.flushRequests(err)
554
555		if c.reconnectLatch != nil {
556			select {
557			case <-c.shouldQuit:
558				return
559			case <-c.reconnectLatch:
560			}
561		}
562	}
563}
564
565func (c *Conn) flushUnsentRequests(err error) {
566	for {
567		select {
568		default:
569			return
570		case req := <-c.sendChan:
571			req.recvChan <- response{-1, err}
572		}
573	}
574}
575
576// Send error to all pending requests and clear request map
577func (c *Conn) flushRequests(err error) {
578	c.requestsLock.Lock()
579	for _, req := range c.requests {
580		req.recvChan <- response{-1, err}
581	}
582	c.requests = make(map[int32]*request)
583	c.requestsLock.Unlock()
584}
585
586// Send error to all watchers and clear watchers map
587func (c *Conn) invalidateWatches(err error) {
588	c.watchersLock.Lock()
589	defer c.watchersLock.Unlock()
590
591	if len(c.watchers) >= 0 {
592		for pathType, watchers := range c.watchers {
593			ev := Event{Type: EventNotWatching, State: StateDisconnected, Path: pathType.path, Err: err}
594			for _, ch := range watchers {
595				ch <- ev
596				close(ch)
597			}
598		}
599		c.watchers = make(map[watchPathType][]chan Event)
600	}
601}
602
603func (c *Conn) sendSetWatches() {
604	c.watchersLock.Lock()
605	defer c.watchersLock.Unlock()
606
607	if len(c.watchers) == 0 {
608		return
609	}
610
611	// NB: A ZK server, by default, rejects packets >1mb. So, if we have too
612	// many watches to reset, we need to break this up into multiple packets
613	// to avoid hitting that limit. Mirroring the Java client behavior: we are
614	// conservative in that we limit requests to 128kb (since server limit is
615	// is actually configurable and could conceivably be configured smaller
616	// than default of 1mb).
617	limit := 128 * 1024
618	if c.setWatchLimit > 0 {
619		limit = c.setWatchLimit
620	}
621
622	var reqs []*setWatchesRequest
623	var req *setWatchesRequest
624	var sizeSoFar int
625
626	n := 0
627	for pathType, watchers := range c.watchers {
628		if len(watchers) == 0 {
629			continue
630		}
631		addlLen := 4 + len(pathType.path)
632		if req == nil || sizeSoFar+addlLen > limit {
633			if req != nil {
634				// add to set of requests that we'll send
635				reqs = append(reqs, req)
636			}
637			sizeSoFar = 28 // fixed overhead of a set-watches packet
638			req = &setWatchesRequest{
639				RelativeZxid: c.lastZxid,
640				DataWatches:  make([]string, 0),
641				ExistWatches: make([]string, 0),
642				ChildWatches: make([]string, 0),
643			}
644		}
645		sizeSoFar += addlLen
646		switch pathType.wType {
647		case watchTypeData:
648			req.DataWatches = append(req.DataWatches, pathType.path)
649		case watchTypeExist:
650			req.ExistWatches = append(req.ExistWatches, pathType.path)
651		case watchTypeChild:
652			req.ChildWatches = append(req.ChildWatches, pathType.path)
653		}
654		n++
655	}
656	if n == 0 {
657		return
658	}
659	if req != nil { // don't forget any trailing packet we were building
660		reqs = append(reqs, req)
661	}
662
663	if c.setWatchCallback != nil {
664		c.setWatchCallback(reqs)
665	}
666
667	go func() {
668		res := &setWatchesResponse{}
669		// TODO: Pipeline these so queue all of them up before waiting on any
670		// response. That will require some investigation to make sure there
671		// aren't failure modes where a blocking write to the channel of requests
672		// could hang indefinitely and cause this goroutine to leak...
673		for _, req := range reqs {
674			_, err := c.request(opSetWatches, req, res, nil)
675			if err != nil {
676				c.logger.Printf("Failed to set previous watches: %s", err.Error())
677				break
678			}
679		}
680	}()
681}
682
683func (c *Conn) authenticate() error {
684	buf := make([]byte, 256)
685
686	// Encode and send a connect request.
687	n, err := encodePacket(buf[4:], &connectRequest{
688		ProtocolVersion: protocolVersion,
689		LastZxidSeen:    c.lastZxid,
690		TimeOut:         c.sessionTimeoutMs,
691		SessionID:       c.SessionID(),
692		Passwd:          c.passwd,
693	})
694	if err != nil {
695		return err
696	}
697
698	binary.BigEndian.PutUint32(buf[:4], uint32(n))
699
700	c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout * 10))
701	_, err = c.conn.Write(buf[:n+4])
702	c.conn.SetWriteDeadline(time.Time{})
703	if err != nil {
704		return err
705	}
706
707	// Receive and decode a connect response.
708	c.conn.SetReadDeadline(time.Now().Add(c.recvTimeout * 10))
709	_, err = io.ReadFull(c.conn, buf[:4])
710	c.conn.SetReadDeadline(time.Time{})
711	if err != nil {
712		return err
713	}
714
715	blen := int(binary.BigEndian.Uint32(buf[:4]))
716	if cap(buf) < blen {
717		buf = make([]byte, blen)
718	}
719
720	_, err = io.ReadFull(c.conn, buf[:blen])
721	if err != nil {
722		return err
723	}
724
725	r := connectResponse{}
726	_, err = decodePacket(buf[:blen], &r)
727	if err != nil {
728		return err
729	}
730	if r.SessionID == 0 {
731		atomic.StoreInt64(&c.sessionID, int64(0))
732		c.passwd = emptyPassword
733		c.lastZxid = 0
734		c.setState(StateExpired)
735		return ErrSessionExpired
736	}
737
738	atomic.StoreInt64(&c.sessionID, r.SessionID)
739	c.setTimeouts(r.TimeOut)
740	c.passwd = r.Passwd
741	c.setState(StateHasSession)
742
743	return nil
744}
745
746func (c *Conn) sendData(req *request) error {
747	header := &requestHeader{req.xid, req.opcode}
748	n, err := encodePacket(c.buf[4:], header)
749	if err != nil {
750		req.recvChan <- response{-1, err}
751		return nil
752	}
753
754	n2, err := encodePacket(c.buf[4+n:], req.pkt)
755	if err != nil {
756		req.recvChan <- response{-1, err}
757		return nil
758	}
759
760	n += n2
761
762	binary.BigEndian.PutUint32(c.buf[:4], uint32(n))
763
764	c.requestsLock.Lock()
765	select {
766	case <-c.closeChan:
767		req.recvChan <- response{-1, ErrConnectionClosed}
768		c.requestsLock.Unlock()
769		return ErrConnectionClosed
770	default:
771	}
772	c.requests[req.xid] = req
773	c.requestsLock.Unlock()
774
775	c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout))
776	_, err = c.conn.Write(c.buf[:n+4])
777	c.conn.SetWriteDeadline(time.Time{})
778	if err != nil {
779		req.recvChan <- response{-1, err}
780		c.conn.Close()
781		return err
782	}
783
784	return nil
785}
786
787func (c *Conn) sendLoop() error {
788	pingTicker := time.NewTicker(c.pingInterval)
789	defer pingTicker.Stop()
790
791	for {
792		select {
793		case req := <-c.sendChan:
794			if err := c.sendData(req); err != nil {
795				return err
796			}
797		case <-pingTicker.C:
798			n, err := encodePacket(c.buf[4:], &requestHeader{Xid: -2, Opcode: opPing})
799			if err != nil {
800				panic("zk: opPing should never fail to serialize")
801			}
802
803			binary.BigEndian.PutUint32(c.buf[:4], uint32(n))
804
805			c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout))
806			_, err = c.conn.Write(c.buf[:n+4])
807			c.conn.SetWriteDeadline(time.Time{})
808			if err != nil {
809				c.conn.Close()
810				return err
811			}
812		case <-c.closeChan:
813			return nil
814		}
815	}
816}
817
818func (c *Conn) recvLoop(conn net.Conn) error {
819	sz := bufferSize
820	if c.maxBufferSize > 0 && sz > c.maxBufferSize {
821		sz = c.maxBufferSize
822	}
823	buf := make([]byte, sz)
824	for {
825		// package length
826		conn.SetReadDeadline(time.Now().Add(c.recvTimeout))
827		_, err := io.ReadFull(conn, buf[:4])
828		if err != nil {
829			return err
830		}
831
832		blen := int(binary.BigEndian.Uint32(buf[:4]))
833		if cap(buf) < blen {
834			if c.maxBufferSize > 0 && blen > c.maxBufferSize {
835				return fmt.Errorf("received packet from server with length %d, which exceeds max buffer size %d", blen, c.maxBufferSize)
836			}
837			buf = make([]byte, blen)
838		}
839
840		_, err = io.ReadFull(conn, buf[:blen])
841		conn.SetReadDeadline(time.Time{})
842		if err != nil {
843			return err
844		}
845
846		res := responseHeader{}
847		_, err = decodePacket(buf[:16], &res)
848		if err != nil {
849			return err
850		}
851
852		if res.Xid == -1 {
853			res := &watcherEvent{}
854			_, err := decodePacket(buf[16:blen], res)
855			if err != nil {
856				return err
857			}
858			ev := Event{
859				Type:  res.Type,
860				State: res.State,
861				Path:  res.Path,
862				Err:   nil,
863			}
864			c.sendEvent(ev)
865			wTypes := make([]watchType, 0, 2)
866			switch res.Type {
867			case EventNodeCreated:
868				wTypes = append(wTypes, watchTypeExist)
869			case EventNodeDeleted, EventNodeDataChanged:
870				wTypes = append(wTypes, watchTypeExist, watchTypeData, watchTypeChild)
871			case EventNodeChildrenChanged:
872				wTypes = append(wTypes, watchTypeChild)
873			}
874			c.watchersLock.Lock()
875			for _, t := range wTypes {
876				wpt := watchPathType{res.Path, t}
877				if watchers := c.watchers[wpt]; watchers != nil && len(watchers) > 0 {
878					for _, ch := range watchers {
879						ch <- ev
880						close(ch)
881					}
882					delete(c.watchers, wpt)
883				}
884			}
885			c.watchersLock.Unlock()
886		} else if res.Xid == -2 {
887			// Ping response. Ignore.
888		} else if res.Xid < 0 {
889			c.logger.Printf("Xid < 0 (%d) but not ping or watcher event", res.Xid)
890		} else {
891			if res.Zxid > 0 {
892				c.lastZxid = res.Zxid
893			}
894
895			c.requestsLock.Lock()
896			req, ok := c.requests[res.Xid]
897			if ok {
898				delete(c.requests, res.Xid)
899			}
900			c.requestsLock.Unlock()
901
902			if !ok {
903				c.logger.Printf("Response for unknown request with xid %d", res.Xid)
904			} else {
905				if res.Err != 0 {
906					err = res.Err.toError()
907				} else {
908					_, err = decodePacket(buf[16:blen], req.recvStruct)
909				}
910				if req.recvFunc != nil {
911					req.recvFunc(req, &res, err)
912				}
913				req.recvChan <- response{res.Zxid, err}
914				if req.opcode == opClose {
915					return io.EOF
916				}
917			}
918		}
919	}
920}
921
922func (c *Conn) nextXid() int32 {
923	return int32(atomic.AddUint32(&c.xid, 1) & 0x7fffffff)
924}
925
926func (c *Conn) addWatcher(path string, watchType watchType) <-chan Event {
927	c.watchersLock.Lock()
928	defer c.watchersLock.Unlock()
929
930	ch := make(chan Event, 1)
931	wpt := watchPathType{path, watchType}
932	c.watchers[wpt] = append(c.watchers[wpt], ch)
933	return ch
934}
935
936func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) <-chan response {
937	rq := &request{
938		xid:        c.nextXid(),
939		opcode:     opcode,
940		pkt:        req,
941		recvStruct: res,
942		recvChan:   make(chan response, 1),
943		recvFunc:   recvFunc,
944	}
945	c.sendChan <- rq
946	return rq.recvChan
947}
948
949func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (int64, error) {
950	r := <-c.queueRequest(opcode, req, res, recvFunc)
951	return r.zxid, r.err
952}
953
954func (c *Conn) AddAuth(scheme string, auth []byte) error {
955	_, err := c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil)
956
957	if err != nil {
958		return err
959	}
960
961	// Remember authdata so that it can be re-submitted on reconnect
962	//
963	// FIXME(prozlach): For now we treat "userfoo:passbar" and "userfoo:passbar2"
964	// as two different entries, which will be re-submitted on reconnet. Some
965	// research is needed on how ZK treats these cases and
966	// then maybe switch to something like "map[username] = password" to allow
967	// only single password for given user with users being unique.
968	obj := authCreds{
969		scheme: scheme,
970		auth:   auth,
971	}
972
973	c.credsMu.Lock()
974	c.creds = append(c.creds, obj)
975	c.credsMu.Unlock()
976
977	return nil
978}
979
980func (c *Conn) Children(path string) ([]string, *Stat, error) {
981	if err := validatePath(path, false); err != nil {
982		return nil, nil, err
983	}
984
985	res := &getChildren2Response{}
986	_, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: false}, res, nil)
987	return res.Children, &res.Stat, err
988}
989
990func (c *Conn) ChildrenW(path string) ([]string, *Stat, <-chan Event, error) {
991	if err := validatePath(path, false); err != nil {
992		return nil, nil, nil, err
993	}
994
995	var ech <-chan Event
996	res := &getChildren2Response{}
997	_, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) {
998		if err == nil {
999			ech = c.addWatcher(path, watchTypeChild)
1000		}
1001	})
1002	if err != nil {
1003		return nil, nil, nil, err
1004	}
1005	return res.Children, &res.Stat, ech, err
1006}
1007
1008func (c *Conn) Get(path string) ([]byte, *Stat, error) {
1009	if err := validatePath(path, false); err != nil {
1010		return nil, nil, err
1011	}
1012
1013	res := &getDataResponse{}
1014	_, err := c.request(opGetData, &getDataRequest{Path: path, Watch: false}, res, nil)
1015	return res.Data, &res.Stat, err
1016}
1017
1018// GetW returns the contents of a znode and sets a watch
1019func (c *Conn) GetW(path string) ([]byte, *Stat, <-chan Event, error) {
1020	if err := validatePath(path, false); err != nil {
1021		return nil, nil, nil, err
1022	}
1023
1024	var ech <-chan Event
1025	res := &getDataResponse{}
1026	_, err := c.request(opGetData, &getDataRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) {
1027		if err == nil {
1028			ech = c.addWatcher(path, watchTypeData)
1029		}
1030	})
1031	if err != nil {
1032		return nil, nil, nil, err
1033	}
1034	return res.Data, &res.Stat, ech, err
1035}
1036
1037func (c *Conn) Set(path string, data []byte, version int32) (*Stat, error) {
1038	if err := validatePath(path, false); err != nil {
1039		return nil, err
1040	}
1041
1042	res := &setDataResponse{}
1043	_, err := c.request(opSetData, &SetDataRequest{path, data, version}, res, nil)
1044	return &res.Stat, err
1045}
1046
1047func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, error) {
1048	if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil {
1049		return "", err
1050	}
1051
1052	res := &createResponse{}
1053	_, err := c.request(opCreate, &CreateRequest{path, data, acl, flags}, res, nil)
1054	return res.Path, err
1055}
1056
1057// CreateProtectedEphemeralSequential fixes a race condition if the server crashes
1058// after it creates the node. On reconnect the session may still be valid so the
1059// ephemeral node still exists. Therefore, on reconnect we need to check if a node
1060// with a GUID generated on create exists.
1061func (c *Conn) CreateProtectedEphemeralSequential(path string, data []byte, acl []ACL) (string, error) {
1062	if err := validatePath(path, true); err != nil {
1063		return "", err
1064	}
1065
1066	var guid [16]byte
1067	_, err := io.ReadFull(rand.Reader, guid[:16])
1068	if err != nil {
1069		return "", err
1070	}
1071	guidStr := fmt.Sprintf("%x", guid)
1072
1073	parts := strings.Split(path, "/")
1074	parts[len(parts)-1] = fmt.Sprintf("%s%s-%s", protectedPrefix, guidStr, parts[len(parts)-1])
1075	rootPath := strings.Join(parts[:len(parts)-1], "/")
1076	protectedPath := strings.Join(parts, "/")
1077
1078	var newPath string
1079	for i := 0; i < 3; i++ {
1080		newPath, err = c.Create(protectedPath, data, FlagEphemeral|FlagSequence, acl)
1081		switch err {
1082		case ErrSessionExpired:
1083			// No need to search for the node since it can't exist. Just try again.
1084		case ErrConnectionClosed:
1085			children, _, err := c.Children(rootPath)
1086			if err != nil {
1087				return "", err
1088			}
1089			for _, p := range children {
1090				parts := strings.Split(p, "/")
1091				if pth := parts[len(parts)-1]; strings.HasPrefix(pth, protectedPrefix) {
1092					if g := pth[len(protectedPrefix) : len(protectedPrefix)+32]; g == guidStr {
1093						return rootPath + "/" + p, nil
1094					}
1095				}
1096			}
1097		case nil:
1098			return newPath, nil
1099		default:
1100			return "", err
1101		}
1102	}
1103	return "", err
1104}
1105
1106func (c *Conn) Delete(path string, version int32) error {
1107	if err := validatePath(path, false); err != nil {
1108		return err
1109	}
1110
1111	_, err := c.request(opDelete, &DeleteRequest{path, version}, &deleteResponse{}, nil)
1112	return err
1113}
1114
1115func (c *Conn) Exists(path string) (bool, *Stat, error) {
1116	if err := validatePath(path, false); err != nil {
1117		return false, nil, err
1118	}
1119
1120	res := &existsResponse{}
1121	_, err := c.request(opExists, &existsRequest{Path: path, Watch: false}, res, nil)
1122	exists := true
1123	if err == ErrNoNode {
1124		exists = false
1125		err = nil
1126	}
1127	return exists, &res.Stat, err
1128}
1129
1130func (c *Conn) ExistsW(path string) (bool, *Stat, <-chan Event, error) {
1131	if err := validatePath(path, false); err != nil {
1132		return false, nil, nil, err
1133	}
1134
1135	var ech <-chan Event
1136	res := &existsResponse{}
1137	_, err := c.request(opExists, &existsRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) {
1138		if err == nil {
1139			ech = c.addWatcher(path, watchTypeData)
1140		} else if err == ErrNoNode {
1141			ech = c.addWatcher(path, watchTypeExist)
1142		}
1143	})
1144	exists := true
1145	if err == ErrNoNode {
1146		exists = false
1147		err = nil
1148	}
1149	if err != nil {
1150		return false, nil, nil, err
1151	}
1152	return exists, &res.Stat, ech, err
1153}
1154
1155func (c *Conn) GetACL(path string) ([]ACL, *Stat, error) {
1156	if err := validatePath(path, false); err != nil {
1157		return nil, nil, err
1158	}
1159
1160	res := &getAclResponse{}
1161	_, err := c.request(opGetAcl, &getAclRequest{Path: path}, res, nil)
1162	return res.Acl, &res.Stat, err
1163}
1164func (c *Conn) SetACL(path string, acl []ACL, version int32) (*Stat, error) {
1165	if err := validatePath(path, false); err != nil {
1166		return nil, err
1167	}
1168
1169	res := &setAclResponse{}
1170	_, err := c.request(opSetAcl, &setAclRequest{Path: path, Acl: acl, Version: version}, res, nil)
1171	return &res.Stat, err
1172}
1173
1174func (c *Conn) Sync(path string) (string, error) {
1175	if err := validatePath(path, false); err != nil {
1176		return "", err
1177	}
1178
1179	res := &syncResponse{}
1180	_, err := c.request(opSync, &syncRequest{Path: path}, res, nil)
1181	return res.Path, err
1182}
1183
1184type MultiResponse struct {
1185	Stat   *Stat
1186	String string
1187	Error  error
1188}
1189
1190// Multi executes multiple ZooKeeper operations or none of them. The provided
1191// ops must be one of *CreateRequest, *DeleteRequest, *SetDataRequest, or
1192// *CheckVersionRequest.
1193func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) {
1194	req := &multiRequest{
1195		Ops:        make([]multiRequestOp, 0, len(ops)),
1196		DoneHeader: multiHeader{Type: -1, Done: true, Err: -1},
1197	}
1198	for _, op := range ops {
1199		var opCode int32
1200		switch op.(type) {
1201		case *CreateRequest:
1202			opCode = opCreate
1203		case *SetDataRequest:
1204			opCode = opSetData
1205		case *DeleteRequest:
1206			opCode = opDelete
1207		case *CheckVersionRequest:
1208			opCode = opCheck
1209		default:
1210			return nil, fmt.Errorf("unknown operation type %T", op)
1211		}
1212		req.Ops = append(req.Ops, multiRequestOp{multiHeader{opCode, false, -1}, op})
1213	}
1214	res := &multiResponse{}
1215	_, err := c.request(opMulti, req, res, nil)
1216	mr := make([]MultiResponse, len(res.Ops))
1217	for i, op := range res.Ops {
1218		mr[i] = MultiResponse{Stat: op.Stat, String: op.String, Error: op.Err.toError()}
1219	}
1220	return mr, err
1221}
1222
1223// Server returns the current or last-connected server name.
1224func (c *Conn) Server() string {
1225	c.serverMu.Lock()
1226	defer c.serverMu.Unlock()
1227	return c.server
1228}
1229