1package memberlist 2 3import ( 4 "bufio" 5 "bytes" 6 "encoding/binary" 7 "fmt" 8 "io" 9 "net" 10 "time" 11 12 "github.com/armon/go-metrics" 13 "github.com/hashicorp/go-msgpack/codec" 14) 15 16// This is the minimum and maximum protocol version that we can 17// _understand_. We're allowed to speak at any version within this 18// range. This range is inclusive. 19const ( 20 ProtocolVersionMin uint8 = 1 21 22 // Version 3 added support for TCP pings but we kept the default 23 // protocol version at 2 to ease transition to this new feature. 24 // A memberlist speaking version 2 of the protocol will attempt 25 // to TCP ping another memberlist who understands version 3 or 26 // greater. 27 ProtocolVersion2Compatible = 2 28 29 ProtocolVersionMax = 3 30) 31 32// messageType is an integer ID of a type of message that can be received 33// on network channels from other members. 34type messageType uint8 35 36// The list of available message types. 37const ( 38 pingMsg messageType = iota 39 indirectPingMsg 40 ackRespMsg 41 suspectMsg 42 aliveMsg 43 deadMsg 44 pushPullMsg 45 compoundMsg 46 userMsg // User mesg, not handled by us 47 compressMsg 48 encryptMsg 49) 50 51// compressionType is used to specify the compression algorithm 52type compressionType uint8 53 54const ( 55 lzwAlgo compressionType = iota 56) 57 58const ( 59 MetaMaxSize = 512 // Maximum size for node meta data 60 compoundHeaderOverhead = 2 // Assumed header overhead 61 compoundOverhead = 2 // Assumed overhead per entry in compoundHeader 62 udpBufSize = 65536 63 udpRecvBuf = 2 * 1024 * 1024 64 udpSendBuf = 1400 65 userMsgOverhead = 1 66 blockingWarning = 10 * time.Millisecond // Warn if a UDP packet takes this long to process 67 maxPushStateBytes = 10 * 1024 * 1024 68) 69 70// ping request sent directly to node 71type ping struct { 72 SeqNo uint32 73 74 // Node is sent so the target can verify they are 75 // the intended recipient. This is to protect again an agent 76 // restart with a new name. 77 Node string 78} 79 80// indirect ping sent to an indirect ndoe 81type indirectPingReq struct { 82 SeqNo uint32 83 Target []byte 84 Port uint16 85 Node string 86} 87 88// ack response is sent for a ping 89type ackResp struct { 90 SeqNo uint32 91 Payload []byte 92} 93 94// suspect is broadcast when we suspect a node is dead 95type suspect struct { 96 Incarnation uint32 97 Node string 98 From string // Include who is suspecting 99} 100 101// alive is broadcast when we know a node is alive. 102// Overloaded for nodes joining 103type alive struct { 104 Incarnation uint32 105 Node string 106 Addr []byte 107 Port uint16 108 Meta []byte 109 110 // The versions of the protocol/delegate that are being spoken, order: 111 // pmin, pmax, pcur, dmin, dmax, dcur 112 Vsn []uint8 113} 114 115// dead is broadcast when we confirm a node is dead 116// Overloaded for nodes leaving 117type dead struct { 118 Incarnation uint32 119 Node string 120 From string // Include who is suspecting 121} 122 123// pushPullHeader is used to inform the 124// otherside how many states we are transfering 125type pushPullHeader struct { 126 Nodes int 127 UserStateLen int // Encodes the byte lengh of user state 128 Join bool // Is this a join request or a anti-entropy run 129} 130 131// userMsgHeader is used to encapsulate a userMsg 132type userMsgHeader struct { 133 UserMsgLen int // Encodes the byte lengh of user state 134} 135 136// pushNodeState is used for pushPullReq when we are 137// transfering out node states 138type pushNodeState struct { 139 Name string 140 Addr []byte 141 Port uint16 142 Meta []byte 143 Incarnation uint32 144 State nodeStateType 145 Vsn []uint8 // Protocol versions 146} 147 148// compress is used to wrap an underlying payload 149// using a specified compression algorithm 150type compress struct { 151 Algo compressionType 152 Buf []byte 153} 154 155// msgHandoff is used to transfer a message between goroutines 156type msgHandoff struct { 157 msgType messageType 158 buf []byte 159 from net.Addr 160} 161 162// encryptionVersion returns the encryption version to use 163func (m *Memberlist) encryptionVersion() encryptionVersion { 164 switch m.ProtocolVersion() { 165 case 1: 166 return 0 167 default: 168 return 1 169 } 170} 171 172// setUDPRecvBuf is used to resize the UDP receive window. The function 173// attempts to set the read buffer to `udpRecvBuf` but backs off until 174// the read buffer can be set. 175func setUDPRecvBuf(c *net.UDPConn) { 176 size := udpRecvBuf 177 for { 178 if err := c.SetReadBuffer(size); err == nil { 179 break 180 } 181 size = size / 2 182 } 183} 184 185// tcpListen listens for and handles incoming connections 186func (m *Memberlist) tcpListen() { 187 for { 188 conn, err := m.tcpListener.AcceptTCP() 189 if err != nil { 190 if m.shutdown { 191 break 192 } 193 m.logger.Printf("[ERR] memberlist: Error accepting TCP connection: %s", err) 194 continue 195 } 196 go m.handleConn(conn) 197 } 198} 199 200// handleConn handles a single incoming TCP connection 201func (m *Memberlist) handleConn(conn *net.TCPConn) { 202 m.logger.Printf("[DEBUG] memberlist: TCP connection %s", LogConn(conn)) 203 204 defer conn.Close() 205 metrics.IncrCounter([]string{"memberlist", "tcp", "accept"}, 1) 206 207 conn.SetDeadline(time.Now().Add(m.config.TCPTimeout)) 208 msgType, bufConn, dec, err := m.readTCP(conn) 209 if err != nil { 210 m.logger.Printf("[ERR] memberlist: failed to receive: %s %s", err, LogConn(conn)) 211 return 212 } 213 214 switch msgType { 215 case userMsg: 216 if err := m.readUserMsg(bufConn, dec); err != nil { 217 m.logger.Printf("[ERR] memberlist: Failed to receive user message: %s %s", err, LogConn(conn)) 218 } 219 case pushPullMsg: 220 join, remoteNodes, userState, err := m.readRemoteState(bufConn, dec) 221 if err != nil { 222 m.logger.Printf("[ERR] memberlist: Failed to read remote state: %s %s", err, LogConn(conn)) 223 return 224 } 225 226 if err := m.sendLocalState(conn, join); err != nil { 227 m.logger.Printf("[ERR] memberlist: Failed to push local state: %s %s", err, LogConn(conn)) 228 return 229 } 230 231 if err := m.mergeRemoteState(join, remoteNodes, userState); err != nil { 232 m.logger.Printf("[ERR] memberlist: Failed push/pull merge: %s %s", err, LogConn(conn)) 233 return 234 } 235 case pingMsg: 236 var p ping 237 if err := dec.Decode(&p); err != nil { 238 m.logger.Printf("[ERR] memberlist: Failed to decode TCP ping: %s %s", err, LogConn(conn)) 239 return 240 } 241 242 if p.Node != "" && p.Node != m.config.Name { 243 m.logger.Printf("[WARN] memberlist: Got ping for unexpected node %s %s", p.Node, LogConn(conn)) 244 return 245 } 246 247 ack := ackResp{p.SeqNo, nil} 248 out, err := encode(ackRespMsg, &ack) 249 if err != nil { 250 m.logger.Printf("[ERR] memberlist: Failed to encode TCP ack: %s", err) 251 return 252 } 253 254 err = m.rawSendMsgTCP(conn, out.Bytes()) 255 if err != nil { 256 m.logger.Printf("[ERR] memberlist: Failed to send TCP ack: %s %s", err, LogConn(conn)) 257 return 258 } 259 default: 260 m.logger.Printf("[ERR] memberlist: Received invalid msgType (%d) %s", msgType, LogConn(conn)) 261 } 262} 263 264// udpListen listens for and handles incoming UDP packets 265func (m *Memberlist) udpListen() { 266 var n int 267 var addr net.Addr 268 var err error 269 var lastPacket time.Time 270 for { 271 // Do a check for potentially blocking operations 272 if !lastPacket.IsZero() && time.Now().Sub(lastPacket) > blockingWarning { 273 diff := time.Now().Sub(lastPacket) 274 m.logger.Printf( 275 "[DEBUG] memberlist: Potential blocking operation. Last command took %v", 276 diff) 277 } 278 279 // Create a new buffer 280 // TODO: Use Sync.Pool eventually 281 buf := make([]byte, udpBufSize) 282 283 // Read a packet 284 n, addr, err = m.udpListener.ReadFrom(buf) 285 if err != nil { 286 if m.shutdown { 287 break 288 } 289 m.logger.Printf("[ERR] memberlist: Error reading UDP packet: %s", err) 290 continue 291 } 292 293 // Capture the reception time of the packet as close to the 294 // system calls as possible. 295 lastPacket = time.Now() 296 297 // Check the length 298 if n < 1 { 299 m.logger.Printf("[ERR] memberlist: UDP packet too short (%d bytes) %s", 300 len(buf), LogAddress(addr)) 301 continue 302 } 303 304 // Ingest this packet 305 metrics.IncrCounter([]string{"memberlist", "udp", "received"}, float32(n)) 306 m.ingestPacket(buf[:n], addr, lastPacket) 307 } 308} 309 310func (m *Memberlist) ingestPacket(buf []byte, from net.Addr, timestamp time.Time) { 311 // Check if encryption is enabled 312 if m.config.EncryptionEnabled() { 313 // Decrypt the payload 314 plain, err := decryptPayload(m.config.Keyring.GetKeys(), buf, nil) 315 if err != nil { 316 m.logger.Printf("[ERR] memberlist: Decrypt packet failed: %v %s", err, LogAddress(from)) 317 return 318 } 319 320 // Continue processing the plaintext buffer 321 buf = plain 322 } 323 324 // Handle the command 325 m.handleCommand(buf, from, timestamp) 326} 327 328func (m *Memberlist) handleCommand(buf []byte, from net.Addr, timestamp time.Time) { 329 // Decode the message type 330 msgType := messageType(buf[0]) 331 buf = buf[1:] 332 333 // Switch on the msgType 334 switch msgType { 335 case compoundMsg: 336 m.handleCompound(buf, from, timestamp) 337 case compressMsg: 338 m.handleCompressed(buf, from, timestamp) 339 340 case pingMsg: 341 m.handlePing(buf, from) 342 case indirectPingMsg: 343 m.handleIndirectPing(buf, from) 344 case ackRespMsg: 345 m.handleAck(buf, from, timestamp) 346 347 case suspectMsg: 348 fallthrough 349 case aliveMsg: 350 fallthrough 351 case deadMsg: 352 fallthrough 353 case userMsg: 354 select { 355 case m.handoff <- msgHandoff{msgType, buf, from}: 356 default: 357 m.logger.Printf("[WARN] memberlist: UDP handler queue full, dropping message (%d) %s", msgType, LogAddress(from)) 358 } 359 360 default: 361 m.logger.Printf("[ERR] memberlist: UDP msg type (%d) not supported %s", msgType, LogAddress(from)) 362 } 363} 364 365// udpHandler processes messages received over UDP, but is decoupled 366// from the listener to avoid blocking the listener which may cause 367// ping/ack messages to be delayed. 368func (m *Memberlist) udpHandler() { 369 for { 370 select { 371 case msg := <-m.handoff: 372 msgType := msg.msgType 373 buf := msg.buf 374 from := msg.from 375 376 switch msgType { 377 case suspectMsg: 378 m.handleSuspect(buf, from) 379 case aliveMsg: 380 m.handleAlive(buf, from) 381 case deadMsg: 382 m.handleDead(buf, from) 383 case userMsg: 384 m.handleUser(buf, from) 385 default: 386 m.logger.Printf("[ERR] memberlist: UDP msg type (%d) not supported %s (handler)", msgType, LogAddress(from)) 387 } 388 389 case <-m.shutdownCh: 390 return 391 } 392 } 393} 394 395func (m *Memberlist) handleCompound(buf []byte, from net.Addr, timestamp time.Time) { 396 // Decode the parts 397 trunc, parts, err := decodeCompoundMessage(buf) 398 if err != nil { 399 m.logger.Printf("[ERR] memberlist: Failed to decode compound request: %s %s", err, LogAddress(from)) 400 return 401 } 402 403 // Log any truncation 404 if trunc > 0 { 405 m.logger.Printf("[WARN] memberlist: Compound request had %d truncated messages %s", trunc, LogAddress(from)) 406 } 407 408 // Handle each message 409 for _, part := range parts { 410 m.handleCommand(part, from, timestamp) 411 } 412} 413 414func (m *Memberlist) handlePing(buf []byte, from net.Addr) { 415 var p ping 416 if err := decode(buf, &p); err != nil { 417 m.logger.Printf("[ERR] memberlist: Failed to decode ping request: %s %s", err, LogAddress(from)) 418 return 419 } 420 // If node is provided, verify that it is for us 421 if p.Node != "" && p.Node != m.config.Name { 422 m.logger.Printf("[WARN] memberlist: Got ping for unexpected node '%s' %s", p.Node, LogAddress(from)) 423 return 424 } 425 var ack ackResp 426 ack.SeqNo = p.SeqNo 427 if m.config.Ping != nil { 428 ack.Payload = m.config.Ping.AckPayload() 429 } 430 if err := m.encodeAndSendMsg(from, ackRespMsg, &ack); err != nil { 431 m.logger.Printf("[ERR] memberlist: Failed to send ack: %s %s", err, LogAddress(from)) 432 } 433} 434 435func (m *Memberlist) handleIndirectPing(buf []byte, from net.Addr) { 436 var ind indirectPingReq 437 if err := decode(buf, &ind); err != nil { 438 m.logger.Printf("[ERR] memberlist: Failed to decode indirect ping request: %s %s", err, LogAddress(from)) 439 return 440 } 441 442 // For proto versions < 2, there is no port provided. Mask old 443 // behavior by using the configured port 444 if m.ProtocolVersion() < 2 || ind.Port == 0 { 445 ind.Port = uint16(m.config.BindPort) 446 } 447 448 // Send a ping to the correct host 449 localSeqNo := m.nextSeqNo() 450 ping := ping{SeqNo: localSeqNo, Node: ind.Node} 451 destAddr := &net.UDPAddr{IP: ind.Target, Port: int(ind.Port)} 452 453 // Setup a response handler to relay the ack 454 respHandler := func(payload []byte, timestamp time.Time) { 455 ack := ackResp{ind.SeqNo, nil} 456 if err := m.encodeAndSendMsg(from, ackRespMsg, &ack); err != nil { 457 m.logger.Printf("[ERR] memberlist: Failed to forward ack: %s %s", err, LogAddress(from)) 458 } 459 } 460 m.setAckHandler(localSeqNo, respHandler, m.config.ProbeTimeout) 461 462 // Send the ping 463 if err := m.encodeAndSendMsg(destAddr, pingMsg, &ping); err != nil { 464 m.logger.Printf("[ERR] memberlist: Failed to send ping: %s %s", err, LogAddress(from)) 465 } 466} 467 468func (m *Memberlist) handleAck(buf []byte, from net.Addr, timestamp time.Time) { 469 var ack ackResp 470 if err := decode(buf, &ack); err != nil { 471 m.logger.Printf("[ERR] memberlist: Failed to decode ack response: %s %s", err, LogAddress(from)) 472 return 473 } 474 m.invokeAckHandler(ack, timestamp) 475} 476 477func (m *Memberlist) handleSuspect(buf []byte, from net.Addr) { 478 var sus suspect 479 if err := decode(buf, &sus); err != nil { 480 m.logger.Printf("[ERR] memberlist: Failed to decode suspect message: %s %s", err, LogAddress(from)) 481 return 482 } 483 m.suspectNode(&sus) 484} 485 486func (m *Memberlist) handleAlive(buf []byte, from net.Addr) { 487 var live alive 488 if err := decode(buf, &live); err != nil { 489 m.logger.Printf("[ERR] memberlist: Failed to decode alive message: %s %s", err, LogAddress(from)) 490 return 491 } 492 493 // For proto versions < 2, there is no port provided. Mask old 494 // behavior by using the configured port 495 if m.ProtocolVersion() < 2 || live.Port == 0 { 496 live.Port = uint16(m.config.BindPort) 497 } 498 499 m.aliveNode(&live, nil, false) 500} 501 502func (m *Memberlist) handleDead(buf []byte, from net.Addr) { 503 var d dead 504 if err := decode(buf, &d); err != nil { 505 m.logger.Printf("[ERR] memberlist: Failed to decode dead message: %s %s", err, LogAddress(from)) 506 return 507 } 508 m.deadNode(&d) 509} 510 511// handleUser is used to notify channels of incoming user data 512func (m *Memberlist) handleUser(buf []byte, from net.Addr) { 513 d := m.config.Delegate 514 if d != nil { 515 d.NotifyMsg(buf) 516 } 517} 518 519// handleCompressed is used to unpack a compressed message 520func (m *Memberlist) handleCompressed(buf []byte, from net.Addr, timestamp time.Time) { 521 // Try to decode the payload 522 payload, err := decompressPayload(buf) 523 if err != nil { 524 m.logger.Printf("[ERR] memberlist: Failed to decompress payload: %v %s", err, LogAddress(from)) 525 return 526 } 527 528 // Recursively handle the payload 529 m.handleCommand(payload, from, timestamp) 530} 531 532// encodeAndSendMsg is used to combine the encoding and sending steps 533func (m *Memberlist) encodeAndSendMsg(to net.Addr, msgType messageType, msg interface{}) error { 534 out, err := encode(msgType, msg) 535 if err != nil { 536 return err 537 } 538 if err := m.sendMsg(to, out.Bytes()); err != nil { 539 return err 540 } 541 return nil 542} 543 544// sendMsg is used to send a UDP message to another host. It will opportunistically 545// create a compoundMsg and piggy back other broadcasts 546func (m *Memberlist) sendMsg(to net.Addr, msg []byte) error { 547 // Check if we can piggy back any messages 548 bytesAvail := udpSendBuf - len(msg) - compoundHeaderOverhead 549 if m.config.EncryptionEnabled() { 550 bytesAvail -= encryptOverhead(m.encryptionVersion()) 551 } 552 extra := m.getBroadcasts(compoundOverhead, bytesAvail) 553 554 // Fast path if nothing to piggypack 555 if len(extra) == 0 { 556 return m.rawSendMsgUDP(to, msg) 557 } 558 559 // Join all the messages 560 msgs := make([][]byte, 0, 1+len(extra)) 561 msgs = append(msgs, msg) 562 msgs = append(msgs, extra...) 563 564 // Create a compound message 565 compound := makeCompoundMessage(msgs) 566 567 // Send the message 568 return m.rawSendMsgUDP(to, compound.Bytes()) 569} 570 571// rawSendMsgUDP is used to send a UDP message to another host without modification 572func (m *Memberlist) rawSendMsgUDP(to net.Addr, msg []byte) error { 573 // Check if we have compression enabled 574 if m.config.EnableCompression { 575 buf, err := compressPayload(msg) 576 if err != nil { 577 m.logger.Printf("[WARN] memberlist: Failed to compress payload: %v", err) 578 } else { 579 // Only use compression if it reduced the size 580 if buf.Len() < len(msg) { 581 msg = buf.Bytes() 582 } 583 } 584 } 585 586 // Check if we have encryption enabled 587 if m.config.EncryptionEnabled() { 588 // Encrypt the payload 589 var buf bytes.Buffer 590 primaryKey := m.config.Keyring.GetPrimaryKey() 591 err := encryptPayload(m.encryptionVersion(), primaryKey, msg, nil, &buf) 592 if err != nil { 593 m.logger.Printf("[ERR] memberlist: Encryption of message failed: %v", err) 594 return err 595 } 596 msg = buf.Bytes() 597 } 598 599 metrics.IncrCounter([]string{"memberlist", "udp", "sent"}, float32(len(msg))) 600 _, err := m.udpListener.WriteTo(msg, to) 601 return err 602} 603 604// rawSendMsgTCP is used to send a TCP message to another host without modification 605func (m *Memberlist) rawSendMsgTCP(conn net.Conn, sendBuf []byte) error { 606 // Check if compresion is enabled 607 if m.config.EnableCompression { 608 compBuf, err := compressPayload(sendBuf) 609 if err != nil { 610 m.logger.Printf("[ERROR] memberlist: Failed to compress payload: %v", err) 611 } else { 612 sendBuf = compBuf.Bytes() 613 } 614 } 615 616 // Check if encryption is enabled 617 if m.config.EncryptionEnabled() { 618 crypt, err := m.encryptLocalState(sendBuf) 619 if err != nil { 620 m.logger.Printf("[ERROR] memberlist: Failed to encrypt local state: %v", err) 621 return err 622 } 623 sendBuf = crypt 624 } 625 626 // Write out the entire send buffer 627 metrics.IncrCounter([]string{"memberlist", "tcp", "sent"}, float32(len(sendBuf))) 628 629 if n, err := conn.Write(sendBuf); err != nil { 630 return err 631 } else if n != len(sendBuf) { 632 return fmt.Errorf("only %d of %d bytes written", n, len(sendBuf)) 633 } 634 635 return nil 636} 637 638// sendTCPUserMsg is used to send a TCP userMsg to another host 639func (m *Memberlist) sendTCPUserMsg(to net.Addr, sendBuf []byte) error { 640 dialer := net.Dialer{Timeout: m.config.TCPTimeout} 641 conn, err := dialer.Dial("tcp", to.String()) 642 if err != nil { 643 return err 644 } 645 defer conn.Close() 646 647 bufConn := bytes.NewBuffer(nil) 648 649 if err := bufConn.WriteByte(byte(userMsg)); err != nil { 650 return err 651 } 652 653 // Send our node state 654 header := userMsgHeader{UserMsgLen: len(sendBuf)} 655 hd := codec.MsgpackHandle{} 656 enc := codec.NewEncoder(bufConn, &hd) 657 658 if err := enc.Encode(&header); err != nil { 659 return err 660 } 661 662 if _, err := bufConn.Write(sendBuf); err != nil { 663 return err 664 } 665 666 return m.rawSendMsgTCP(conn, bufConn.Bytes()) 667} 668 669// sendAndReceiveState is used to initiate a push/pull over TCP with a remote node 670func (m *Memberlist) sendAndReceiveState(addr []byte, port uint16, join bool) ([]pushNodeState, []byte, error) { 671 // Attempt to connect 672 dialer := net.Dialer{Timeout: m.config.TCPTimeout} 673 dest := net.TCPAddr{IP: addr, Port: int(port)} 674 conn, err := dialer.Dial("tcp", dest.String()) 675 if err != nil { 676 return nil, nil, err 677 } 678 defer conn.Close() 679 m.logger.Printf("[DEBUG] memberlist: Initiating push/pull sync with: %s", conn.RemoteAddr()) 680 metrics.IncrCounter([]string{"memberlist", "tcp", "connect"}, 1) 681 682 // Send our state 683 if err := m.sendLocalState(conn, join); err != nil { 684 return nil, nil, err 685 } 686 687 conn.SetDeadline(time.Now().Add(m.config.TCPTimeout)) 688 msgType, bufConn, dec, err := m.readTCP(conn) 689 if err != nil { 690 return nil, nil, err 691 } 692 693 // Quit if not push/pull 694 if msgType != pushPullMsg { 695 err := fmt.Errorf("received invalid msgType (%d), expected pushPullMsg (%d) %s", msgType, pushPullMsg, LogConn(conn)) 696 return nil, nil, err 697 } 698 699 // Read remote state 700 _, remoteNodes, userState, err := m.readRemoteState(bufConn, dec) 701 return remoteNodes, userState, err 702} 703 704// sendLocalState is invoked to send our local state over a tcp connection 705func (m *Memberlist) sendLocalState(conn net.Conn, join bool) error { 706 // Setup a deadline 707 conn.SetDeadline(time.Now().Add(m.config.TCPTimeout)) 708 709 // Prepare the local node state 710 m.nodeLock.RLock() 711 localNodes := make([]pushNodeState, len(m.nodes)) 712 for idx, n := range m.nodes { 713 localNodes[idx].Name = n.Name 714 localNodes[idx].Addr = n.Addr 715 localNodes[idx].Port = n.Port 716 localNodes[idx].Incarnation = n.Incarnation 717 localNodes[idx].State = n.State 718 localNodes[idx].Meta = n.Meta 719 localNodes[idx].Vsn = []uint8{ 720 n.PMin, n.PMax, n.PCur, 721 n.DMin, n.DMax, n.DCur, 722 } 723 } 724 m.nodeLock.RUnlock() 725 726 // Get the delegate state 727 var userData []byte 728 if m.config.Delegate != nil { 729 userData = m.config.Delegate.LocalState(join) 730 } 731 732 // Create a bytes buffer writer 733 bufConn := bytes.NewBuffer(nil) 734 735 // Send our node state 736 header := pushPullHeader{Nodes: len(localNodes), UserStateLen: len(userData), Join: join} 737 hd := codec.MsgpackHandle{} 738 enc := codec.NewEncoder(bufConn, &hd) 739 740 // Begin state push 741 if _, err := bufConn.Write([]byte{byte(pushPullMsg)}); err != nil { 742 return err 743 } 744 745 if err := enc.Encode(&header); err != nil { 746 return err 747 } 748 for i := 0; i < header.Nodes; i++ { 749 if err := enc.Encode(&localNodes[i]); err != nil { 750 return err 751 } 752 } 753 754 // Write the user state as well 755 if userData != nil { 756 if _, err := bufConn.Write(userData); err != nil { 757 return err 758 } 759 } 760 761 // Get the send buffer 762 return m.rawSendMsgTCP(conn, bufConn.Bytes()) 763} 764 765// encryptLocalState is used to help encrypt local state before sending 766func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) { 767 var buf bytes.Buffer 768 769 // Write the encryptMsg byte 770 buf.WriteByte(byte(encryptMsg)) 771 772 // Write the size of the message 773 sizeBuf := make([]byte, 4) 774 encVsn := m.encryptionVersion() 775 encLen := encryptedLength(encVsn, len(sendBuf)) 776 binary.BigEndian.PutUint32(sizeBuf, uint32(encLen)) 777 buf.Write(sizeBuf) 778 779 // Write the encrypted cipher text to the buffer 780 key := m.config.Keyring.GetPrimaryKey() 781 err := encryptPayload(encVsn, key, sendBuf, buf.Bytes()[:5], &buf) 782 if err != nil { 783 return nil, err 784 } 785 return buf.Bytes(), nil 786} 787 788// decryptRemoteState is used to help decrypt the remote state 789func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) { 790 // Read in enough to determine message length 791 cipherText := bytes.NewBuffer(nil) 792 cipherText.WriteByte(byte(encryptMsg)) 793 _, err := io.CopyN(cipherText, bufConn, 4) 794 if err != nil { 795 return nil, err 796 } 797 798 // Ensure we aren't asked to download too much. This is to guard against 799 // an attack vector where a huge amount of state is sent 800 moreBytes := binary.BigEndian.Uint32(cipherText.Bytes()[1:5]) 801 if moreBytes > maxPushStateBytes { 802 return nil, fmt.Errorf("Remote node state is larger than limit (%d)", moreBytes) 803 } 804 805 // Read in the rest of the payload 806 _, err = io.CopyN(cipherText, bufConn, int64(moreBytes)) 807 if err != nil { 808 return nil, err 809 } 810 811 // Decrypt the cipherText 812 dataBytes := cipherText.Bytes()[:5] 813 cipherBytes := cipherText.Bytes()[5:] 814 815 // Decrypt the payload 816 keys := m.config.Keyring.GetKeys() 817 return decryptPayload(keys, cipherBytes, dataBytes) 818} 819 820// readTCP is used to read the start of a TCP stream. 821// it decrypts and decompresses the stream if necessary 822func (m *Memberlist) readTCP(conn net.Conn) (messageType, io.Reader, *codec.Decoder, error) { 823 // Created a buffered reader 824 var bufConn io.Reader = bufio.NewReader(conn) 825 826 // Read the message type 827 buf := [1]byte{0} 828 if _, err := bufConn.Read(buf[:]); err != nil { 829 return 0, nil, nil, err 830 } 831 msgType := messageType(buf[0]) 832 833 // Check if the message is encrypted 834 if msgType == encryptMsg { 835 if !m.config.EncryptionEnabled() { 836 return 0, nil, nil, 837 fmt.Errorf("Remote state is encrypted and encryption is not configured") 838 } 839 840 plain, err := m.decryptRemoteState(bufConn) 841 if err != nil { 842 return 0, nil, nil, err 843 } 844 845 // Reset message type and bufConn 846 msgType = messageType(plain[0]) 847 bufConn = bytes.NewReader(plain[1:]) 848 } else if m.config.EncryptionEnabled() { 849 return 0, nil, nil, 850 fmt.Errorf("Encryption is configured but remote state is not encrypted") 851 } 852 853 // Get the msgPack decoders 854 hd := codec.MsgpackHandle{} 855 dec := codec.NewDecoder(bufConn, &hd) 856 857 // Check if we have a compressed message 858 if msgType == compressMsg { 859 var c compress 860 if err := dec.Decode(&c); err != nil { 861 return 0, nil, nil, err 862 } 863 decomp, err := decompressBuffer(&c) 864 if err != nil { 865 return 0, nil, nil, err 866 } 867 868 // Reset the message type 869 msgType = messageType(decomp[0]) 870 871 // Create a new bufConn 872 bufConn = bytes.NewReader(decomp[1:]) 873 874 // Create a new decoder 875 dec = codec.NewDecoder(bufConn, &hd) 876 } 877 878 return msgType, bufConn, dec, nil 879} 880 881// readRemoteState is used to read the remote state from a connection 882func (m *Memberlist) readRemoteState(bufConn io.Reader, dec *codec.Decoder) (bool, []pushNodeState, []byte, error) { 883 // Read the push/pull header 884 var header pushPullHeader 885 if err := dec.Decode(&header); err != nil { 886 return false, nil, nil, err 887 } 888 889 // Allocate space for the transfer 890 remoteNodes := make([]pushNodeState, header.Nodes) 891 892 // Try to decode all the states 893 for i := 0; i < header.Nodes; i++ { 894 if err := dec.Decode(&remoteNodes[i]); err != nil { 895 return false, nil, nil, err 896 } 897 } 898 899 // Read the remote user state into a buffer 900 var userBuf []byte 901 if header.UserStateLen > 0 { 902 userBuf = make([]byte, header.UserStateLen) 903 bytes, err := io.ReadAtLeast(bufConn, userBuf, header.UserStateLen) 904 if err == nil && bytes != header.UserStateLen { 905 err = fmt.Errorf( 906 "Failed to read full user state (%d / %d)", 907 bytes, header.UserStateLen) 908 } 909 if err != nil { 910 return false, nil, nil, err 911 } 912 } 913 914 // For proto versions < 2, there is no port provided. Mask old 915 // behavior by using the configured port 916 for idx := range remoteNodes { 917 if m.ProtocolVersion() < 2 || remoteNodes[idx].Port == 0 { 918 remoteNodes[idx].Port = uint16(m.config.BindPort) 919 } 920 } 921 922 return header.Join, remoteNodes, userBuf, nil 923} 924 925// mergeRemoteState is used to merge the remote state with our local state 926func (m *Memberlist) mergeRemoteState(join bool, remoteNodes []pushNodeState, userBuf []byte) error { 927 if err := m.verifyProtocol(remoteNodes); err != nil { 928 return err 929 } 930 931 // Invoke the merge delegate if any 932 if join && m.config.Merge != nil { 933 nodes := make([]*Node, len(remoteNodes)) 934 for idx, n := range remoteNodes { 935 nodes[idx] = &Node{ 936 Name: n.Name, 937 Addr: n.Addr, 938 Port: n.Port, 939 Meta: n.Meta, 940 PMin: n.Vsn[0], 941 PMax: n.Vsn[1], 942 PCur: n.Vsn[2], 943 DMin: n.Vsn[3], 944 DMax: n.Vsn[4], 945 DCur: n.Vsn[5], 946 } 947 } 948 if err := m.config.Merge.NotifyMerge(nodes); err != nil { 949 return err 950 } 951 } 952 953 // Merge the membership state 954 m.mergeState(remoteNodes) 955 956 // Invoke the delegate for user state 957 if userBuf != nil && m.config.Delegate != nil { 958 m.config.Delegate.MergeRemoteState(userBuf, join) 959 } 960 return nil 961} 962 963// readUserMsg is used to decode a userMsg from a TCP stream 964func (m *Memberlist) readUserMsg(bufConn io.Reader, dec *codec.Decoder) error { 965 // Read the user message header 966 var header userMsgHeader 967 if err := dec.Decode(&header); err != nil { 968 return err 969 } 970 971 // Read the user message into a buffer 972 var userBuf []byte 973 if header.UserMsgLen > 0 { 974 userBuf = make([]byte, header.UserMsgLen) 975 bytes, err := io.ReadAtLeast(bufConn, userBuf, header.UserMsgLen) 976 if err == nil && bytes != header.UserMsgLen { 977 err = fmt.Errorf( 978 "Failed to read full user message (%d / %d)", 979 bytes, header.UserMsgLen) 980 } 981 if err != nil { 982 return err 983 } 984 985 d := m.config.Delegate 986 if d != nil { 987 d.NotifyMsg(userBuf) 988 } 989 } 990 991 return nil 992} 993 994// sendPingAndWaitForAck makes a TCP connection to the given address, sends 995// a ping, and waits for an ack. All of this is done as a series of blocking 996// operations, given the deadline. The bool return parameter is true if we 997// we able to round trip a ping to the other node. 998func (m *Memberlist) sendPingAndWaitForAck(destAddr net.Addr, ping ping, deadline time.Time) (bool, error) { 999 dialer := net.Dialer{Deadline: deadline} 1000 conn, err := dialer.Dial("tcp", destAddr.String()) 1001 if err != nil { 1002 // If the node is actually dead we expect this to fail, so we 1003 // shouldn't spam the logs with it. After this point, errors 1004 // with the connection are real, unexpected errors and should 1005 // get propagated up. 1006 return false, nil 1007 } 1008 defer conn.Close() 1009 conn.SetDeadline(deadline) 1010 1011 out, err := encode(pingMsg, &ping) 1012 if err != nil { 1013 return false, err 1014 } 1015 1016 if err = m.rawSendMsgTCP(conn, out.Bytes()); err != nil { 1017 return false, err 1018 } 1019 1020 msgType, _, dec, err := m.readTCP(conn) 1021 if err != nil { 1022 return false, err 1023 } 1024 1025 if msgType != ackRespMsg { 1026 return false, fmt.Errorf("Unexpected msgType (%d) from TCP ping %s", msgType, LogConn(conn)) 1027 } 1028 1029 var ack ackResp 1030 if err = dec.Decode(&ack); err != nil { 1031 return false, err 1032 } 1033 1034 if ack.SeqNo != ping.SeqNo { 1035 return false, fmt.Errorf("Sequence number from ack (%d) doesn't match ping (%d) from TCP ping %s", ack.SeqNo, ping.SeqNo, LogConn(conn)) 1036 } 1037 1038 return true, nil 1039} 1040