1// Package zk is a native Go client library for the ZooKeeper orchestration service. 2package zk 3 4/* 5TODO: 6* make sure a ping response comes back in a reasonable time 7 8Possible watcher events: 9* Event{Type: EventNotWatching, State: StateDisconnected, Path: path, Err: err} 10*/ 11 12import ( 13 "crypto/rand" 14 "encoding/binary" 15 "errors" 16 "fmt" 17 "io" 18 "net" 19 "strconv" 20 "strings" 21 "sync" 22 "sync/atomic" 23 "time" 24) 25 26// ErrNoServer indicates that an operation cannot be completed 27// because attempts to connect to all servers in the list failed. 28var ErrNoServer = errors.New("zk: could not connect to a server") 29 30// ErrInvalidPath indicates that an operation was being attempted on 31// an invalid path. (e.g. empty path) 32var ErrInvalidPath = errors.New("zk: invalid path") 33 34// DefaultLogger uses the stdlib log package for logging. 35var DefaultLogger Logger = defaultLogger{} 36 37const ( 38 bufferSize = 1536 * 1024 39 eventChanSize = 6 40 sendChanSize = 16 41 protectedPrefix = "_c_" 42) 43 44type watchType int 45 46const ( 47 watchTypeData = iota 48 watchTypeExist 49 watchTypeChild 50) 51 52type watchPathType struct { 53 path string 54 wType watchType 55} 56 57type Dialer func(network, address string, timeout time.Duration) (net.Conn, error) 58 59// Logger is an interface that can be implemented to provide custom log output. 60type Logger interface { 61 Printf(string, ...interface{}) 62} 63 64type authCreds struct { 65 scheme string 66 auth []byte 67} 68 69type Conn struct { 70 lastZxid int64 71 sessionID int64 72 state State // must be 32-bit aligned 73 xid uint32 74 sessionTimeoutMs int32 // session timeout in milliseconds 75 passwd []byte 76 77 dialer Dialer 78 hostProvider HostProvider 79 serverMu sync.Mutex // protects server 80 server string // remember the address/port of the current server 81 conn net.Conn 82 eventChan chan Event 83 eventCallback EventCallback // may be nil 84 shouldQuit chan struct{} 85 pingInterval time.Duration 86 recvTimeout time.Duration 87 connectTimeout time.Duration 88 maxBufferSize int 89 90 creds []authCreds 91 credsMu sync.Mutex // protects server 92 93 sendChan chan *request 94 requests map[int32]*request // Xid -> pending request 95 requestsLock sync.Mutex 96 watchers map[watchPathType][]chan Event 97 watchersLock sync.Mutex 98 closeChan chan struct{} // channel to tell send loop stop 99 100 // Debug (used by unit tests) 101 reconnectLatch chan struct{} 102 setWatchLimit int 103 setWatchCallback func([]*setWatchesRequest) 104 // Debug (for recurring re-auth hang) 105 debugCloseRecvLoop bool 106 debugReauthDone chan struct{} 107 108 logger Logger 109 logInfo bool // true if information messages are logged; false if only errors are logged 110 111 buf []byte 112} 113 114// connOption represents a connection option. 115type connOption func(c *Conn) 116 117type request struct { 118 xid int32 119 opcode int32 120 pkt interface{} 121 recvStruct interface{} 122 recvChan chan response 123 124 // Because sending and receiving happen in separate go routines, there's 125 // a possible race condition when creating watches from outside the read 126 // loop. We must ensure that a watcher gets added to the list synchronously 127 // with the response from the server on any request that creates a watch. 128 // In order to not hard code the watch logic for each opcode in the recv 129 // loop the caller can use recvFunc to insert some synchronously code 130 // after a response. 131 recvFunc func(*request, *responseHeader, error) 132} 133 134type response struct { 135 zxid int64 136 err error 137} 138 139type Event struct { 140 Type EventType 141 State State 142 Path string // For non-session events, the path of the watched node. 143 Err error 144 Server string // For connection events 145} 146 147// HostProvider is used to represent a set of hosts a ZooKeeper client should connect to. 148// It is an analog of the Java equivalent: 149// http://svn.apache.org/viewvc/zookeeper/trunk/src/java/main/org/apache/zookeeper/client/HostProvider.java?view=markup 150type HostProvider interface { 151 // Init is called first, with the servers specified in the connection string. 152 Init(servers []string) error 153 // Len returns the number of servers. 154 Len() int 155 // Next returns the next server to connect to. retryStart will be true if we've looped through 156 // all known servers without Connected() being called. 157 Next() (server string, retryStart bool) 158 // Notify the HostProvider of a successful connection. 159 Connected() 160} 161 162// ConnectWithDialer establishes a new connection to a pool of zookeeper servers 163// using a custom Dialer. See Connect for further information about session timeout. 164// This method is deprecated and provided for compatibility: use the WithDialer option instead. 165func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Dialer) (*Conn, <-chan Event, error) { 166 return Connect(servers, sessionTimeout, WithDialer(dialer)) 167} 168 169// Connect establishes a new connection to a pool of zookeeper 170// servers. The provided session timeout sets the amount of time for which 171// a session is considered valid after losing connection to a server. Within 172// the session timeout it's possible to reestablish a connection to a different 173// server and keep the same session. This is means any ephemeral nodes and 174// watches are maintained. 175func Connect(servers []string, sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) { 176 if len(servers) == 0 { 177 return nil, nil, errors.New("zk: server list must not be empty") 178 } 179 180 srvs := make([]string, len(servers)) 181 182 for i, addr := range servers { 183 if strings.Contains(addr, ":") { 184 srvs[i] = addr 185 } else { 186 srvs[i] = addr + ":" + strconv.Itoa(DefaultPort) 187 } 188 } 189 190 // Randomize the order of the servers to avoid creating hotspots 191 stringShuffle(srvs) 192 193 ec := make(chan Event, eventChanSize) 194 conn := &Conn{ 195 dialer: net.DialTimeout, 196 hostProvider: &DNSHostProvider{}, 197 conn: nil, 198 state: StateDisconnected, 199 eventChan: ec, 200 shouldQuit: make(chan struct{}), 201 connectTimeout: 1 * time.Second, 202 sendChan: make(chan *request, sendChanSize), 203 requests: make(map[int32]*request), 204 watchers: make(map[watchPathType][]chan Event), 205 passwd: emptyPassword, 206 logger: DefaultLogger, 207 logInfo: true, // default is true for backwards compatability 208 buf: make([]byte, bufferSize), 209 } 210 211 // Set provided options. 212 for _, option := range options { 213 option(conn) 214 } 215 216 if err := conn.hostProvider.Init(srvs); err != nil { 217 return nil, nil, err 218 } 219 220 conn.setTimeouts(int32(sessionTimeout / time.Millisecond)) 221 222 go func() { 223 conn.loop() 224 conn.flushRequests(ErrClosing) 225 conn.invalidateWatches(ErrClosing) 226 close(conn.eventChan) 227 }() 228 return conn, ec, nil 229} 230 231// WithDialer returns a connection option specifying a non-default Dialer. 232func WithDialer(dialer Dialer) connOption { 233 return func(c *Conn) { 234 c.dialer = dialer 235 } 236} 237 238// WithHostProvider returns a connection option specifying a non-default HostProvider. 239func WithHostProvider(hostProvider HostProvider) connOption { 240 return func(c *Conn) { 241 c.hostProvider = hostProvider 242 } 243} 244 245// WithLogger returns a connection option specifying a non-default Logger 246func WithLogger(logger Logger) connOption { 247 return func(c *Conn) { 248 c.logger = logger 249 } 250} 251 252// WithLogInfo returns a connection option specifying whether or not information messages 253// shoud be logged. 254func WithLogInfo(logInfo bool) connOption { 255 return func(c *Conn) { 256 c.logInfo = logInfo 257 } 258} 259 260// EventCallback is a function that is called when an Event occurs. 261type EventCallback func(Event) 262 263// WithEventCallback returns a connection option that specifies an event 264// callback. 265// The callback must not block - doing so would delay the ZK go routines. 266func WithEventCallback(cb EventCallback) connOption { 267 return func(c *Conn) { 268 c.eventCallback = cb 269 } 270} 271 272// WithMaxBufferSize sets the maximum buffer size used to read and decode 273// packets received from the Zookeeper server. The standard Zookeeper client for 274// Java defaults to a limit of 1mb. For backwards compatibility, this Go client 275// defaults to unbounded unless overridden via this option. A value that is zero 276// or negative indicates that no limit is enforced. 277// 278// This is meant to prevent resource exhaustion in the face of potentially 279// malicious data in ZK. It should generally match the server setting (which 280// also defaults ot 1mb) so that clients and servers agree on the limits for 281// things like the size of data in an individual znode and the total size of a 282// transaction. 283// 284// For production systems, this should be set to a reasonable value (ideally 285// that matches the server configuration). For ops tooling, it is handy to use a 286// much larger limit, in order to do things like clean-up problematic state in 287// the ZK tree. For example, if a single znode has a huge number of children, it 288// is possible for the response to a "list children" operation to exceed this 289// buffer size and cause errors in clients. The only way to subsequently clean 290// up the tree (by removing superfluous children) is to use a client configured 291// with a larger buffer size that can successfully query for all of the child 292// names and then remove them. (Note there are other tools that can list all of 293// the child names without an increased buffer size in the client, but they work 294// by inspecting the servers' transaction logs to enumerate children instead of 295// sending an online request to a server. 296func WithMaxBufferSize(maxBufferSize int) connOption { 297 return func(c *Conn) { 298 c.maxBufferSize = maxBufferSize 299 } 300} 301 302// WithMaxConnBufferSize sets maximum buffer size used to send and encode 303// packets to Zookeeper server. The standard Zookeepeer client for java defaults 304// to a limit of 1mb. This option should be used for non-standard server setup 305// where znode is bigger than default 1mb. 306func WithMaxConnBufferSize(maxBufferSize int) connOption { 307 return func(c *Conn) { 308 c.buf = make([]byte, maxBufferSize) 309 } 310} 311 312func (c *Conn) Close() { 313 close(c.shouldQuit) 314 315 select { 316 case <-c.queueRequest(opClose, &closeRequest{}, &closeResponse{}, nil): 317 case <-time.After(time.Second): 318 } 319} 320 321// State returns the current state of the connection. 322func (c *Conn) State() State { 323 return State(atomic.LoadInt32((*int32)(&c.state))) 324} 325 326// SessionID returns the current session id of the connection. 327func (c *Conn) SessionID() int64 { 328 return atomic.LoadInt64(&c.sessionID) 329} 330 331// SetLogger sets the logger to be used for printing errors. 332// Logger is an interface provided by this package. 333func (c *Conn) SetLogger(l Logger) { 334 c.logger = l 335} 336 337func (c *Conn) setTimeouts(sessionTimeoutMs int32) { 338 c.sessionTimeoutMs = sessionTimeoutMs 339 sessionTimeout := time.Duration(sessionTimeoutMs) * time.Millisecond 340 c.recvTimeout = sessionTimeout * 2 / 3 341 c.pingInterval = c.recvTimeout / 2 342} 343 344func (c *Conn) setState(state State) { 345 atomic.StoreInt32((*int32)(&c.state), int32(state)) 346 c.sendEvent(Event{Type: EventSession, State: state, Server: c.Server()}) 347} 348 349func (c *Conn) sendEvent(evt Event) { 350 if c.eventCallback != nil { 351 c.eventCallback(evt) 352 } 353 354 select { 355 case c.eventChan <- evt: 356 default: 357 // panic("zk: event channel full - it must be monitored and never allowed to be full") 358 } 359} 360 361func (c *Conn) connect() error { 362 var retryStart bool 363 for { 364 c.serverMu.Lock() 365 c.server, retryStart = c.hostProvider.Next() 366 c.serverMu.Unlock() 367 c.setState(StateConnecting) 368 if retryStart { 369 c.flushUnsentRequests(ErrNoServer) 370 select { 371 case <-time.After(time.Second): 372 // pass 373 case <-c.shouldQuit: 374 c.setState(StateDisconnected) 375 c.flushUnsentRequests(ErrClosing) 376 return ErrClosing 377 } 378 } 379 380 zkConn, err := c.dialer("tcp", c.Server(), c.connectTimeout) 381 if err == nil { 382 c.conn = zkConn 383 c.setState(StateConnected) 384 if c.logInfo { 385 c.logger.Printf("Connected to %s", c.Server()) 386 } 387 return nil 388 } 389 390 c.logger.Printf("Failed to connect to %s: %+v", c.Server(), err) 391 } 392} 393 394func (c *Conn) resendZkAuth(reauthReadyChan chan struct{}) { 395 shouldCancel := func() bool { 396 select { 397 case <-c.shouldQuit: 398 return true 399 case <-c.closeChan: 400 return true 401 default: 402 return false 403 } 404 } 405 406 c.credsMu.Lock() 407 defer c.credsMu.Unlock() 408 409 defer close(reauthReadyChan) 410 411 if c.logInfo { 412 c.logger.Printf("Re-submitting `%d` credentials after reconnect", 413 len(c.creds)) 414 } 415 416 for _, cred := range c.creds { 417 if shouldCancel() { 418 c.logger.Printf("Cancel rer-submitting credentials") 419 return 420 } 421 resChan, err := c.sendRequest( 422 opSetAuth, 423 &setAuthRequest{Type: 0, 424 Scheme: cred.scheme, 425 Auth: cred.auth, 426 }, 427 &setAuthResponse{}, 428 nil) 429 430 if err != nil { 431 c.logger.Printf("Call to sendRequest failed during credential resubmit: %s", err) 432 // FIXME(prozlach): lets ignore errors for now 433 continue 434 } 435 436 var res response 437 select { 438 case res = <-resChan: 439 case <-c.closeChan: 440 c.logger.Printf("Recv closed, cancel re-submitting credentials") 441 return 442 case <-c.shouldQuit: 443 c.logger.Printf("Should quit, cancel re-submitting credentials") 444 return 445 } 446 if res.err != nil { 447 c.logger.Printf("Credential re-submit failed: %s", res.err) 448 // FIXME(prozlach): lets ignore errors for now 449 continue 450 } 451 } 452} 453 454func (c *Conn) sendRequest( 455 opcode int32, 456 req interface{}, 457 res interface{}, 458 recvFunc func(*request, *responseHeader, error), 459) ( 460 <-chan response, 461 error, 462) { 463 rq := &request{ 464 xid: c.nextXid(), 465 opcode: opcode, 466 pkt: req, 467 recvStruct: res, 468 recvChan: make(chan response, 1), 469 recvFunc: recvFunc, 470 } 471 472 if err := c.sendData(rq); err != nil { 473 return nil, err 474 } 475 476 return rq.recvChan, nil 477} 478 479func (c *Conn) loop() { 480 for { 481 if err := c.connect(); err != nil { 482 // c.Close() was called 483 return 484 } 485 486 err := c.authenticate() 487 switch { 488 case err == ErrSessionExpired: 489 c.logger.Printf("Authentication failed: %s", err) 490 c.invalidateWatches(err) 491 case err != nil && c.conn != nil: 492 c.logger.Printf("Authentication failed: %s", err) 493 c.conn.Close() 494 case err == nil: 495 if c.logInfo { 496 c.logger.Printf("Authenticated: id=%d, timeout=%d", c.SessionID(), c.sessionTimeoutMs) 497 } 498 c.hostProvider.Connected() // mark success 499 c.closeChan = make(chan struct{}) // channel to tell send loop stop 500 reauthChan := make(chan struct{}) // channel to tell send loop that authdata has been resubmitted 501 502 var wg sync.WaitGroup 503 wg.Add(1) 504 go func() { 505 <-reauthChan 506 if c.debugCloseRecvLoop { 507 close(c.debugReauthDone) 508 } 509 err := c.sendLoop() 510 if err != nil || c.logInfo { 511 c.logger.Printf("Send loop terminated: err=%v", err) 512 } 513 c.conn.Close() // causes recv loop to EOF/exit 514 wg.Done() 515 }() 516 517 wg.Add(1) 518 go func() { 519 var err error 520 if c.debugCloseRecvLoop { 521 err = errors.New("DEBUG: close recv loop") 522 } else { 523 err = c.recvLoop(c.conn) 524 } 525 if err != io.EOF || c.logInfo { 526 c.logger.Printf("Recv loop terminated: err=%v", err) 527 } 528 if err == nil { 529 panic("zk: recvLoop should never return nil error") 530 } 531 close(c.closeChan) // tell send loop to exit 532 wg.Done() 533 }() 534 535 c.resendZkAuth(reauthChan) 536 537 c.sendSetWatches() 538 wg.Wait() 539 } 540 541 c.setState(StateDisconnected) 542 543 select { 544 case <-c.shouldQuit: 545 c.flushRequests(ErrClosing) 546 return 547 default: 548 } 549 550 if err != ErrSessionExpired { 551 err = ErrConnectionClosed 552 } 553 c.flushRequests(err) 554 555 if c.reconnectLatch != nil { 556 select { 557 case <-c.shouldQuit: 558 return 559 case <-c.reconnectLatch: 560 } 561 } 562 } 563} 564 565func (c *Conn) flushUnsentRequests(err error) { 566 for { 567 select { 568 default: 569 return 570 case req := <-c.sendChan: 571 req.recvChan <- response{-1, err} 572 } 573 } 574} 575 576// Send error to all pending requests and clear request map 577func (c *Conn) flushRequests(err error) { 578 c.requestsLock.Lock() 579 for _, req := range c.requests { 580 req.recvChan <- response{-1, err} 581 } 582 c.requests = make(map[int32]*request) 583 c.requestsLock.Unlock() 584} 585 586// Send error to all watchers and clear watchers map 587func (c *Conn) invalidateWatches(err error) { 588 c.watchersLock.Lock() 589 defer c.watchersLock.Unlock() 590 591 if len(c.watchers) >= 0 { 592 for pathType, watchers := range c.watchers { 593 ev := Event{Type: EventNotWatching, State: StateDisconnected, Path: pathType.path, Err: err} 594 for _, ch := range watchers { 595 ch <- ev 596 close(ch) 597 } 598 } 599 c.watchers = make(map[watchPathType][]chan Event) 600 } 601} 602 603func (c *Conn) sendSetWatches() { 604 c.watchersLock.Lock() 605 defer c.watchersLock.Unlock() 606 607 if len(c.watchers) == 0 { 608 return 609 } 610 611 // NB: A ZK server, by default, rejects packets >1mb. So, if we have too 612 // many watches to reset, we need to break this up into multiple packets 613 // to avoid hitting that limit. Mirroring the Java client behavior: we are 614 // conservative in that we limit requests to 128kb (since server limit is 615 // is actually configurable and could conceivably be configured smaller 616 // than default of 1mb). 617 limit := 128 * 1024 618 if c.setWatchLimit > 0 { 619 limit = c.setWatchLimit 620 } 621 622 var reqs []*setWatchesRequest 623 var req *setWatchesRequest 624 var sizeSoFar int 625 626 n := 0 627 for pathType, watchers := range c.watchers { 628 if len(watchers) == 0 { 629 continue 630 } 631 addlLen := 4 + len(pathType.path) 632 if req == nil || sizeSoFar+addlLen > limit { 633 if req != nil { 634 // add to set of requests that we'll send 635 reqs = append(reqs, req) 636 } 637 sizeSoFar = 28 // fixed overhead of a set-watches packet 638 req = &setWatchesRequest{ 639 RelativeZxid: c.lastZxid, 640 DataWatches: make([]string, 0), 641 ExistWatches: make([]string, 0), 642 ChildWatches: make([]string, 0), 643 } 644 } 645 sizeSoFar += addlLen 646 switch pathType.wType { 647 case watchTypeData: 648 req.DataWatches = append(req.DataWatches, pathType.path) 649 case watchTypeExist: 650 req.ExistWatches = append(req.ExistWatches, pathType.path) 651 case watchTypeChild: 652 req.ChildWatches = append(req.ChildWatches, pathType.path) 653 } 654 n++ 655 } 656 if n == 0 { 657 return 658 } 659 if req != nil { // don't forget any trailing packet we were building 660 reqs = append(reqs, req) 661 } 662 663 if c.setWatchCallback != nil { 664 c.setWatchCallback(reqs) 665 } 666 667 go func() { 668 res := &setWatchesResponse{} 669 // TODO: Pipeline these so queue all of them up before waiting on any 670 // response. That will require some investigation to make sure there 671 // aren't failure modes where a blocking write to the channel of requests 672 // could hang indefinitely and cause this goroutine to leak... 673 for _, req := range reqs { 674 _, err := c.request(opSetWatches, req, res, nil) 675 if err != nil { 676 c.logger.Printf("Failed to set previous watches: %s", err.Error()) 677 break 678 } 679 } 680 }() 681} 682 683func (c *Conn) authenticate() error { 684 buf := make([]byte, 256) 685 686 // Encode and send a connect request. 687 n, err := encodePacket(buf[4:], &connectRequest{ 688 ProtocolVersion: protocolVersion, 689 LastZxidSeen: c.lastZxid, 690 TimeOut: c.sessionTimeoutMs, 691 SessionID: c.SessionID(), 692 Passwd: c.passwd, 693 }) 694 if err != nil { 695 return err 696 } 697 698 binary.BigEndian.PutUint32(buf[:4], uint32(n)) 699 700 c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout * 10)) 701 _, err = c.conn.Write(buf[:n+4]) 702 c.conn.SetWriteDeadline(time.Time{}) 703 if err != nil { 704 return err 705 } 706 707 // Receive and decode a connect response. 708 c.conn.SetReadDeadline(time.Now().Add(c.recvTimeout * 10)) 709 _, err = io.ReadFull(c.conn, buf[:4]) 710 c.conn.SetReadDeadline(time.Time{}) 711 if err != nil { 712 return err 713 } 714 715 blen := int(binary.BigEndian.Uint32(buf[:4])) 716 if cap(buf) < blen { 717 buf = make([]byte, blen) 718 } 719 720 _, err = io.ReadFull(c.conn, buf[:blen]) 721 if err != nil { 722 return err 723 } 724 725 r := connectResponse{} 726 _, err = decodePacket(buf[:blen], &r) 727 if err != nil { 728 return err 729 } 730 if r.SessionID == 0 { 731 atomic.StoreInt64(&c.sessionID, int64(0)) 732 c.passwd = emptyPassword 733 c.lastZxid = 0 734 c.setState(StateExpired) 735 return ErrSessionExpired 736 } 737 738 atomic.StoreInt64(&c.sessionID, r.SessionID) 739 c.setTimeouts(r.TimeOut) 740 c.passwd = r.Passwd 741 c.setState(StateHasSession) 742 743 return nil 744} 745 746func (c *Conn) sendData(req *request) error { 747 header := &requestHeader{req.xid, req.opcode} 748 n, err := encodePacket(c.buf[4:], header) 749 if err != nil { 750 req.recvChan <- response{-1, err} 751 return nil 752 } 753 754 n2, err := encodePacket(c.buf[4+n:], req.pkt) 755 if err != nil { 756 req.recvChan <- response{-1, err} 757 return nil 758 } 759 760 n += n2 761 762 binary.BigEndian.PutUint32(c.buf[:4], uint32(n)) 763 764 c.requestsLock.Lock() 765 select { 766 case <-c.closeChan: 767 req.recvChan <- response{-1, ErrConnectionClosed} 768 c.requestsLock.Unlock() 769 return ErrConnectionClosed 770 default: 771 } 772 c.requests[req.xid] = req 773 c.requestsLock.Unlock() 774 775 c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout)) 776 _, err = c.conn.Write(c.buf[:n+4]) 777 c.conn.SetWriteDeadline(time.Time{}) 778 if err != nil { 779 req.recvChan <- response{-1, err} 780 c.conn.Close() 781 return err 782 } 783 784 return nil 785} 786 787func (c *Conn) sendLoop() error { 788 pingTicker := time.NewTicker(c.pingInterval) 789 defer pingTicker.Stop() 790 791 for { 792 select { 793 case req := <-c.sendChan: 794 if err := c.sendData(req); err != nil { 795 return err 796 } 797 case <-pingTicker.C: 798 n, err := encodePacket(c.buf[4:], &requestHeader{Xid: -2, Opcode: opPing}) 799 if err != nil { 800 panic("zk: opPing should never fail to serialize") 801 } 802 803 binary.BigEndian.PutUint32(c.buf[:4], uint32(n)) 804 805 c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout)) 806 _, err = c.conn.Write(c.buf[:n+4]) 807 c.conn.SetWriteDeadline(time.Time{}) 808 if err != nil { 809 c.conn.Close() 810 return err 811 } 812 case <-c.closeChan: 813 return nil 814 } 815 } 816} 817 818func (c *Conn) recvLoop(conn net.Conn) error { 819 sz := bufferSize 820 if c.maxBufferSize > 0 && sz > c.maxBufferSize { 821 sz = c.maxBufferSize 822 } 823 buf := make([]byte, sz) 824 for { 825 // package length 826 conn.SetReadDeadline(time.Now().Add(c.recvTimeout)) 827 _, err := io.ReadFull(conn, buf[:4]) 828 if err != nil { 829 return err 830 } 831 832 blen := int(binary.BigEndian.Uint32(buf[:4])) 833 if cap(buf) < blen { 834 if c.maxBufferSize > 0 && blen > c.maxBufferSize { 835 return fmt.Errorf("received packet from server with length %d, which exceeds max buffer size %d", blen, c.maxBufferSize) 836 } 837 buf = make([]byte, blen) 838 } 839 840 _, err = io.ReadFull(conn, buf[:blen]) 841 conn.SetReadDeadline(time.Time{}) 842 if err != nil { 843 return err 844 } 845 846 res := responseHeader{} 847 _, err = decodePacket(buf[:16], &res) 848 if err != nil { 849 return err 850 } 851 852 if res.Xid == -1 { 853 res := &watcherEvent{} 854 _, err := decodePacket(buf[16:blen], res) 855 if err != nil { 856 return err 857 } 858 ev := Event{ 859 Type: res.Type, 860 State: res.State, 861 Path: res.Path, 862 Err: nil, 863 } 864 c.sendEvent(ev) 865 wTypes := make([]watchType, 0, 2) 866 switch res.Type { 867 case EventNodeCreated: 868 wTypes = append(wTypes, watchTypeExist) 869 case EventNodeDeleted, EventNodeDataChanged: 870 wTypes = append(wTypes, watchTypeExist, watchTypeData, watchTypeChild) 871 case EventNodeChildrenChanged: 872 wTypes = append(wTypes, watchTypeChild) 873 } 874 c.watchersLock.Lock() 875 for _, t := range wTypes { 876 wpt := watchPathType{res.Path, t} 877 if watchers := c.watchers[wpt]; watchers != nil && len(watchers) > 0 { 878 for _, ch := range watchers { 879 ch <- ev 880 close(ch) 881 } 882 delete(c.watchers, wpt) 883 } 884 } 885 c.watchersLock.Unlock() 886 } else if res.Xid == -2 { 887 // Ping response. Ignore. 888 } else if res.Xid < 0 { 889 c.logger.Printf("Xid < 0 (%d) but not ping or watcher event", res.Xid) 890 } else { 891 if res.Zxid > 0 { 892 c.lastZxid = res.Zxid 893 } 894 895 c.requestsLock.Lock() 896 req, ok := c.requests[res.Xid] 897 if ok { 898 delete(c.requests, res.Xid) 899 } 900 c.requestsLock.Unlock() 901 902 if !ok { 903 c.logger.Printf("Response for unknown request with xid %d", res.Xid) 904 } else { 905 if res.Err != 0 { 906 err = res.Err.toError() 907 } else { 908 _, err = decodePacket(buf[16:blen], req.recvStruct) 909 } 910 if req.recvFunc != nil { 911 req.recvFunc(req, &res, err) 912 } 913 req.recvChan <- response{res.Zxid, err} 914 if req.opcode == opClose { 915 return io.EOF 916 } 917 } 918 } 919 } 920} 921 922func (c *Conn) nextXid() int32 { 923 return int32(atomic.AddUint32(&c.xid, 1) & 0x7fffffff) 924} 925 926func (c *Conn) addWatcher(path string, watchType watchType) <-chan Event { 927 c.watchersLock.Lock() 928 defer c.watchersLock.Unlock() 929 930 ch := make(chan Event, 1) 931 wpt := watchPathType{path, watchType} 932 c.watchers[wpt] = append(c.watchers[wpt], ch) 933 return ch 934} 935 936func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) <-chan response { 937 rq := &request{ 938 xid: c.nextXid(), 939 opcode: opcode, 940 pkt: req, 941 recvStruct: res, 942 recvChan: make(chan response, 1), 943 recvFunc: recvFunc, 944 } 945 c.sendChan <- rq 946 return rq.recvChan 947} 948 949func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (int64, error) { 950 r := <-c.queueRequest(opcode, req, res, recvFunc) 951 return r.zxid, r.err 952} 953 954func (c *Conn) AddAuth(scheme string, auth []byte) error { 955 _, err := c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil) 956 957 if err != nil { 958 return err 959 } 960 961 // Remember authdata so that it can be re-submitted on reconnect 962 // 963 // FIXME(prozlach): For now we treat "userfoo:passbar" and "userfoo:passbar2" 964 // as two different entries, which will be re-submitted on reconnet. Some 965 // research is needed on how ZK treats these cases and 966 // then maybe switch to something like "map[username] = password" to allow 967 // only single password for given user with users being unique. 968 obj := authCreds{ 969 scheme: scheme, 970 auth: auth, 971 } 972 973 c.credsMu.Lock() 974 c.creds = append(c.creds, obj) 975 c.credsMu.Unlock() 976 977 return nil 978} 979 980func (c *Conn) Children(path string) ([]string, *Stat, error) { 981 if err := validatePath(path, false); err != nil { 982 return nil, nil, err 983 } 984 985 res := &getChildren2Response{} 986 _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: false}, res, nil) 987 return res.Children, &res.Stat, err 988} 989 990func (c *Conn) ChildrenW(path string) ([]string, *Stat, <-chan Event, error) { 991 if err := validatePath(path, false); err != nil { 992 return nil, nil, nil, err 993 } 994 995 var ech <-chan Event 996 res := &getChildren2Response{} 997 _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { 998 if err == nil { 999 ech = c.addWatcher(path, watchTypeChild) 1000 } 1001 }) 1002 if err != nil { 1003 return nil, nil, nil, err 1004 } 1005 return res.Children, &res.Stat, ech, err 1006} 1007 1008func (c *Conn) Get(path string) ([]byte, *Stat, error) { 1009 if err := validatePath(path, false); err != nil { 1010 return nil, nil, err 1011 } 1012 1013 res := &getDataResponse{} 1014 _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: false}, res, nil) 1015 return res.Data, &res.Stat, err 1016} 1017 1018// GetW returns the contents of a znode and sets a watch 1019func (c *Conn) GetW(path string) ([]byte, *Stat, <-chan Event, error) { 1020 if err := validatePath(path, false); err != nil { 1021 return nil, nil, nil, err 1022 } 1023 1024 var ech <-chan Event 1025 res := &getDataResponse{} 1026 _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { 1027 if err == nil { 1028 ech = c.addWatcher(path, watchTypeData) 1029 } 1030 }) 1031 if err != nil { 1032 return nil, nil, nil, err 1033 } 1034 return res.Data, &res.Stat, ech, err 1035} 1036 1037func (c *Conn) Set(path string, data []byte, version int32) (*Stat, error) { 1038 if err := validatePath(path, false); err != nil { 1039 return nil, err 1040 } 1041 1042 res := &setDataResponse{} 1043 _, err := c.request(opSetData, &SetDataRequest{path, data, version}, res, nil) 1044 return &res.Stat, err 1045} 1046 1047func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, error) { 1048 if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil { 1049 return "", err 1050 } 1051 1052 res := &createResponse{} 1053 _, err := c.request(opCreate, &CreateRequest{path, data, acl, flags}, res, nil) 1054 return res.Path, err 1055} 1056 1057// CreateProtectedEphemeralSequential fixes a race condition if the server crashes 1058// after it creates the node. On reconnect the session may still be valid so the 1059// ephemeral node still exists. Therefore, on reconnect we need to check if a node 1060// with a GUID generated on create exists. 1061func (c *Conn) CreateProtectedEphemeralSequential(path string, data []byte, acl []ACL) (string, error) { 1062 if err := validatePath(path, true); err != nil { 1063 return "", err 1064 } 1065 1066 var guid [16]byte 1067 _, err := io.ReadFull(rand.Reader, guid[:16]) 1068 if err != nil { 1069 return "", err 1070 } 1071 guidStr := fmt.Sprintf("%x", guid) 1072 1073 parts := strings.Split(path, "/") 1074 parts[len(parts)-1] = fmt.Sprintf("%s%s-%s", protectedPrefix, guidStr, parts[len(parts)-1]) 1075 rootPath := strings.Join(parts[:len(parts)-1], "/") 1076 protectedPath := strings.Join(parts, "/") 1077 1078 var newPath string 1079 for i := 0; i < 3; i++ { 1080 newPath, err = c.Create(protectedPath, data, FlagEphemeral|FlagSequence, acl) 1081 switch err { 1082 case ErrSessionExpired: 1083 // No need to search for the node since it can't exist. Just try again. 1084 case ErrConnectionClosed: 1085 children, _, err := c.Children(rootPath) 1086 if err != nil { 1087 return "", err 1088 } 1089 for _, p := range children { 1090 parts := strings.Split(p, "/") 1091 if pth := parts[len(parts)-1]; strings.HasPrefix(pth, protectedPrefix) { 1092 if g := pth[len(protectedPrefix) : len(protectedPrefix)+32]; g == guidStr { 1093 return rootPath + "/" + p, nil 1094 } 1095 } 1096 } 1097 case nil: 1098 return newPath, nil 1099 default: 1100 return "", err 1101 } 1102 } 1103 return "", err 1104} 1105 1106func (c *Conn) Delete(path string, version int32) error { 1107 if err := validatePath(path, false); err != nil { 1108 return err 1109 } 1110 1111 _, err := c.request(opDelete, &DeleteRequest{path, version}, &deleteResponse{}, nil) 1112 return err 1113} 1114 1115func (c *Conn) Exists(path string) (bool, *Stat, error) { 1116 if err := validatePath(path, false); err != nil { 1117 return false, nil, err 1118 } 1119 1120 res := &existsResponse{} 1121 _, err := c.request(opExists, &existsRequest{Path: path, Watch: false}, res, nil) 1122 exists := true 1123 if err == ErrNoNode { 1124 exists = false 1125 err = nil 1126 } 1127 return exists, &res.Stat, err 1128} 1129 1130func (c *Conn) ExistsW(path string) (bool, *Stat, <-chan Event, error) { 1131 if err := validatePath(path, false); err != nil { 1132 return false, nil, nil, err 1133 } 1134 1135 var ech <-chan Event 1136 res := &existsResponse{} 1137 _, err := c.request(opExists, &existsRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { 1138 if err == nil { 1139 ech = c.addWatcher(path, watchTypeData) 1140 } else if err == ErrNoNode { 1141 ech = c.addWatcher(path, watchTypeExist) 1142 } 1143 }) 1144 exists := true 1145 if err == ErrNoNode { 1146 exists = false 1147 err = nil 1148 } 1149 if err != nil { 1150 return false, nil, nil, err 1151 } 1152 return exists, &res.Stat, ech, err 1153} 1154 1155func (c *Conn) GetACL(path string) ([]ACL, *Stat, error) { 1156 if err := validatePath(path, false); err != nil { 1157 return nil, nil, err 1158 } 1159 1160 res := &getAclResponse{} 1161 _, err := c.request(opGetAcl, &getAclRequest{Path: path}, res, nil) 1162 return res.Acl, &res.Stat, err 1163} 1164func (c *Conn) SetACL(path string, acl []ACL, version int32) (*Stat, error) { 1165 if err := validatePath(path, false); err != nil { 1166 return nil, err 1167 } 1168 1169 res := &setAclResponse{} 1170 _, err := c.request(opSetAcl, &setAclRequest{Path: path, Acl: acl, Version: version}, res, nil) 1171 return &res.Stat, err 1172} 1173 1174func (c *Conn) Sync(path string) (string, error) { 1175 if err := validatePath(path, false); err != nil { 1176 return "", err 1177 } 1178 1179 res := &syncResponse{} 1180 _, err := c.request(opSync, &syncRequest{Path: path}, res, nil) 1181 return res.Path, err 1182} 1183 1184type MultiResponse struct { 1185 Stat *Stat 1186 String string 1187 Error error 1188} 1189 1190// Multi executes multiple ZooKeeper operations or none of them. The provided 1191// ops must be one of *CreateRequest, *DeleteRequest, *SetDataRequest, or 1192// *CheckVersionRequest. 1193func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) { 1194 req := &multiRequest{ 1195 Ops: make([]multiRequestOp, 0, len(ops)), 1196 DoneHeader: multiHeader{Type: -1, Done: true, Err: -1}, 1197 } 1198 for _, op := range ops { 1199 var opCode int32 1200 switch op.(type) { 1201 case *CreateRequest: 1202 opCode = opCreate 1203 case *SetDataRequest: 1204 opCode = opSetData 1205 case *DeleteRequest: 1206 opCode = opDelete 1207 case *CheckVersionRequest: 1208 opCode = opCheck 1209 default: 1210 return nil, fmt.Errorf("unknown operation type %T", op) 1211 } 1212 req.Ops = append(req.Ops, multiRequestOp{multiHeader{opCode, false, -1}, op}) 1213 } 1214 res := &multiResponse{} 1215 _, err := c.request(opMulti, req, res, nil) 1216 mr := make([]MultiResponse, len(res.Ops)) 1217 for i, op := range res.Ops { 1218 mr[i] = MultiResponse{Stat: op.Stat, String: op.String, Error: op.Err.toError()} 1219 } 1220 return mr, err 1221} 1222 1223// Server returns the current or last-connected server name. 1224func (c *Conn) Server() string { 1225 c.serverMu.Lock() 1226 defer c.serverMu.Unlock() 1227 return c.server 1228} 1229