1package memberlist
2
3import (
4	"bufio"
5	"bytes"
6	"encoding/binary"
7	"fmt"
8	"io"
9	"net"
10	"time"
11
12	"github.com/armon/go-metrics"
13	"github.com/hashicorp/go-msgpack/codec"
14)
15
16// This is the minimum and maximum protocol version that we can
17// _understand_. We're allowed to speak at any version within this
18// range. This range is inclusive.
19const (
20	ProtocolVersionMin uint8 = 1
21
22	// Version 3 added support for TCP pings but we kept the default
23	// protocol version at 2 to ease transition to this new feature.
24	// A memberlist speaking version 2 of the protocol will attempt
25	// to TCP ping another memberlist who understands version 3 or
26	// greater.
27	ProtocolVersion2Compatible = 2
28
29	ProtocolVersionMax = 3
30)
31
32// messageType is an integer ID of a type of message that can be received
33// on network channels from other members.
34type messageType uint8
35
36// The list of available message types.
37const (
38	pingMsg messageType = iota
39	indirectPingMsg
40	ackRespMsg
41	suspectMsg
42	aliveMsg
43	deadMsg
44	pushPullMsg
45	compoundMsg
46	userMsg // User mesg, not handled by us
47	compressMsg
48	encryptMsg
49)
50
51// compressionType is used to specify the compression algorithm
52type compressionType uint8
53
54const (
55	lzwAlgo compressionType = iota
56)
57
58const (
59	MetaMaxSize            = 512 // Maximum size for node meta data
60	compoundHeaderOverhead = 2   // Assumed header overhead
61	compoundOverhead       = 2   // Assumed overhead per entry in compoundHeader
62	udpBufSize             = 65536
63	udpRecvBuf             = 2 * 1024 * 1024
64	udpSendBuf             = 1400
65	userMsgOverhead        = 1
66	blockingWarning        = 10 * time.Millisecond // Warn if a UDP packet takes this long to process
67	maxPushStateBytes      = 10 * 1024 * 1024
68)
69
70// ping request sent directly to node
71type ping struct {
72	SeqNo uint32
73
74	// Node is sent so the target can verify they are
75	// the intended recipient. This is to protect again an agent
76	// restart with a new name.
77	Node string
78}
79
80// indirect ping sent to an indirect ndoe
81type indirectPingReq struct {
82	SeqNo  uint32
83	Target []byte
84	Port   uint16
85	Node   string
86}
87
88// ack response is sent for a ping
89type ackResp struct {
90	SeqNo   uint32
91	Payload []byte
92}
93
94// suspect is broadcast when we suspect a node is dead
95type suspect struct {
96	Incarnation uint32
97	Node        string
98	From        string // Include who is suspecting
99}
100
101// alive is broadcast when we know a node is alive.
102// Overloaded for nodes joining
103type alive struct {
104	Incarnation uint32
105	Node        string
106	Addr        []byte
107	Port        uint16
108	Meta        []byte
109
110	// The versions of the protocol/delegate that are being spoken, order:
111	// pmin, pmax, pcur, dmin, dmax, dcur
112	Vsn []uint8
113}
114
115// dead is broadcast when we confirm a node is dead
116// Overloaded for nodes leaving
117type dead struct {
118	Incarnation uint32
119	Node        string
120	From        string // Include who is suspecting
121}
122
123// pushPullHeader is used to inform the
124// otherside how many states we are transfering
125type pushPullHeader struct {
126	Nodes        int
127	UserStateLen int  // Encodes the byte lengh of user state
128	Join         bool // Is this a join request or a anti-entropy run
129}
130
131// userMsgHeader is used to encapsulate a userMsg
132type userMsgHeader struct {
133	UserMsgLen int // Encodes the byte lengh of user state
134}
135
136// pushNodeState is used for pushPullReq when we are
137// transfering out node states
138type pushNodeState struct {
139	Name        string
140	Addr        []byte
141	Port        uint16
142	Meta        []byte
143	Incarnation uint32
144	State       nodeStateType
145	Vsn         []uint8 // Protocol versions
146}
147
148// compress is used to wrap an underlying payload
149// using a specified compression algorithm
150type compress struct {
151	Algo compressionType
152	Buf  []byte
153}
154
155// msgHandoff is used to transfer a message between goroutines
156type msgHandoff struct {
157	msgType messageType
158	buf     []byte
159	from    net.Addr
160}
161
162// encryptionVersion returns the encryption version to use
163func (m *Memberlist) encryptionVersion() encryptionVersion {
164	switch m.ProtocolVersion() {
165	case 1:
166		return 0
167	default:
168		return 1
169	}
170}
171
172// setUDPRecvBuf is used to resize the UDP receive window. The function
173// attempts to set the read buffer to `udpRecvBuf` but backs off until
174// the read buffer can be set.
175func setUDPRecvBuf(c *net.UDPConn) {
176	size := udpRecvBuf
177	for {
178		if err := c.SetReadBuffer(size); err == nil {
179			break
180		}
181		size = size / 2
182	}
183}
184
185// tcpListen listens for and handles incoming connections
186func (m *Memberlist) tcpListen() {
187	for {
188		conn, err := m.tcpListener.AcceptTCP()
189		if err != nil {
190			if m.shutdown {
191				break
192			}
193			m.logger.Printf("[ERR] memberlist: Error accepting TCP connection: %s", err)
194			continue
195		}
196		go m.handleConn(conn)
197	}
198}
199
200// handleConn handles a single incoming TCP connection
201func (m *Memberlist) handleConn(conn *net.TCPConn) {
202	m.logger.Printf("[DEBUG] memberlist: TCP connection %s", LogConn(conn))
203
204	defer conn.Close()
205	metrics.IncrCounter([]string{"memberlist", "tcp", "accept"}, 1)
206
207	conn.SetDeadline(time.Now().Add(m.config.TCPTimeout))
208	msgType, bufConn, dec, err := m.readTCP(conn)
209	if err != nil {
210		m.logger.Printf("[ERR] memberlist: failed to receive: %s %s", err, LogConn(conn))
211		return
212	}
213
214	switch msgType {
215	case userMsg:
216		if err := m.readUserMsg(bufConn, dec); err != nil {
217			m.logger.Printf("[ERR] memberlist: Failed to receive user message: %s %s", err, LogConn(conn))
218		}
219	case pushPullMsg:
220		join, remoteNodes, userState, err := m.readRemoteState(bufConn, dec)
221		if err != nil {
222			m.logger.Printf("[ERR] memberlist: Failed to read remote state: %s %s", err, LogConn(conn))
223			return
224		}
225
226		if err := m.sendLocalState(conn, join); err != nil {
227			m.logger.Printf("[ERR] memberlist: Failed to push local state: %s %s", err, LogConn(conn))
228			return
229		}
230
231		if err := m.mergeRemoteState(join, remoteNodes, userState); err != nil {
232			m.logger.Printf("[ERR] memberlist: Failed push/pull merge: %s %s", err, LogConn(conn))
233			return
234		}
235	case pingMsg:
236		var p ping
237		if err := dec.Decode(&p); err != nil {
238			m.logger.Printf("[ERR] memberlist: Failed to decode TCP ping: %s %s", err, LogConn(conn))
239			return
240		}
241
242		if p.Node != "" && p.Node != m.config.Name {
243			m.logger.Printf("[WARN] memberlist: Got ping for unexpected node %s %s", p.Node, LogConn(conn))
244			return
245		}
246
247		ack := ackResp{p.SeqNo, nil}
248		out, err := encode(ackRespMsg, &ack)
249		if err != nil {
250			m.logger.Printf("[ERR] memberlist: Failed to encode TCP ack: %s", err)
251			return
252		}
253
254		err = m.rawSendMsgTCP(conn, out.Bytes())
255		if err != nil {
256			m.logger.Printf("[ERR] memberlist: Failed to send TCP ack: %s %s", err, LogConn(conn))
257			return
258		}
259	default:
260		m.logger.Printf("[ERR] memberlist: Received invalid msgType (%d) %s", msgType, LogConn(conn))
261	}
262}
263
264// udpListen listens for and handles incoming UDP packets
265func (m *Memberlist) udpListen() {
266	var n int
267	var addr net.Addr
268	var err error
269	var lastPacket time.Time
270	for {
271		// Do a check for potentially blocking operations
272		if !lastPacket.IsZero() && time.Now().Sub(lastPacket) > blockingWarning {
273			diff := time.Now().Sub(lastPacket)
274			m.logger.Printf(
275				"[DEBUG] memberlist: Potential blocking operation. Last command took %v",
276				diff)
277		}
278
279		// Create a new buffer
280		// TODO: Use Sync.Pool eventually
281		buf := make([]byte, udpBufSize)
282
283		// Read a packet
284		n, addr, err = m.udpListener.ReadFrom(buf)
285		if err != nil {
286			if m.shutdown {
287				break
288			}
289			m.logger.Printf("[ERR] memberlist: Error reading UDP packet: %s", err)
290			continue
291		}
292
293		// Capture the reception time of the packet as close to the
294		// system calls as possible.
295		lastPacket = time.Now()
296
297		// Check the length
298		if n < 1 {
299			m.logger.Printf("[ERR] memberlist: UDP packet too short (%d bytes) %s",
300				len(buf), LogAddress(addr))
301			continue
302		}
303
304		// Ingest this packet
305		metrics.IncrCounter([]string{"memberlist", "udp", "received"}, float32(n))
306		m.ingestPacket(buf[:n], addr, lastPacket)
307	}
308}
309
310func (m *Memberlist) ingestPacket(buf []byte, from net.Addr, timestamp time.Time) {
311	// Check if encryption is enabled
312	if m.config.EncryptionEnabled() {
313		// Decrypt the payload
314		plain, err := decryptPayload(m.config.Keyring.GetKeys(), buf, nil)
315		if err != nil {
316			m.logger.Printf("[ERR] memberlist: Decrypt packet failed: %v %s", err, LogAddress(from))
317			return
318		}
319
320		// Continue processing the plaintext buffer
321		buf = plain
322	}
323
324	// Handle the command
325	m.handleCommand(buf, from, timestamp)
326}
327
328func (m *Memberlist) handleCommand(buf []byte, from net.Addr, timestamp time.Time) {
329	// Decode the message type
330	msgType := messageType(buf[0])
331	buf = buf[1:]
332
333	// Switch on the msgType
334	switch msgType {
335	case compoundMsg:
336		m.handleCompound(buf, from, timestamp)
337	case compressMsg:
338		m.handleCompressed(buf, from, timestamp)
339
340	case pingMsg:
341		m.handlePing(buf, from)
342	case indirectPingMsg:
343		m.handleIndirectPing(buf, from)
344	case ackRespMsg:
345		m.handleAck(buf, from, timestamp)
346
347	case suspectMsg:
348		fallthrough
349	case aliveMsg:
350		fallthrough
351	case deadMsg:
352		fallthrough
353	case userMsg:
354		select {
355		case m.handoff <- msgHandoff{msgType, buf, from}:
356		default:
357			m.logger.Printf("[WARN] memberlist: UDP handler queue full, dropping message (%d) %s", msgType, LogAddress(from))
358		}
359
360	default:
361		m.logger.Printf("[ERR] memberlist: UDP msg type (%d) not supported %s", msgType, LogAddress(from))
362	}
363}
364
365// udpHandler processes messages received over UDP, but is decoupled
366// from the listener to avoid blocking the listener which may cause
367// ping/ack messages to be delayed.
368func (m *Memberlist) udpHandler() {
369	for {
370		select {
371		case msg := <-m.handoff:
372			msgType := msg.msgType
373			buf := msg.buf
374			from := msg.from
375
376			switch msgType {
377			case suspectMsg:
378				m.handleSuspect(buf, from)
379			case aliveMsg:
380				m.handleAlive(buf, from)
381			case deadMsg:
382				m.handleDead(buf, from)
383			case userMsg:
384				m.handleUser(buf, from)
385			default:
386				m.logger.Printf("[ERR] memberlist: UDP msg type (%d) not supported %s (handler)", msgType, LogAddress(from))
387			}
388
389		case <-m.shutdownCh:
390			return
391		}
392	}
393}
394
395func (m *Memberlist) handleCompound(buf []byte, from net.Addr, timestamp time.Time) {
396	// Decode the parts
397	trunc, parts, err := decodeCompoundMessage(buf)
398	if err != nil {
399		m.logger.Printf("[ERR] memberlist: Failed to decode compound request: %s %s", err, LogAddress(from))
400		return
401	}
402
403	// Log any truncation
404	if trunc > 0 {
405		m.logger.Printf("[WARN] memberlist: Compound request had %d truncated messages %s", trunc, LogAddress(from))
406	}
407
408	// Handle each message
409	for _, part := range parts {
410		m.handleCommand(part, from, timestamp)
411	}
412}
413
414func (m *Memberlist) handlePing(buf []byte, from net.Addr) {
415	var p ping
416	if err := decode(buf, &p); err != nil {
417		m.logger.Printf("[ERR] memberlist: Failed to decode ping request: %s %s", err, LogAddress(from))
418		return
419	}
420	// If node is provided, verify that it is for us
421	if p.Node != "" && p.Node != m.config.Name {
422		m.logger.Printf("[WARN] memberlist: Got ping for unexpected node '%s' %s", p.Node, LogAddress(from))
423		return
424	}
425	var ack ackResp
426	ack.SeqNo = p.SeqNo
427	if m.config.Ping != nil {
428		ack.Payload = m.config.Ping.AckPayload()
429	}
430	if err := m.encodeAndSendMsg(from, ackRespMsg, &ack); err != nil {
431		m.logger.Printf("[ERR] memberlist: Failed to send ack: %s %s", err, LogAddress(from))
432	}
433}
434
435func (m *Memberlist) handleIndirectPing(buf []byte, from net.Addr) {
436	var ind indirectPingReq
437	if err := decode(buf, &ind); err != nil {
438		m.logger.Printf("[ERR] memberlist: Failed to decode indirect ping request: %s %s", err, LogAddress(from))
439		return
440	}
441
442	// For proto versions < 2, there is no port provided. Mask old
443	// behavior by using the configured port
444	if m.ProtocolVersion() < 2 || ind.Port == 0 {
445		ind.Port = uint16(m.config.BindPort)
446	}
447
448	// Send a ping to the correct host
449	localSeqNo := m.nextSeqNo()
450	ping := ping{SeqNo: localSeqNo, Node: ind.Node}
451	destAddr := &net.UDPAddr{IP: ind.Target, Port: int(ind.Port)}
452
453	// Setup a response handler to relay the ack
454	respHandler := func(payload []byte, timestamp time.Time) {
455		ack := ackResp{ind.SeqNo, nil}
456		if err := m.encodeAndSendMsg(from, ackRespMsg, &ack); err != nil {
457			m.logger.Printf("[ERR] memberlist: Failed to forward ack: %s %s", err, LogAddress(from))
458		}
459	}
460	m.setAckHandler(localSeqNo, respHandler, m.config.ProbeTimeout)
461
462	// Send the ping
463	if err := m.encodeAndSendMsg(destAddr, pingMsg, &ping); err != nil {
464		m.logger.Printf("[ERR] memberlist: Failed to send ping: %s %s", err, LogAddress(from))
465	}
466}
467
468func (m *Memberlist) handleAck(buf []byte, from net.Addr, timestamp time.Time) {
469	var ack ackResp
470	if err := decode(buf, &ack); err != nil {
471		m.logger.Printf("[ERR] memberlist: Failed to decode ack response: %s %s", err, LogAddress(from))
472		return
473	}
474	m.invokeAckHandler(ack, timestamp)
475}
476
477func (m *Memberlist) handleSuspect(buf []byte, from net.Addr) {
478	var sus suspect
479	if err := decode(buf, &sus); err != nil {
480		m.logger.Printf("[ERR] memberlist: Failed to decode suspect message: %s %s", err, LogAddress(from))
481		return
482	}
483	m.suspectNode(&sus)
484}
485
486func (m *Memberlist) handleAlive(buf []byte, from net.Addr) {
487	var live alive
488	if err := decode(buf, &live); err != nil {
489		m.logger.Printf("[ERR] memberlist: Failed to decode alive message: %s %s", err, LogAddress(from))
490		return
491	}
492
493	// For proto versions < 2, there is no port provided. Mask old
494	// behavior by using the configured port
495	if m.ProtocolVersion() < 2 || live.Port == 0 {
496		live.Port = uint16(m.config.BindPort)
497	}
498
499	m.aliveNode(&live, nil, false)
500}
501
502func (m *Memberlist) handleDead(buf []byte, from net.Addr) {
503	var d dead
504	if err := decode(buf, &d); err != nil {
505		m.logger.Printf("[ERR] memberlist: Failed to decode dead message: %s %s", err, LogAddress(from))
506		return
507	}
508	m.deadNode(&d)
509}
510
511// handleUser is used to notify channels of incoming user data
512func (m *Memberlist) handleUser(buf []byte, from net.Addr) {
513	d := m.config.Delegate
514	if d != nil {
515		d.NotifyMsg(buf)
516	}
517}
518
519// handleCompressed is used to unpack a compressed message
520func (m *Memberlist) handleCompressed(buf []byte, from net.Addr, timestamp time.Time) {
521	// Try to decode the payload
522	payload, err := decompressPayload(buf)
523	if err != nil {
524		m.logger.Printf("[ERR] memberlist: Failed to decompress payload: %v %s", err, LogAddress(from))
525		return
526	}
527
528	// Recursively handle the payload
529	m.handleCommand(payload, from, timestamp)
530}
531
532// encodeAndSendMsg is used to combine the encoding and sending steps
533func (m *Memberlist) encodeAndSendMsg(to net.Addr, msgType messageType, msg interface{}) error {
534	out, err := encode(msgType, msg)
535	if err != nil {
536		return err
537	}
538	if err := m.sendMsg(to, out.Bytes()); err != nil {
539		return err
540	}
541	return nil
542}
543
544// sendMsg is used to send a UDP message to another host. It will opportunistically
545// create a compoundMsg and piggy back other broadcasts
546func (m *Memberlist) sendMsg(to net.Addr, msg []byte) error {
547	// Check if we can piggy back any messages
548	bytesAvail := udpSendBuf - len(msg) - compoundHeaderOverhead
549	if m.config.EncryptionEnabled() {
550		bytesAvail -= encryptOverhead(m.encryptionVersion())
551	}
552	extra := m.getBroadcasts(compoundOverhead, bytesAvail)
553
554	// Fast path if nothing to piggypack
555	if len(extra) == 0 {
556		return m.rawSendMsgUDP(to, msg)
557	}
558
559	// Join all the messages
560	msgs := make([][]byte, 0, 1+len(extra))
561	msgs = append(msgs, msg)
562	msgs = append(msgs, extra...)
563
564	// Create a compound message
565	compound := makeCompoundMessage(msgs)
566
567	// Send the message
568	return m.rawSendMsgUDP(to, compound.Bytes())
569}
570
571// rawSendMsgUDP is used to send a UDP message to another host without modification
572func (m *Memberlist) rawSendMsgUDP(to net.Addr, msg []byte) error {
573	// Check if we have compression enabled
574	if m.config.EnableCompression {
575		buf, err := compressPayload(msg)
576		if err != nil {
577			m.logger.Printf("[WARN] memberlist: Failed to compress payload: %v", err)
578		} else {
579			// Only use compression if it reduced the size
580			if buf.Len() < len(msg) {
581				msg = buf.Bytes()
582			}
583		}
584	}
585
586	// Check if we have encryption enabled
587	if m.config.EncryptionEnabled() {
588		// Encrypt the payload
589		var buf bytes.Buffer
590		primaryKey := m.config.Keyring.GetPrimaryKey()
591		err := encryptPayload(m.encryptionVersion(), primaryKey, msg, nil, &buf)
592		if err != nil {
593			m.logger.Printf("[ERR] memberlist: Encryption of message failed: %v", err)
594			return err
595		}
596		msg = buf.Bytes()
597	}
598
599	metrics.IncrCounter([]string{"memberlist", "udp", "sent"}, float32(len(msg)))
600	_, err := m.udpListener.WriteTo(msg, to)
601	return err
602}
603
604// rawSendMsgTCP is used to send a TCP message to another host without modification
605func (m *Memberlist) rawSendMsgTCP(conn net.Conn, sendBuf []byte) error {
606	// Check if compresion is enabled
607	if m.config.EnableCompression {
608		compBuf, err := compressPayload(sendBuf)
609		if err != nil {
610			m.logger.Printf("[ERROR] memberlist: Failed to compress payload: %v", err)
611		} else {
612			sendBuf = compBuf.Bytes()
613		}
614	}
615
616	// Check if encryption is enabled
617	if m.config.EncryptionEnabled() {
618		crypt, err := m.encryptLocalState(sendBuf)
619		if err != nil {
620			m.logger.Printf("[ERROR] memberlist: Failed to encrypt local state: %v", err)
621			return err
622		}
623		sendBuf = crypt
624	}
625
626	// Write out the entire send buffer
627	metrics.IncrCounter([]string{"memberlist", "tcp", "sent"}, float32(len(sendBuf)))
628
629	if n, err := conn.Write(sendBuf); err != nil {
630		return err
631	} else if n != len(sendBuf) {
632		return fmt.Errorf("only %d of %d bytes written", n, len(sendBuf))
633	}
634
635	return nil
636}
637
638// sendTCPUserMsg is used to send a TCP userMsg to another host
639func (m *Memberlist) sendTCPUserMsg(to net.Addr, sendBuf []byte) error {
640	dialer := net.Dialer{Timeout: m.config.TCPTimeout}
641	conn, err := dialer.Dial("tcp", to.String())
642	if err != nil {
643		return err
644	}
645	defer conn.Close()
646
647	bufConn := bytes.NewBuffer(nil)
648
649	if err := bufConn.WriteByte(byte(userMsg)); err != nil {
650		return err
651	}
652
653	// Send our node state
654	header := userMsgHeader{UserMsgLen: len(sendBuf)}
655	hd := codec.MsgpackHandle{}
656	enc := codec.NewEncoder(bufConn, &hd)
657
658	if err := enc.Encode(&header); err != nil {
659		return err
660	}
661
662	if _, err := bufConn.Write(sendBuf); err != nil {
663		return err
664	}
665
666	return m.rawSendMsgTCP(conn, bufConn.Bytes())
667}
668
669// sendAndReceiveState is used to initiate a push/pull over TCP with a remote node
670func (m *Memberlist) sendAndReceiveState(addr []byte, port uint16, join bool) ([]pushNodeState, []byte, error) {
671	// Attempt to connect
672	dialer := net.Dialer{Timeout: m.config.TCPTimeout}
673	dest := net.TCPAddr{IP: addr, Port: int(port)}
674	conn, err := dialer.Dial("tcp", dest.String())
675	if err != nil {
676		return nil, nil, err
677	}
678	defer conn.Close()
679	m.logger.Printf("[DEBUG] memberlist: Initiating push/pull sync with: %s", conn.RemoteAddr())
680	metrics.IncrCounter([]string{"memberlist", "tcp", "connect"}, 1)
681
682	// Send our state
683	if err := m.sendLocalState(conn, join); err != nil {
684		return nil, nil, err
685	}
686
687	conn.SetDeadline(time.Now().Add(m.config.TCPTimeout))
688	msgType, bufConn, dec, err := m.readTCP(conn)
689	if err != nil {
690		return nil, nil, err
691	}
692
693	// Quit if not push/pull
694	if msgType != pushPullMsg {
695		err := fmt.Errorf("received invalid msgType (%d), expected pushPullMsg (%d) %s", msgType, pushPullMsg, LogConn(conn))
696		return nil, nil, err
697	}
698
699	// Read remote state
700	_, remoteNodes, userState, err := m.readRemoteState(bufConn, dec)
701	return remoteNodes, userState, err
702}
703
704// sendLocalState is invoked to send our local state over a tcp connection
705func (m *Memberlist) sendLocalState(conn net.Conn, join bool) error {
706	// Setup a deadline
707	conn.SetDeadline(time.Now().Add(m.config.TCPTimeout))
708
709	// Prepare the local node state
710	m.nodeLock.RLock()
711	localNodes := make([]pushNodeState, len(m.nodes))
712	for idx, n := range m.nodes {
713		localNodes[idx].Name = n.Name
714		localNodes[idx].Addr = n.Addr
715		localNodes[idx].Port = n.Port
716		localNodes[idx].Incarnation = n.Incarnation
717		localNodes[idx].State = n.State
718		localNodes[idx].Meta = n.Meta
719		localNodes[idx].Vsn = []uint8{
720			n.PMin, n.PMax, n.PCur,
721			n.DMin, n.DMax, n.DCur,
722		}
723	}
724	m.nodeLock.RUnlock()
725
726	// Get the delegate state
727	var userData []byte
728	if m.config.Delegate != nil {
729		userData = m.config.Delegate.LocalState(join)
730	}
731
732	// Create a bytes buffer writer
733	bufConn := bytes.NewBuffer(nil)
734
735	// Send our node state
736	header := pushPullHeader{Nodes: len(localNodes), UserStateLen: len(userData), Join: join}
737	hd := codec.MsgpackHandle{}
738	enc := codec.NewEncoder(bufConn, &hd)
739
740	// Begin state push
741	if _, err := bufConn.Write([]byte{byte(pushPullMsg)}); err != nil {
742		return err
743	}
744
745	if err := enc.Encode(&header); err != nil {
746		return err
747	}
748	for i := 0; i < header.Nodes; i++ {
749		if err := enc.Encode(&localNodes[i]); err != nil {
750			return err
751		}
752	}
753
754	// Write the user state as well
755	if userData != nil {
756		if _, err := bufConn.Write(userData); err != nil {
757			return err
758		}
759	}
760
761	// Get the send buffer
762	return m.rawSendMsgTCP(conn, bufConn.Bytes())
763}
764
765// encryptLocalState is used to help encrypt local state before sending
766func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) {
767	var buf bytes.Buffer
768
769	// Write the encryptMsg byte
770	buf.WriteByte(byte(encryptMsg))
771
772	// Write the size of the message
773	sizeBuf := make([]byte, 4)
774	encVsn := m.encryptionVersion()
775	encLen := encryptedLength(encVsn, len(sendBuf))
776	binary.BigEndian.PutUint32(sizeBuf, uint32(encLen))
777	buf.Write(sizeBuf)
778
779	// Write the encrypted cipher text to the buffer
780	key := m.config.Keyring.GetPrimaryKey()
781	err := encryptPayload(encVsn, key, sendBuf, buf.Bytes()[:5], &buf)
782	if err != nil {
783		return nil, err
784	}
785	return buf.Bytes(), nil
786}
787
788// decryptRemoteState is used to help decrypt the remote state
789func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) {
790	// Read in enough to determine message length
791	cipherText := bytes.NewBuffer(nil)
792	cipherText.WriteByte(byte(encryptMsg))
793	_, err := io.CopyN(cipherText, bufConn, 4)
794	if err != nil {
795		return nil, err
796	}
797
798	// Ensure we aren't asked to download too much. This is to guard against
799	// an attack vector where a huge amount of state is sent
800	moreBytes := binary.BigEndian.Uint32(cipherText.Bytes()[1:5])
801	if moreBytes > maxPushStateBytes {
802		return nil, fmt.Errorf("Remote node state is larger than limit (%d)", moreBytes)
803	}
804
805	// Read in the rest of the payload
806	_, err = io.CopyN(cipherText, bufConn, int64(moreBytes))
807	if err != nil {
808		return nil, err
809	}
810
811	// Decrypt the cipherText
812	dataBytes := cipherText.Bytes()[:5]
813	cipherBytes := cipherText.Bytes()[5:]
814
815	// Decrypt the payload
816	keys := m.config.Keyring.GetKeys()
817	return decryptPayload(keys, cipherBytes, dataBytes)
818}
819
820// readTCP is used to read the start of a TCP stream.
821// it decrypts and decompresses the stream if necessary
822func (m *Memberlist) readTCP(conn net.Conn) (messageType, io.Reader, *codec.Decoder, error) {
823	// Created a buffered reader
824	var bufConn io.Reader = bufio.NewReader(conn)
825
826	// Read the message type
827	buf := [1]byte{0}
828	if _, err := bufConn.Read(buf[:]); err != nil {
829		return 0, nil, nil, err
830	}
831	msgType := messageType(buf[0])
832
833	// Check if the message is encrypted
834	if msgType == encryptMsg {
835		if !m.config.EncryptionEnabled() {
836			return 0, nil, nil,
837				fmt.Errorf("Remote state is encrypted and encryption is not configured")
838		}
839
840		plain, err := m.decryptRemoteState(bufConn)
841		if err != nil {
842			return 0, nil, nil, err
843		}
844
845		// Reset message type and bufConn
846		msgType = messageType(plain[0])
847		bufConn = bytes.NewReader(plain[1:])
848	} else if m.config.EncryptionEnabled() {
849		return 0, nil, nil,
850			fmt.Errorf("Encryption is configured but remote state is not encrypted")
851	}
852
853	// Get the msgPack decoders
854	hd := codec.MsgpackHandle{}
855	dec := codec.NewDecoder(bufConn, &hd)
856
857	// Check if we have a compressed message
858	if msgType == compressMsg {
859		var c compress
860		if err := dec.Decode(&c); err != nil {
861			return 0, nil, nil, err
862		}
863		decomp, err := decompressBuffer(&c)
864		if err != nil {
865			return 0, nil, nil, err
866		}
867
868		// Reset the message type
869		msgType = messageType(decomp[0])
870
871		// Create a new bufConn
872		bufConn = bytes.NewReader(decomp[1:])
873
874		// Create a new decoder
875		dec = codec.NewDecoder(bufConn, &hd)
876	}
877
878	return msgType, bufConn, dec, nil
879}
880
881// readRemoteState is used to read the remote state from a connection
882func (m *Memberlist) readRemoteState(bufConn io.Reader, dec *codec.Decoder) (bool, []pushNodeState, []byte, error) {
883	// Read the push/pull header
884	var header pushPullHeader
885	if err := dec.Decode(&header); err != nil {
886		return false, nil, nil, err
887	}
888
889	// Allocate space for the transfer
890	remoteNodes := make([]pushNodeState, header.Nodes)
891
892	// Try to decode all the states
893	for i := 0; i < header.Nodes; i++ {
894		if err := dec.Decode(&remoteNodes[i]); err != nil {
895			return false, nil, nil, err
896		}
897	}
898
899	// Read the remote user state into a buffer
900	var userBuf []byte
901	if header.UserStateLen > 0 {
902		userBuf = make([]byte, header.UserStateLen)
903		bytes, err := io.ReadAtLeast(bufConn, userBuf, header.UserStateLen)
904		if err == nil && bytes != header.UserStateLen {
905			err = fmt.Errorf(
906				"Failed to read full user state (%d / %d)",
907				bytes, header.UserStateLen)
908		}
909		if err != nil {
910			return false, nil, nil, err
911		}
912	}
913
914	// For proto versions < 2, there is no port provided. Mask old
915	// behavior by using the configured port
916	for idx := range remoteNodes {
917		if m.ProtocolVersion() < 2 || remoteNodes[idx].Port == 0 {
918			remoteNodes[idx].Port = uint16(m.config.BindPort)
919		}
920	}
921
922	return header.Join, remoteNodes, userBuf, nil
923}
924
925// mergeRemoteState is used to merge the remote state with our local state
926func (m *Memberlist) mergeRemoteState(join bool, remoteNodes []pushNodeState, userBuf []byte) error {
927	if err := m.verifyProtocol(remoteNodes); err != nil {
928		return err
929	}
930
931	// Invoke the merge delegate if any
932	if join && m.config.Merge != nil {
933		nodes := make([]*Node, len(remoteNodes))
934		for idx, n := range remoteNodes {
935			nodes[idx] = &Node{
936				Name: n.Name,
937				Addr: n.Addr,
938				Port: n.Port,
939				Meta: n.Meta,
940				PMin: n.Vsn[0],
941				PMax: n.Vsn[1],
942				PCur: n.Vsn[2],
943				DMin: n.Vsn[3],
944				DMax: n.Vsn[4],
945				DCur: n.Vsn[5],
946			}
947		}
948		if err := m.config.Merge.NotifyMerge(nodes); err != nil {
949			return err
950		}
951	}
952
953	// Merge the membership state
954	m.mergeState(remoteNodes)
955
956	// Invoke the delegate for user state
957	if userBuf != nil && m.config.Delegate != nil {
958		m.config.Delegate.MergeRemoteState(userBuf, join)
959	}
960	return nil
961}
962
963// readUserMsg is used to decode a userMsg from a TCP stream
964func (m *Memberlist) readUserMsg(bufConn io.Reader, dec *codec.Decoder) error {
965	// Read the user message header
966	var header userMsgHeader
967	if err := dec.Decode(&header); err != nil {
968		return err
969	}
970
971	// Read the user message into a buffer
972	var userBuf []byte
973	if header.UserMsgLen > 0 {
974		userBuf = make([]byte, header.UserMsgLen)
975		bytes, err := io.ReadAtLeast(bufConn, userBuf, header.UserMsgLen)
976		if err == nil && bytes != header.UserMsgLen {
977			err = fmt.Errorf(
978				"Failed to read full user message (%d / %d)",
979				bytes, header.UserMsgLen)
980		}
981		if err != nil {
982			return err
983		}
984
985		d := m.config.Delegate
986		if d != nil {
987			d.NotifyMsg(userBuf)
988		}
989	}
990
991	return nil
992}
993
994// sendPingAndWaitForAck makes a TCP connection to the given address, sends
995// a ping, and waits for an ack. All of this is done as a series of blocking
996// operations, given the deadline. The bool return parameter is true if we
997// we able to round trip a ping to the other node.
998func (m *Memberlist) sendPingAndWaitForAck(destAddr net.Addr, ping ping, deadline time.Time) (bool, error) {
999	dialer := net.Dialer{Deadline: deadline}
1000	conn, err := dialer.Dial("tcp", destAddr.String())
1001	if err != nil {
1002		// If the node is actually dead we expect this to fail, so we
1003		// shouldn't spam the logs with it. After this point, errors
1004		// with the connection are real, unexpected errors and should
1005		// get propagated up.
1006		return false, nil
1007	}
1008	defer conn.Close()
1009	conn.SetDeadline(deadline)
1010
1011	out, err := encode(pingMsg, &ping)
1012	if err != nil {
1013		return false, err
1014	}
1015
1016	if err = m.rawSendMsgTCP(conn, out.Bytes()); err != nil {
1017		return false, err
1018	}
1019
1020	msgType, _, dec, err := m.readTCP(conn)
1021	if err != nil {
1022		return false, err
1023	}
1024
1025	if msgType != ackRespMsg {
1026		return false, fmt.Errorf("Unexpected msgType (%d) from TCP ping %s", msgType, LogConn(conn))
1027	}
1028
1029	var ack ackResp
1030	if err = dec.Decode(&ack); err != nil {
1031		return false, err
1032	}
1033
1034	if ack.SeqNo != ping.SeqNo {
1035		return false, fmt.Errorf("Sequence number from ack (%d) doesn't match ping (%d) from TCP ping %s", ack.SeqNo, ping.SeqNo, LogConn(conn))
1036	}
1037
1038	return true, nil
1039}
1040