1// Package zk is a native Go client library for the ZooKeeper orchestration service. 2package zk 3 4/* 5TODO: 6* make sure a ping response comes back in a reasonable time 7 8Possible watcher events: 9* Event{Type: EventNotWatching, State: StateDisconnected, Path: path, Err: err} 10*/ 11 12import ( 13 "context" 14 "crypto/rand" 15 "encoding/binary" 16 "errors" 17 "fmt" 18 "io" 19 "net" 20 "strings" 21 "sync" 22 "sync/atomic" 23 "time" 24) 25 26// ErrNoServer indicates that an operation cannot be completed 27// because attempts to connect to all servers in the list failed. 28var ErrNoServer = errors.New("zk: could not connect to a server") 29 30// ErrInvalidPath indicates that an operation was being attempted on 31// an invalid path. (e.g. empty path) 32var ErrInvalidPath = errors.New("zk: invalid path") 33 34// DefaultLogger uses the stdlib log package for logging. 35var DefaultLogger Logger = defaultLogger{} 36 37const ( 38 bufferSize = 1536 * 1024 39 eventChanSize = 6 40 sendChanSize = 16 41 protectedPrefix = "_c_" 42) 43 44type watchType int 45 46const ( 47 watchTypeData = iota 48 watchTypeExist 49 watchTypeChild 50) 51 52type watchPathType struct { 53 path string 54 wType watchType 55} 56 57type Dialer func(network, address string, timeout time.Duration) (net.Conn, error) 58 59// Logger is an interface that can be implemented to provide custom log output. 60type Logger interface { 61 Printf(string, ...interface{}) 62} 63 64type authCreds struct { 65 scheme string 66 auth []byte 67} 68 69type Conn struct { 70 lastZxid int64 71 sessionID int64 72 state State // must be 32-bit aligned 73 xid uint32 74 sessionTimeoutMs int32 // session timeout in milliseconds 75 passwd []byte 76 77 dialer Dialer 78 hostProvider HostProvider 79 serverMu sync.Mutex // protects server 80 server string // remember the address/port of the current server 81 conn net.Conn 82 eventChan chan Event 83 eventCallback EventCallback // may be nil 84 shouldQuit chan struct{} 85 shouldQuitOnce sync.Once 86 pingInterval time.Duration 87 recvTimeout time.Duration 88 connectTimeout time.Duration 89 maxBufferSize int 90 91 creds []authCreds 92 credsMu sync.Mutex // protects server 93 94 sendChan chan *request 95 requests map[int32]*request // Xid -> pending request 96 requestsLock sync.Mutex 97 watchers map[watchPathType][]chan Event 98 watchersLock sync.Mutex 99 closeChan chan struct{} // channel to tell send loop stop 100 101 // Debug (used by unit tests) 102 reconnectLatch chan struct{} 103 setWatchLimit int 104 setWatchCallback func([]*setWatchesRequest) 105 106 // Debug (for recurring re-auth hang) 107 debugCloseRecvLoop bool 108 resendZkAuthFn func(context.Context, *Conn) error 109 110 logger Logger 111 logInfo bool // true if information messages are logged; false if only errors are logged 112 113 buf []byte 114} 115 116// connOption represents a connection option. 117type connOption func(c *Conn) 118 119type request struct { 120 xid int32 121 opcode int32 122 pkt interface{} 123 recvStruct interface{} 124 recvChan chan response 125 126 // Because sending and receiving happen in separate go routines, there's 127 // a possible race condition when creating watches from outside the read 128 // loop. We must ensure that a watcher gets added to the list synchronously 129 // with the response from the server on any request that creates a watch. 130 // In order to not hard code the watch logic for each opcode in the recv 131 // loop the caller can use recvFunc to insert some synchronously code 132 // after a response. 133 recvFunc func(*request, *responseHeader, error) 134} 135 136type response struct { 137 zxid int64 138 err error 139} 140 141type Event struct { 142 Type EventType 143 State State 144 Path string // For non-session events, the path of the watched node. 145 Err error 146 Server string // For connection events 147} 148 149// HostProvider is used to represent a set of hosts a ZooKeeper client should connect to. 150// It is an analog of the Java equivalent: 151// http://svn.apache.org/viewvc/zookeeper/trunk/src/java/main/org/apache/zookeeper/client/HostProvider.java?view=markup 152type HostProvider interface { 153 // Init is called first, with the servers specified in the connection string. 154 Init(servers []string) error 155 // Len returns the number of servers. 156 Len() int 157 // Next returns the next server to connect to. retryStart will be true if we've looped through 158 // all known servers without Connected() being called. 159 Next() (server string, retryStart bool) 160 // Notify the HostProvider of a successful connection. 161 Connected() 162} 163 164// ConnectWithDialer establishes a new connection to a pool of zookeeper servers 165// using a custom Dialer. See Connect for further information about session timeout. 166// This method is deprecated and provided for compatibility: use the WithDialer option instead. 167func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Dialer) (*Conn, <-chan Event, error) { 168 return Connect(servers, sessionTimeout, WithDialer(dialer)) 169} 170 171// Connect establishes a new connection to a pool of zookeeper 172// servers. The provided session timeout sets the amount of time for which 173// a session is considered valid after losing connection to a server. Within 174// the session timeout it's possible to reestablish a connection to a different 175// server and keep the same session. This is means any ephemeral nodes and 176// watches are maintained. 177func Connect(servers []string, sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) { 178 if len(servers) == 0 { 179 return nil, nil, errors.New("zk: server list must not be empty") 180 } 181 182 srvs := FormatServers(servers) 183 184 // Randomize the order of the servers to avoid creating hotspots 185 stringShuffle(srvs) 186 187 ec := make(chan Event, eventChanSize) 188 conn := &Conn{ 189 dialer: net.DialTimeout, 190 hostProvider: &DNSHostProvider{}, 191 conn: nil, 192 state: StateDisconnected, 193 eventChan: ec, 194 shouldQuit: make(chan struct{}), 195 connectTimeout: 1 * time.Second, 196 sendChan: make(chan *request, sendChanSize), 197 requests: make(map[int32]*request), 198 watchers: make(map[watchPathType][]chan Event), 199 passwd: emptyPassword, 200 logger: DefaultLogger, 201 logInfo: true, // default is true for backwards compatability 202 buf: make([]byte, bufferSize), 203 resendZkAuthFn: resendZkAuth, 204 } 205 206 // Set provided options. 207 for _, option := range options { 208 option(conn) 209 } 210 211 if err := conn.hostProvider.Init(srvs); err != nil { 212 return nil, nil, err 213 } 214 215 conn.setTimeouts(int32(sessionTimeout / time.Millisecond)) 216 // TODO: This context should be passed in by the caller to be the connection lifecycle context. 217 ctx := context.Background() 218 219 go func() { 220 conn.loop(ctx) 221 conn.flushRequests(ErrClosing) 222 conn.invalidateWatches(ErrClosing) 223 close(conn.eventChan) 224 }() 225 return conn, ec, nil 226} 227 228// WithDialer returns a connection option specifying a non-default Dialer. 229func WithDialer(dialer Dialer) connOption { 230 return func(c *Conn) { 231 c.dialer = dialer 232 } 233} 234 235// WithHostProvider returns a connection option specifying a non-default HostProvider. 236func WithHostProvider(hostProvider HostProvider) connOption { 237 return func(c *Conn) { 238 c.hostProvider = hostProvider 239 } 240} 241 242// WithLogger returns a connection option specifying a non-default Logger 243func WithLogger(logger Logger) connOption { 244 return func(c *Conn) { 245 c.logger = logger 246 } 247} 248 249// WithLogInfo returns a connection option specifying whether or not information messages 250// shoud be logged. 251func WithLogInfo(logInfo bool) connOption { 252 return func(c *Conn) { 253 c.logInfo = logInfo 254 } 255} 256 257// EventCallback is a function that is called when an Event occurs. 258type EventCallback func(Event) 259 260// WithEventCallback returns a connection option that specifies an event 261// callback. 262// The callback must not block - doing so would delay the ZK go routines. 263func WithEventCallback(cb EventCallback) connOption { 264 return func(c *Conn) { 265 c.eventCallback = cb 266 } 267} 268 269// WithMaxBufferSize sets the maximum buffer size used to read and decode 270// packets received from the Zookeeper server. The standard Zookeeper client for 271// Java defaults to a limit of 1mb. For backwards compatibility, this Go client 272// defaults to unbounded unless overridden via this option. A value that is zero 273// or negative indicates that no limit is enforced. 274// 275// This is meant to prevent resource exhaustion in the face of potentially 276// malicious data in ZK. It should generally match the server setting (which 277// also defaults ot 1mb) so that clients and servers agree on the limits for 278// things like the size of data in an individual znode and the total size of a 279// transaction. 280// 281// For production systems, this should be set to a reasonable value (ideally 282// that matches the server configuration). For ops tooling, it is handy to use a 283// much larger limit, in order to do things like clean-up problematic state in 284// the ZK tree. For example, if a single znode has a huge number of children, it 285// is possible for the response to a "list children" operation to exceed this 286// buffer size and cause errors in clients. The only way to subsequently clean 287// up the tree (by removing superfluous children) is to use a client configured 288// with a larger buffer size that can successfully query for all of the child 289// names and then remove them. (Note there are other tools that can list all of 290// the child names without an increased buffer size in the client, but they work 291// by inspecting the servers' transaction logs to enumerate children instead of 292// sending an online request to a server. 293func WithMaxBufferSize(maxBufferSize int) connOption { 294 return func(c *Conn) { 295 c.maxBufferSize = maxBufferSize 296 } 297} 298 299// WithMaxConnBufferSize sets maximum buffer size used to send and encode 300// packets to Zookeeper server. The standard Zookeepeer client for java defaults 301// to a limit of 1mb. This option should be used for non-standard server setup 302// where znode is bigger than default 1mb. 303func WithMaxConnBufferSize(maxBufferSize int) connOption { 304 return func(c *Conn) { 305 c.buf = make([]byte, maxBufferSize) 306 } 307} 308 309// Close will submit a close request with ZK and signal the connection to stop 310// sending and receiving packets. 311func (c *Conn) Close() { 312 c.shouldQuitOnce.Do(func() { 313 close(c.shouldQuit) 314 315 select { 316 case <-c.queueRequest(opClose, &closeRequest{}, &closeResponse{}, nil): 317 case <-time.After(time.Second): 318 } 319 }) 320} 321 322// State returns the current state of the connection. 323func (c *Conn) State() State { 324 return State(atomic.LoadInt32((*int32)(&c.state))) 325} 326 327// SessionID returns the current session id of the connection. 328func (c *Conn) SessionID() int64 { 329 return atomic.LoadInt64(&c.sessionID) 330} 331 332// SetLogger sets the logger to be used for printing errors. 333// Logger is an interface provided by this package. 334func (c *Conn) SetLogger(l Logger) { 335 c.logger = l 336} 337 338func (c *Conn) setTimeouts(sessionTimeoutMs int32) { 339 c.sessionTimeoutMs = sessionTimeoutMs 340 sessionTimeout := time.Duration(sessionTimeoutMs) * time.Millisecond 341 c.recvTimeout = sessionTimeout * 2 / 3 342 c.pingInterval = c.recvTimeout / 2 343} 344 345func (c *Conn) setState(state State) { 346 atomic.StoreInt32((*int32)(&c.state), int32(state)) 347 c.sendEvent(Event{Type: EventSession, State: state, Server: c.Server()}) 348} 349 350func (c *Conn) sendEvent(evt Event) { 351 if c.eventCallback != nil { 352 c.eventCallback(evt) 353 } 354 355 select { 356 case c.eventChan <- evt: 357 default: 358 // panic("zk: event channel full - it must be monitored and never allowed to be full") 359 } 360} 361 362func (c *Conn) connect() error { 363 var retryStart bool 364 for { 365 c.serverMu.Lock() 366 c.server, retryStart = c.hostProvider.Next() 367 c.serverMu.Unlock() 368 369 c.setState(StateConnecting) 370 371 if retryStart { 372 c.flushUnsentRequests(ErrNoServer) 373 select { 374 case <-time.After(time.Second): 375 // pass 376 case <-c.shouldQuit: 377 c.setState(StateDisconnected) 378 c.flushUnsentRequests(ErrClosing) 379 return ErrClosing 380 } 381 } 382 383 zkConn, err := c.dialer("tcp", c.Server(), c.connectTimeout) 384 if err == nil { 385 c.conn = zkConn 386 c.setState(StateConnected) 387 if c.logInfo { 388 c.logger.Printf("connected to %s", c.Server()) 389 } 390 return nil 391 } 392 393 c.logger.Printf("failed to connect to %s: %v", c.Server(), err) 394 } 395} 396 397func (c *Conn) sendRequest( 398 opcode int32, 399 req interface{}, 400 res interface{}, 401 recvFunc func(*request, *responseHeader, error), 402) ( 403 <-chan response, 404 error, 405) { 406 rq := &request{ 407 xid: c.nextXid(), 408 opcode: opcode, 409 pkt: req, 410 recvStruct: res, 411 recvChan: make(chan response, 1), 412 recvFunc: recvFunc, 413 } 414 415 if err := c.sendData(rq); err != nil { 416 return nil, err 417 } 418 419 return rq.recvChan, nil 420} 421 422func (c *Conn) loop(ctx context.Context) { 423 for { 424 if err := c.connect(); err != nil { 425 // c.Close() was called 426 return 427 } 428 429 err := c.authenticate() 430 switch { 431 case err == ErrSessionExpired: 432 c.logger.Printf("authentication failed: %s", err) 433 c.invalidateWatches(err) 434 case err != nil && c.conn != nil: 435 c.logger.Printf("authentication failed: %s", err) 436 c.conn.Close() 437 case err == nil: 438 if c.logInfo { 439 c.logger.Printf("authenticated: id=%d, timeout=%d", c.SessionID(), c.sessionTimeoutMs) 440 } 441 c.hostProvider.Connected() // mark success 442 c.closeChan = make(chan struct{}) // channel to tell send loop stop 443 444 var wg sync.WaitGroup 445 446 wg.Add(1) 447 go func() { 448 defer c.conn.Close() // causes recv loop to EOF/exit 449 defer wg.Done() 450 451 if err := c.resendZkAuthFn(ctx, c); err != nil { 452 c.logger.Printf("error in resending auth creds: %v", err) 453 return 454 } 455 456 if err := c.sendLoop(); err != nil || c.logInfo { 457 c.logger.Printf("send loop terminated: %v", err) 458 } 459 }() 460 461 wg.Add(1) 462 go func() { 463 defer close(c.closeChan) // tell send loop to exit 464 defer wg.Done() 465 466 var err error 467 if c.debugCloseRecvLoop { 468 err = errors.New("DEBUG: close recv loop") 469 } else { 470 err = c.recvLoop(c.conn) 471 } 472 if err != io.EOF || c.logInfo { 473 c.logger.Printf("recv loop terminated: %v", err) 474 } 475 if err == nil { 476 panic("zk: recvLoop should never return nil error") 477 } 478 }() 479 480 c.sendSetWatches() 481 wg.Wait() 482 } 483 484 c.setState(StateDisconnected) 485 486 select { 487 case <-c.shouldQuit: 488 c.flushRequests(ErrClosing) 489 return 490 default: 491 } 492 493 if err != ErrSessionExpired { 494 err = ErrConnectionClosed 495 } 496 c.flushRequests(err) 497 498 if c.reconnectLatch != nil { 499 select { 500 case <-c.shouldQuit: 501 return 502 case <-c.reconnectLatch: 503 } 504 } 505 } 506} 507 508func (c *Conn) flushUnsentRequests(err error) { 509 for { 510 select { 511 default: 512 return 513 case req := <-c.sendChan: 514 req.recvChan <- response{-1, err} 515 } 516 } 517} 518 519// Send error to all pending requests and clear request map 520func (c *Conn) flushRequests(err error) { 521 c.requestsLock.Lock() 522 for _, req := range c.requests { 523 req.recvChan <- response{-1, err} 524 } 525 c.requests = make(map[int32]*request) 526 c.requestsLock.Unlock() 527} 528 529// Send error to all watchers and clear watchers map 530func (c *Conn) invalidateWatches(err error) { 531 c.watchersLock.Lock() 532 defer c.watchersLock.Unlock() 533 534 if len(c.watchers) >= 0 { 535 for pathType, watchers := range c.watchers { 536 ev := Event{Type: EventNotWatching, State: StateDisconnected, Path: pathType.path, Err: err} 537 for _, ch := range watchers { 538 ch <- ev 539 close(ch) 540 } 541 } 542 c.watchers = make(map[watchPathType][]chan Event) 543 } 544} 545 546func (c *Conn) sendSetWatches() { 547 c.watchersLock.Lock() 548 defer c.watchersLock.Unlock() 549 550 if len(c.watchers) == 0 { 551 return 552 } 553 554 // NB: A ZK server, by default, rejects packets >1mb. So, if we have too 555 // many watches to reset, we need to break this up into multiple packets 556 // to avoid hitting that limit. Mirroring the Java client behavior: we are 557 // conservative in that we limit requests to 128kb (since server limit is 558 // is actually configurable and could conceivably be configured smaller 559 // than default of 1mb). 560 limit := 128 * 1024 561 if c.setWatchLimit > 0 { 562 limit = c.setWatchLimit 563 } 564 565 var reqs []*setWatchesRequest 566 var req *setWatchesRequest 567 var sizeSoFar int 568 569 n := 0 570 for pathType, watchers := range c.watchers { 571 if len(watchers) == 0 { 572 continue 573 } 574 addlLen := 4 + len(pathType.path) 575 if req == nil || sizeSoFar+addlLen > limit { 576 if req != nil { 577 // add to set of requests that we'll send 578 reqs = append(reqs, req) 579 } 580 sizeSoFar = 28 // fixed overhead of a set-watches packet 581 req = &setWatchesRequest{ 582 RelativeZxid: c.lastZxid, 583 DataWatches: make([]string, 0), 584 ExistWatches: make([]string, 0), 585 ChildWatches: make([]string, 0), 586 } 587 } 588 sizeSoFar += addlLen 589 switch pathType.wType { 590 case watchTypeData: 591 req.DataWatches = append(req.DataWatches, pathType.path) 592 case watchTypeExist: 593 req.ExistWatches = append(req.ExistWatches, pathType.path) 594 case watchTypeChild: 595 req.ChildWatches = append(req.ChildWatches, pathType.path) 596 } 597 n++ 598 } 599 if n == 0 { 600 return 601 } 602 if req != nil { // don't forget any trailing packet we were building 603 reqs = append(reqs, req) 604 } 605 606 if c.setWatchCallback != nil { 607 c.setWatchCallback(reqs) 608 } 609 610 go func() { 611 res := &setWatchesResponse{} 612 // TODO: Pipeline these so queue all of them up before waiting on any 613 // response. That will require some investigation to make sure there 614 // aren't failure modes where a blocking write to the channel of requests 615 // could hang indefinitely and cause this goroutine to leak... 616 for _, req := range reqs { 617 _, err := c.request(opSetWatches, req, res, nil) 618 if err != nil { 619 c.logger.Printf("Failed to set previous watches: %v", err) 620 break 621 } 622 } 623 }() 624} 625 626func (c *Conn) authenticate() error { 627 buf := make([]byte, 256) 628 629 // Encode and send a connect request. 630 n, err := encodePacket(buf[4:], &connectRequest{ 631 ProtocolVersion: protocolVersion, 632 LastZxidSeen: c.lastZxid, 633 TimeOut: c.sessionTimeoutMs, 634 SessionID: c.SessionID(), 635 Passwd: c.passwd, 636 }) 637 if err != nil { 638 return err 639 } 640 641 binary.BigEndian.PutUint32(buf[:4], uint32(n)) 642 643 c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout * 10)) 644 _, err = c.conn.Write(buf[:n+4]) 645 c.conn.SetWriteDeadline(time.Time{}) 646 if err != nil { 647 return err 648 } 649 650 // Receive and decode a connect response. 651 c.conn.SetReadDeadline(time.Now().Add(c.recvTimeout * 10)) 652 _, err = io.ReadFull(c.conn, buf[:4]) 653 c.conn.SetReadDeadline(time.Time{}) 654 if err != nil { 655 return err 656 } 657 658 blen := int(binary.BigEndian.Uint32(buf[:4])) 659 if cap(buf) < blen { 660 buf = make([]byte, blen) 661 } 662 663 _, err = io.ReadFull(c.conn, buf[:blen]) 664 if err != nil { 665 return err 666 } 667 668 r := connectResponse{} 669 _, err = decodePacket(buf[:blen], &r) 670 if err != nil { 671 return err 672 } 673 if r.SessionID == 0 { 674 atomic.StoreInt64(&c.sessionID, int64(0)) 675 c.passwd = emptyPassword 676 c.lastZxid = 0 677 c.setState(StateExpired) 678 return ErrSessionExpired 679 } 680 681 atomic.StoreInt64(&c.sessionID, r.SessionID) 682 c.setTimeouts(r.TimeOut) 683 c.passwd = r.Passwd 684 c.setState(StateHasSession) 685 686 return nil 687} 688 689func (c *Conn) sendData(req *request) error { 690 header := &requestHeader{req.xid, req.opcode} 691 n, err := encodePacket(c.buf[4:], header) 692 if err != nil { 693 req.recvChan <- response{-1, err} 694 return nil 695 } 696 697 n2, err := encodePacket(c.buf[4+n:], req.pkt) 698 if err != nil { 699 req.recvChan <- response{-1, err} 700 return nil 701 } 702 703 n += n2 704 705 binary.BigEndian.PutUint32(c.buf[:4], uint32(n)) 706 707 c.requestsLock.Lock() 708 select { 709 case <-c.closeChan: 710 req.recvChan <- response{-1, ErrConnectionClosed} 711 c.requestsLock.Unlock() 712 return ErrConnectionClosed 713 default: 714 } 715 c.requests[req.xid] = req 716 c.requestsLock.Unlock() 717 718 c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout)) 719 _, err = c.conn.Write(c.buf[:n+4]) 720 c.conn.SetWriteDeadline(time.Time{}) 721 if err != nil { 722 req.recvChan <- response{-1, err} 723 c.conn.Close() 724 return err 725 } 726 727 return nil 728} 729 730func (c *Conn) sendLoop() error { 731 pingTicker := time.NewTicker(c.pingInterval) 732 defer pingTicker.Stop() 733 734 for { 735 select { 736 case req := <-c.sendChan: 737 if err := c.sendData(req); err != nil { 738 return err 739 } 740 case <-pingTicker.C: 741 n, err := encodePacket(c.buf[4:], &requestHeader{Xid: -2, Opcode: opPing}) 742 if err != nil { 743 panic("zk: opPing should never fail to serialize") 744 } 745 746 binary.BigEndian.PutUint32(c.buf[:4], uint32(n)) 747 748 c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout)) 749 _, err = c.conn.Write(c.buf[:n+4]) 750 c.conn.SetWriteDeadline(time.Time{}) 751 if err != nil { 752 c.conn.Close() 753 return err 754 } 755 case <-c.closeChan: 756 return nil 757 } 758 } 759} 760 761func (c *Conn) recvLoop(conn net.Conn) error { 762 sz := bufferSize 763 if c.maxBufferSize > 0 && sz > c.maxBufferSize { 764 sz = c.maxBufferSize 765 } 766 buf := make([]byte, sz) 767 for { 768 // package length 769 if err := conn.SetReadDeadline(time.Now().Add(c.recvTimeout)); err != nil { 770 c.logger.Printf("failed to set connection deadline: %v", err) 771 } 772 _, err := io.ReadFull(conn, buf[:4]) 773 if err != nil { 774 return fmt.Errorf("failed to read from connection: %v", err) 775 } 776 777 blen := int(binary.BigEndian.Uint32(buf[:4])) 778 if cap(buf) < blen { 779 if c.maxBufferSize > 0 && blen > c.maxBufferSize { 780 return fmt.Errorf("received packet from server with length %d, which exceeds max buffer size %d", blen, c.maxBufferSize) 781 } 782 buf = make([]byte, blen) 783 } 784 785 _, err = io.ReadFull(conn, buf[:blen]) 786 conn.SetReadDeadline(time.Time{}) 787 if err != nil { 788 return err 789 } 790 791 res := responseHeader{} 792 _, err = decodePacket(buf[:16], &res) 793 if err != nil { 794 return err 795 } 796 797 if res.Xid == -1 { 798 res := &watcherEvent{} 799 _, err := decodePacket(buf[16:blen], res) 800 if err != nil { 801 return err 802 } 803 ev := Event{ 804 Type: res.Type, 805 State: res.State, 806 Path: res.Path, 807 Err: nil, 808 } 809 c.sendEvent(ev) 810 wTypes := make([]watchType, 0, 2) 811 switch res.Type { 812 case EventNodeCreated: 813 wTypes = append(wTypes, watchTypeExist) 814 case EventNodeDeleted, EventNodeDataChanged: 815 wTypes = append(wTypes, watchTypeExist, watchTypeData, watchTypeChild) 816 case EventNodeChildrenChanged: 817 wTypes = append(wTypes, watchTypeChild) 818 } 819 c.watchersLock.Lock() 820 for _, t := range wTypes { 821 wpt := watchPathType{res.Path, t} 822 if watchers := c.watchers[wpt]; watchers != nil && len(watchers) > 0 { 823 for _, ch := range watchers { 824 ch <- ev 825 close(ch) 826 } 827 delete(c.watchers, wpt) 828 } 829 } 830 c.watchersLock.Unlock() 831 } else if res.Xid == -2 { 832 // Ping response. Ignore. 833 } else if res.Xid < 0 { 834 c.logger.Printf("Xid < 0 (%d) but not ping or watcher event", res.Xid) 835 } else { 836 if res.Zxid > 0 { 837 c.lastZxid = res.Zxid 838 } 839 840 c.requestsLock.Lock() 841 req, ok := c.requests[res.Xid] 842 if ok { 843 delete(c.requests, res.Xid) 844 } 845 c.requestsLock.Unlock() 846 847 if !ok { 848 c.logger.Printf("Response for unknown request with xid %d", res.Xid) 849 } else { 850 if res.Err != 0 { 851 err = res.Err.toError() 852 } else { 853 _, err = decodePacket(buf[16:blen], req.recvStruct) 854 } 855 if req.recvFunc != nil { 856 req.recvFunc(req, &res, err) 857 } 858 req.recvChan <- response{res.Zxid, err} 859 if req.opcode == opClose { 860 return io.EOF 861 } 862 } 863 } 864 } 865} 866 867func (c *Conn) nextXid() int32 { 868 return int32(atomic.AddUint32(&c.xid, 1) & 0x7fffffff) 869} 870 871func (c *Conn) addWatcher(path string, watchType watchType) <-chan Event { 872 c.watchersLock.Lock() 873 defer c.watchersLock.Unlock() 874 875 ch := make(chan Event, 1) 876 wpt := watchPathType{path, watchType} 877 c.watchers[wpt] = append(c.watchers[wpt], ch) 878 return ch 879} 880 881func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) <-chan response { 882 rq := &request{ 883 xid: c.nextXid(), 884 opcode: opcode, 885 pkt: req, 886 recvStruct: res, 887 recvChan: make(chan response, 2), 888 recvFunc: recvFunc, 889 } 890 891 switch opcode { 892 case opClose: 893 // always attempt to send close ops. 894 select { 895 case c.sendChan <- rq: 896 case <-time.After(c.connectTimeout * 2): 897 c.logger.Printf("gave up trying to send opClose to server") 898 rq.recvChan <- response{-1, ErrConnectionClosed} 899 } 900 default: 901 // otherwise avoid deadlocks for dumb clients who aren't aware that 902 // the ZK connection is closed yet. 903 select { 904 case <-c.shouldQuit: 905 rq.recvChan <- response{-1, ErrConnectionClosed} 906 case c.sendChan <- rq: 907 // check for a tie 908 select { 909 case <-c.shouldQuit: 910 // maybe the caller gets this, maybe not- we tried. 911 rq.recvChan <- response{-1, ErrConnectionClosed} 912 default: 913 } 914 } 915 } 916 return rq.recvChan 917} 918 919func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (int64, error) { 920 r := <-c.queueRequest(opcode, req, res, recvFunc) 921 select { 922 case <-c.shouldQuit: 923 // queueRequest() can be racy, double-check for the race here and avoid 924 // a potential data-race. otherwise the client of this func may try to 925 // access `res` fields concurrently w/ the async response processor. 926 // NOTE: callers of this func should check for (at least) ErrConnectionClosed 927 // and avoid accessing fields of the response object if such error is present. 928 return -1, ErrConnectionClosed 929 default: 930 return r.zxid, r.err 931 } 932} 933 934func (c *Conn) AddAuth(scheme string, auth []byte) error { 935 _, err := c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil) 936 937 if err != nil { 938 return err 939 } 940 941 // Remember authdata so that it can be re-submitted on reconnect 942 // 943 // FIXME(prozlach): For now we treat "userfoo:passbar" and "userfoo:passbar2" 944 // as two different entries, which will be re-submitted on reconnet. Some 945 // research is needed on how ZK treats these cases and 946 // then maybe switch to something like "map[username] = password" to allow 947 // only single password for given user with users being unique. 948 obj := authCreds{ 949 scheme: scheme, 950 auth: auth, 951 } 952 953 c.credsMu.Lock() 954 c.creds = append(c.creds, obj) 955 c.credsMu.Unlock() 956 957 return nil 958} 959 960func (c *Conn) Children(path string) ([]string, *Stat, error) { 961 if err := validatePath(path, false); err != nil { 962 return nil, nil, err 963 } 964 965 res := &getChildren2Response{} 966 _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: false}, res, nil) 967 if err == ErrConnectionClosed { 968 return nil, nil, err 969 } 970 return res.Children, &res.Stat, err 971} 972 973func (c *Conn) ChildrenW(path string) ([]string, *Stat, <-chan Event, error) { 974 if err := validatePath(path, false); err != nil { 975 return nil, nil, nil, err 976 } 977 978 var ech <-chan Event 979 res := &getChildren2Response{} 980 _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { 981 if err == nil { 982 ech = c.addWatcher(path, watchTypeChild) 983 } 984 }) 985 if err != nil { 986 return nil, nil, nil, err 987 } 988 return res.Children, &res.Stat, ech, err 989} 990 991func (c *Conn) Get(path string) ([]byte, *Stat, error) { 992 if err := validatePath(path, false); err != nil { 993 return nil, nil, err 994 } 995 996 res := &getDataResponse{} 997 _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: false}, res, nil) 998 if err == ErrConnectionClosed { 999 return nil, nil, err 1000 } 1001 return res.Data, &res.Stat, err 1002} 1003 1004// GetW returns the contents of a znode and sets a watch 1005func (c *Conn) GetW(path string) ([]byte, *Stat, <-chan Event, error) { 1006 if err := validatePath(path, false); err != nil { 1007 return nil, nil, nil, err 1008 } 1009 1010 var ech <-chan Event 1011 res := &getDataResponse{} 1012 _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { 1013 if err == nil { 1014 ech = c.addWatcher(path, watchTypeData) 1015 } 1016 }) 1017 if err != nil { 1018 return nil, nil, nil, err 1019 } 1020 return res.Data, &res.Stat, ech, err 1021} 1022 1023func (c *Conn) Set(path string, data []byte, version int32) (*Stat, error) { 1024 if err := validatePath(path, false); err != nil { 1025 return nil, err 1026 } 1027 1028 res := &setDataResponse{} 1029 _, err := c.request(opSetData, &SetDataRequest{path, data, version}, res, nil) 1030 if err == ErrConnectionClosed { 1031 return nil, err 1032 } 1033 return &res.Stat, err 1034} 1035 1036func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, error) { 1037 if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil { 1038 return "", err 1039 } 1040 1041 res := &createResponse{} 1042 _, err := c.request(opCreate, &CreateRequest{path, data, acl, flags}, res, nil) 1043 if err == ErrConnectionClosed { 1044 return "", err 1045 } 1046 return res.Path, err 1047} 1048 1049func (c *Conn) CreateContainer(path string, data []byte, flags int32, acl []ACL) (string, error) { 1050 if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil { 1051 return "", err 1052 } 1053 if flags&FlagTTL != FlagTTL { 1054 return "", ErrInvalidFlags 1055 } 1056 1057 res := &createResponse{} 1058 _, err := c.request(opCreateContainer, &CreateContainerRequest{path, data, acl, flags}, res, nil) 1059 return res.Path, err 1060} 1061 1062func (c *Conn) CreateTTL(path string, data []byte, flags int32, acl []ACL, ttl time.Duration) (string, error) { 1063 if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil { 1064 return "", err 1065 } 1066 if flags&FlagTTL != FlagTTL { 1067 return "", ErrInvalidFlags 1068 } 1069 1070 res := &createResponse{} 1071 _, err := c.request(opCreateTTL, &CreateTTLRequest{path, data, acl, flags, ttl.Milliseconds()}, res, nil) 1072 return res.Path, err 1073} 1074 1075// CreateProtectedEphemeralSequential fixes a race condition if the server crashes 1076// after it creates the node. On reconnect the session may still be valid so the 1077// ephemeral node still exists. Therefore, on reconnect we need to check if a node 1078// with a GUID generated on create exists. 1079func (c *Conn) CreateProtectedEphemeralSequential(path string, data []byte, acl []ACL) (string, error) { 1080 if err := validatePath(path, true); err != nil { 1081 return "", err 1082 } 1083 1084 var guid [16]byte 1085 _, err := io.ReadFull(rand.Reader, guid[:16]) 1086 if err != nil { 1087 return "", err 1088 } 1089 guidStr := fmt.Sprintf("%x", guid) 1090 1091 parts := strings.Split(path, "/") 1092 parts[len(parts)-1] = fmt.Sprintf("%s%s-%s", protectedPrefix, guidStr, parts[len(parts)-1]) 1093 rootPath := strings.Join(parts[:len(parts)-1], "/") 1094 protectedPath := strings.Join(parts, "/") 1095 1096 var newPath string 1097 for i := 0; i < 3; i++ { 1098 newPath, err = c.Create(protectedPath, data, FlagEphemeral|FlagSequence, acl) 1099 switch err { 1100 case ErrSessionExpired: 1101 // No need to search for the node since it can't exist. Just try again. 1102 case ErrConnectionClosed: 1103 children, _, err := c.Children(rootPath) 1104 if err != nil { 1105 return "", err 1106 } 1107 for _, p := range children { 1108 parts := strings.Split(p, "/") 1109 if pth := parts[len(parts)-1]; strings.HasPrefix(pth, protectedPrefix) { 1110 if g := pth[len(protectedPrefix) : len(protectedPrefix)+32]; g == guidStr { 1111 return rootPath + "/" + p, nil 1112 } 1113 } 1114 } 1115 case nil: 1116 return newPath, nil 1117 default: 1118 return "", err 1119 } 1120 } 1121 return "", err 1122} 1123 1124func (c *Conn) Delete(path string, version int32) error { 1125 if err := validatePath(path, false); err != nil { 1126 return err 1127 } 1128 1129 _, err := c.request(opDelete, &DeleteRequest{path, version}, &deleteResponse{}, nil) 1130 return err 1131} 1132 1133func (c *Conn) Exists(path string) (bool, *Stat, error) { 1134 if err := validatePath(path, false); err != nil { 1135 return false, nil, err 1136 } 1137 1138 res := &existsResponse{} 1139 _, err := c.request(opExists, &existsRequest{Path: path, Watch: false}, res, nil) 1140 if err == ErrConnectionClosed { 1141 return false, nil, err 1142 } 1143 exists := true 1144 if err == ErrNoNode { 1145 exists = false 1146 err = nil 1147 } 1148 return exists, &res.Stat, err 1149} 1150 1151func (c *Conn) ExistsW(path string) (bool, *Stat, <-chan Event, error) { 1152 if err := validatePath(path, false); err != nil { 1153 return false, nil, nil, err 1154 } 1155 1156 var ech <-chan Event 1157 res := &existsResponse{} 1158 _, err := c.request(opExists, &existsRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { 1159 if err == nil { 1160 ech = c.addWatcher(path, watchTypeData) 1161 } else if err == ErrNoNode { 1162 ech = c.addWatcher(path, watchTypeExist) 1163 } 1164 }) 1165 exists := true 1166 if err == ErrNoNode { 1167 exists = false 1168 err = nil 1169 } 1170 if err != nil { 1171 return false, nil, nil, err 1172 } 1173 return exists, &res.Stat, ech, err 1174} 1175 1176func (c *Conn) GetACL(path string) ([]ACL, *Stat, error) { 1177 if err := validatePath(path, false); err != nil { 1178 return nil, nil, err 1179 } 1180 1181 res := &getAclResponse{} 1182 _, err := c.request(opGetAcl, &getAclRequest{Path: path}, res, nil) 1183 if err == ErrConnectionClosed { 1184 return nil, nil, err 1185 } 1186 return res.Acl, &res.Stat, err 1187} 1188func (c *Conn) SetACL(path string, acl []ACL, version int32) (*Stat, error) { 1189 if err := validatePath(path, false); err != nil { 1190 return nil, err 1191 } 1192 1193 res := &setAclResponse{} 1194 _, err := c.request(opSetAcl, &setAclRequest{Path: path, Acl: acl, Version: version}, res, nil) 1195 if err == ErrConnectionClosed { 1196 return nil, err 1197 } 1198 return &res.Stat, err 1199} 1200 1201func (c *Conn) Sync(path string) (string, error) { 1202 if err := validatePath(path, false); err != nil { 1203 return "", err 1204 } 1205 1206 res := &syncResponse{} 1207 _, err := c.request(opSync, &syncRequest{Path: path}, res, nil) 1208 if err == ErrConnectionClosed { 1209 return "", err 1210 } 1211 return res.Path, err 1212} 1213 1214type MultiResponse struct { 1215 Stat *Stat 1216 String string 1217 Error error 1218} 1219 1220// Multi executes multiple ZooKeeper operations or none of them. The provided 1221// ops must be one of *CreateRequest, *DeleteRequest, *SetDataRequest, or 1222// *CheckVersionRequest. 1223func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) { 1224 req := &multiRequest{ 1225 Ops: make([]multiRequestOp, 0, len(ops)), 1226 DoneHeader: multiHeader{Type: -1, Done: true, Err: -1}, 1227 } 1228 for _, op := range ops { 1229 var opCode int32 1230 switch op.(type) { 1231 case *CreateRequest: 1232 opCode = opCreate 1233 case *SetDataRequest: 1234 opCode = opSetData 1235 case *DeleteRequest: 1236 opCode = opDelete 1237 case *CheckVersionRequest: 1238 opCode = opCheck 1239 default: 1240 return nil, fmt.Errorf("unknown operation type %T", op) 1241 } 1242 req.Ops = append(req.Ops, multiRequestOp{multiHeader{opCode, false, -1}, op}) 1243 } 1244 res := &multiResponse{} 1245 _, err := c.request(opMulti, req, res, nil) 1246 if err == ErrConnectionClosed { 1247 return nil, err 1248 } 1249 mr := make([]MultiResponse, len(res.Ops)) 1250 for i, op := range res.Ops { 1251 mr[i] = MultiResponse{Stat: op.Stat, String: op.String, Error: op.Err.toError()} 1252 } 1253 return mr, err 1254} 1255 1256// IncrementalReconfig is the zookeeper reconfiguration api that allows adding and removing servers 1257// by lists of members. For more info refer to the ZK documentation. 1258// 1259// An optional version allows for conditional reconfigurations, -1 ignores the condition. 1260// 1261// Returns the new configuration znode stat. 1262func (c *Conn) IncrementalReconfig(joining, leaving []string, version int64) (*Stat, error) { 1263 // TODO: validate the shape of the member string to give early feedback. 1264 request := &reconfigRequest{ 1265 JoiningServers: []byte(strings.Join(joining, ",")), 1266 LeavingServers: []byte(strings.Join(leaving, ",")), 1267 CurConfigId: version, 1268 } 1269 1270 return c.internalReconfig(request) 1271} 1272 1273// Reconfig is the non-incremental update functionality for Zookeeper where the list provided 1274// is the entire new member list. For more info refer to the ZK documentation. 1275// 1276// An optional version allows for conditional reconfigurations, -1 ignores the condition. 1277// 1278// Returns the new configuration znode stat. 1279func (c *Conn) Reconfig(members []string, version int64) (*Stat, error) { 1280 request := &reconfigRequest{ 1281 NewMembers: []byte(strings.Join(members, ",")), 1282 CurConfigId: version, 1283 } 1284 1285 return c.internalReconfig(request) 1286} 1287 1288func (c *Conn) internalReconfig(request *reconfigRequest) (*Stat, error) { 1289 response := &reconfigReponse{} 1290 _, err := c.request(opReconfig, request, response, nil) 1291 return &response.Stat, err 1292} 1293 1294// Server returns the current or last-connected server name. 1295func (c *Conn) Server() string { 1296 c.serverMu.Lock() 1297 defer c.serverMu.Unlock() 1298 return c.server 1299} 1300 1301func resendZkAuth(ctx context.Context, c *Conn) error { 1302 shouldCancel := func() bool { 1303 select { 1304 case <-c.shouldQuit: 1305 return true 1306 case <-c.closeChan: 1307 return true 1308 default: 1309 return false 1310 } 1311 } 1312 1313 c.credsMu.Lock() 1314 defer c.credsMu.Unlock() 1315 1316 if c.logInfo { 1317 c.logger.Printf("re-submitting `%d` credentials after reconnect", len(c.creds)) 1318 } 1319 1320 for _, cred := range c.creds { 1321 // return early before attempting to send request. 1322 if shouldCancel() { 1323 return nil 1324 } 1325 // do not use the public API for auth since it depends on the send/recv loops 1326 // that are waiting for this to return 1327 resChan, err := c.sendRequest( 1328 opSetAuth, 1329 &setAuthRequest{Type: 0, 1330 Scheme: cred.scheme, 1331 Auth: cred.auth, 1332 }, 1333 &setAuthResponse{}, 1334 nil, /* recvFunc*/ 1335 ) 1336 if err != nil { 1337 return fmt.Errorf("failed to send auth request: %v", err) 1338 } 1339 1340 var res response 1341 select { 1342 case res = <-resChan: 1343 case <-c.closeChan: 1344 c.logger.Printf("recv closed, cancel re-submitting credentials") 1345 return nil 1346 case <-c.shouldQuit: 1347 c.logger.Printf("should quit, cancel re-submitting credentials") 1348 return nil 1349 case <-ctx.Done(): 1350 return ctx.Err() 1351 } 1352 if res.err != nil { 1353 return fmt.Errorf("failed conneciton setAuth request: %v", res.err) 1354 } 1355 } 1356 1357 return nil 1358} 1359