1package memberlist
2
3import (
4	"bufio"
5	"bytes"
6	"encoding/binary"
7	"fmt"
8	"hash/crc32"
9	"io"
10	"net"
11	"sync/atomic"
12	"time"
13
14	metrics "github.com/armon/go-metrics"
15	"github.com/hashicorp/go-msgpack/codec"
16)
17
18// This is the minimum and maximum protocol version that we can
19// _understand_. We're allowed to speak at any version within this
20// range. This range is inclusive.
21const (
22	ProtocolVersionMin uint8 = 1
23
24	// Version 3 added support for TCP pings but we kept the default
25	// protocol version at 2 to ease transition to this new feature.
26	// A memberlist speaking version 2 of the protocol will attempt
27	// to TCP ping another memberlist who understands version 3 or
28	// greater.
29	//
30	// Version 4 added support for nacks as part of indirect probes.
31	// A memberlist speaking version 2 of the protocol will expect
32	// nacks from another memberlist who understands version 4 or
33	// greater, and likewise nacks will be sent to memberlists who
34	// understand version 4 or greater.
35	ProtocolVersion2Compatible = 2
36
37	ProtocolVersionMax = 5
38)
39
40// messageType is an integer ID of a type of message that can be received
41// on network channels from other members.
42type messageType uint8
43
44// The list of available message types.
45const (
46	pingMsg messageType = iota
47	indirectPingMsg
48	ackRespMsg
49	suspectMsg
50	aliveMsg
51	deadMsg
52	pushPullMsg
53	compoundMsg
54	userMsg // User mesg, not handled by us
55	compressMsg
56	encryptMsg
57	nackRespMsg
58	hasCrcMsg
59	errMsg
60)
61
62// compressionType is used to specify the compression algorithm
63type compressionType uint8
64
65const (
66	lzwAlgo compressionType = iota
67)
68
69const (
70	MetaMaxSize            = 512 // Maximum size for node meta data
71	compoundHeaderOverhead = 2   // Assumed header overhead
72	compoundOverhead       = 2   // Assumed overhead per entry in compoundHeader
73	userMsgOverhead        = 1
74	blockingWarning        = 10 * time.Millisecond // Warn if a UDP packet takes this long to process
75	maxPushStateBytes      = 20 * 1024 * 1024
76	maxPushPullRequests    = 128 // Maximum number of concurrent push/pull requests
77)
78
79// ping request sent directly to node
80type ping struct {
81	SeqNo uint32
82
83	// Node is sent so the target can verify they are
84	// the intended recipient. This is to protect again an agent
85	// restart with a new name.
86	Node string
87
88	SourceAddr []byte `codec:",omitempty"` // Source address, used for a direct reply
89	SourcePort uint16 `codec:",omitempty"` // Source port, used for a direct reply
90	SourceNode string `codec:",omitempty"` // Source name, used for a direct reply
91}
92
93// indirect ping sent to an indirect node
94type indirectPingReq struct {
95	SeqNo  uint32
96	Target []byte
97	Port   uint16
98
99	// Node is sent so the target can verify they are
100	// the intended recipient. This is to protect against an agent
101	// restart with a new name.
102	Node string
103
104	Nack bool // true if we'd like a nack back
105
106	SourceAddr []byte `codec:",omitempty"` // Source address, used for a direct reply
107	SourcePort uint16 `codec:",omitempty"` // Source port, used for a direct reply
108	SourceNode string `codec:",omitempty"` // Source name, used for a direct reply
109}
110
111// ack response is sent for a ping
112type ackResp struct {
113	SeqNo   uint32
114	Payload []byte
115}
116
117// nack response is sent for an indirect ping when the pinger doesn't hear from
118// the ping-ee within the configured timeout. This lets the original node know
119// that the indirect ping attempt happened but didn't succeed.
120type nackResp struct {
121	SeqNo uint32
122}
123
124// err response is sent to relay the error from the remote end
125type errResp struct {
126	Error string
127}
128
129// suspect is broadcast when we suspect a node is dead
130type suspect struct {
131	Incarnation uint32
132	Node        string
133	From        string // Include who is suspecting
134}
135
136// alive is broadcast when we know a node is alive.
137// Overloaded for nodes joining
138type alive struct {
139	Incarnation uint32
140	Node        string
141	Addr        []byte
142	Port        uint16
143	Meta        []byte
144
145	// The versions of the protocol/delegate that are being spoken, order:
146	// pmin, pmax, pcur, dmin, dmax, dcur
147	Vsn []uint8
148}
149
150// dead is broadcast when we confirm a node is dead
151// Overloaded for nodes leaving
152type dead struct {
153	Incarnation uint32
154	Node        string
155	From        string // Include who is suspecting
156}
157
158// pushPullHeader is used to inform the
159// otherside how many states we are transferring
160type pushPullHeader struct {
161	Nodes        int
162	UserStateLen int  // Encodes the byte lengh of user state
163	Join         bool // Is this a join request or a anti-entropy run
164}
165
166// userMsgHeader is used to encapsulate a userMsg
167type userMsgHeader struct {
168	UserMsgLen int // Encodes the byte lengh of user state
169}
170
171// pushNodeState is used for pushPullReq when we are
172// transferring out node states
173type pushNodeState struct {
174	Name        string
175	Addr        []byte
176	Port        uint16
177	Meta        []byte
178	Incarnation uint32
179	State       NodeStateType
180	Vsn         []uint8 // Protocol versions
181}
182
183// compress is used to wrap an underlying payload
184// using a specified compression algorithm
185type compress struct {
186	Algo compressionType
187	Buf  []byte
188}
189
190// msgHandoff is used to transfer a message between goroutines
191type msgHandoff struct {
192	msgType messageType
193	buf     []byte
194	from    net.Addr
195}
196
197// encryptionVersion returns the encryption version to use
198func (m *Memberlist) encryptionVersion() encryptionVersion {
199	switch m.ProtocolVersion() {
200	case 1:
201		return 0
202	default:
203		return 1
204	}
205}
206
207// streamListen is a long running goroutine that pulls incoming streams from the
208// transport and hands them off for processing.
209func (m *Memberlist) streamListen() {
210	for {
211		select {
212		case conn := <-m.transport.StreamCh():
213			go m.handleConn(conn)
214
215		case <-m.shutdownCh:
216			return
217		}
218	}
219}
220
221// handleConn handles a single incoming stream connection from the transport.
222func (m *Memberlist) handleConn(conn net.Conn) {
223	defer conn.Close()
224	m.logger.Printf("[DEBUG] memberlist: Stream connection %s", LogConn(conn))
225
226	metrics.IncrCounter([]string{"memberlist", "tcp", "accept"}, 1)
227
228	conn.SetDeadline(time.Now().Add(m.config.TCPTimeout))
229	msgType, bufConn, dec, err := m.readStream(conn)
230	if err != nil {
231		if err != io.EOF {
232			m.logger.Printf("[ERR] memberlist: failed to receive: %s %s", err, LogConn(conn))
233
234			resp := errResp{err.Error()}
235			out, err := encode(errMsg, &resp)
236			if err != nil {
237				m.logger.Printf("[ERR] memberlist: Failed to encode error response: %s", err)
238				return
239			}
240
241			err = m.rawSendMsgStream(conn, out.Bytes())
242			if err != nil {
243				m.logger.Printf("[ERR] memberlist: Failed to send error: %s %s", err, LogConn(conn))
244				return
245			}
246		}
247		return
248	}
249
250	switch msgType {
251	case userMsg:
252		if err := m.readUserMsg(bufConn, dec); err != nil {
253			m.logger.Printf("[ERR] memberlist: Failed to receive user message: %s %s", err, LogConn(conn))
254		}
255	case pushPullMsg:
256		// Increment counter of pending push/pulls
257		numConcurrent := atomic.AddUint32(&m.pushPullReq, 1)
258		defer atomic.AddUint32(&m.pushPullReq, ^uint32(0))
259
260		// Check if we have too many open push/pull requests
261		if numConcurrent >= maxPushPullRequests {
262			m.logger.Printf("[ERR] memberlist: Too many pending push/pull requests")
263			return
264		}
265
266		join, remoteNodes, userState, err := m.readRemoteState(bufConn, dec)
267		if err != nil {
268			m.logger.Printf("[ERR] memberlist: Failed to read remote state: %s %s", err, LogConn(conn))
269			return
270		}
271
272		if err := m.sendLocalState(conn, join); err != nil {
273			m.logger.Printf("[ERR] memberlist: Failed to push local state: %s %s", err, LogConn(conn))
274			return
275		}
276
277		if err := m.mergeRemoteState(join, remoteNodes, userState); err != nil {
278			m.logger.Printf("[ERR] memberlist: Failed push/pull merge: %s %s", err, LogConn(conn))
279			return
280		}
281	case pingMsg:
282		var p ping
283		if err := dec.Decode(&p); err != nil {
284			m.logger.Printf("[ERR] memberlist: Failed to decode ping: %s %s", err, LogConn(conn))
285			return
286		}
287
288		if p.Node != "" && p.Node != m.config.Name {
289			m.logger.Printf("[WARN] memberlist: Got ping for unexpected node %s %s", p.Node, LogConn(conn))
290			return
291		}
292
293		ack := ackResp{p.SeqNo, nil}
294		out, err := encode(ackRespMsg, &ack)
295		if err != nil {
296			m.logger.Printf("[ERR] memberlist: Failed to encode ack: %s", err)
297			return
298		}
299
300		err = m.rawSendMsgStream(conn, out.Bytes())
301		if err != nil {
302			m.logger.Printf("[ERR] memberlist: Failed to send ack: %s %s", err, LogConn(conn))
303			return
304		}
305	default:
306		m.logger.Printf("[ERR] memberlist: Received invalid msgType (%d) %s", msgType, LogConn(conn))
307	}
308}
309
310// packetListen is a long running goroutine that pulls packets out of the
311// transport and hands them off for processing.
312func (m *Memberlist) packetListen() {
313	for {
314		select {
315		case packet := <-m.transport.PacketCh():
316			m.ingestPacket(packet.Buf, packet.From, packet.Timestamp)
317
318		case <-m.shutdownCh:
319			return
320		}
321	}
322}
323
324func (m *Memberlist) ingestPacket(buf []byte, from net.Addr, timestamp time.Time) {
325	// Check if encryption is enabled
326	if m.config.EncryptionEnabled() {
327		// Decrypt the payload
328		plain, err := decryptPayload(m.config.Keyring.GetKeys(), buf, nil)
329		if err != nil {
330			if !m.config.GossipVerifyIncoming {
331				// Treat the message as plaintext
332				plain = buf
333			} else {
334				m.logger.Printf("[ERR] memberlist: Decrypt packet failed: %v %s", err, LogAddress(from))
335				return
336			}
337		}
338
339		// Continue processing the plaintext buffer
340		buf = plain
341	}
342
343	// See if there's a checksum included to verify the contents of the message
344	if len(buf) >= 5 && messageType(buf[0]) == hasCrcMsg {
345		crc := crc32.ChecksumIEEE(buf[5:])
346		expected := binary.BigEndian.Uint32(buf[1:5])
347		if crc != expected {
348			m.logger.Printf("[WARN] memberlist: Got invalid checksum for UDP packet: %x, %x", crc, expected)
349			return
350		}
351		m.handleCommand(buf[5:], from, timestamp)
352	} else {
353		m.handleCommand(buf, from, timestamp)
354	}
355}
356
357func (m *Memberlist) handleCommand(buf []byte, from net.Addr, timestamp time.Time) {
358	// Decode the message type
359	msgType := messageType(buf[0])
360	buf = buf[1:]
361
362	// Switch on the msgType
363	switch msgType {
364	case compoundMsg:
365		m.handleCompound(buf, from, timestamp)
366	case compressMsg:
367		m.handleCompressed(buf, from, timestamp)
368
369	case pingMsg:
370		m.handlePing(buf, from)
371	case indirectPingMsg:
372		m.handleIndirectPing(buf, from)
373	case ackRespMsg:
374		m.handleAck(buf, from, timestamp)
375	case nackRespMsg:
376		m.handleNack(buf, from)
377
378	case suspectMsg:
379		fallthrough
380	case aliveMsg:
381		fallthrough
382	case deadMsg:
383		fallthrough
384	case userMsg:
385		// Determine the message queue, prioritize alive
386		queue := m.lowPriorityMsgQueue
387		if msgType == aliveMsg {
388			queue = m.highPriorityMsgQueue
389		}
390
391		// Check for overflow and append if not full
392		m.msgQueueLock.Lock()
393		if queue.Len() >= m.config.HandoffQueueDepth {
394			m.logger.Printf("[WARN] memberlist: handler queue full, dropping message (%d) %s", msgType, LogAddress(from))
395		} else {
396			queue.PushBack(msgHandoff{msgType, buf, from})
397		}
398		m.msgQueueLock.Unlock()
399
400		// Notify of pending message
401		select {
402		case m.handoffCh <- struct{}{}:
403		default:
404		}
405
406	default:
407		m.logger.Printf("[ERR] memberlist: msg type (%d) not supported %s", msgType, LogAddress(from))
408	}
409}
410
411// getNextMessage returns the next message to process in priority order, using LIFO
412func (m *Memberlist) getNextMessage() (msgHandoff, bool) {
413	m.msgQueueLock.Lock()
414	defer m.msgQueueLock.Unlock()
415
416	if el := m.highPriorityMsgQueue.Back(); el != nil {
417		m.highPriorityMsgQueue.Remove(el)
418		msg := el.Value.(msgHandoff)
419		return msg, true
420	} else if el := m.lowPriorityMsgQueue.Back(); el != nil {
421		m.lowPriorityMsgQueue.Remove(el)
422		msg := el.Value.(msgHandoff)
423		return msg, true
424	}
425	return msgHandoff{}, false
426}
427
428// packetHandler is a long running goroutine that processes messages received
429// over the packet interface, but is decoupled from the listener to avoid
430// blocking the listener which may cause ping/ack messages to be delayed.
431func (m *Memberlist) packetHandler() {
432	for {
433		select {
434		case <-m.handoffCh:
435			for {
436				msg, ok := m.getNextMessage()
437				if !ok {
438					break
439				}
440				msgType := msg.msgType
441				buf := msg.buf
442				from := msg.from
443
444				switch msgType {
445				case suspectMsg:
446					m.handleSuspect(buf, from)
447				case aliveMsg:
448					m.handleAlive(buf, from)
449				case deadMsg:
450					m.handleDead(buf, from)
451				case userMsg:
452					m.handleUser(buf, from)
453				default:
454					m.logger.Printf("[ERR] memberlist: Message type (%d) not supported %s (packet handler)", msgType, LogAddress(from))
455				}
456			}
457
458		case <-m.shutdownCh:
459			return
460		}
461	}
462}
463
464func (m *Memberlist) handleCompound(buf []byte, from net.Addr, timestamp time.Time) {
465	// Decode the parts
466	trunc, parts, err := decodeCompoundMessage(buf)
467	if err != nil {
468		m.logger.Printf("[ERR] memberlist: Failed to decode compound request: %s %s", err, LogAddress(from))
469		return
470	}
471
472	// Log any truncation
473	if trunc > 0 {
474		m.logger.Printf("[WARN] memberlist: Compound request had %d truncated messages %s", trunc, LogAddress(from))
475	}
476
477	// Handle each message
478	for _, part := range parts {
479		m.handleCommand(part, from, timestamp)
480	}
481}
482
483func (m *Memberlist) handlePing(buf []byte, from net.Addr) {
484	var p ping
485	if err := decode(buf, &p); err != nil {
486		m.logger.Printf("[ERR] memberlist: Failed to decode ping request: %s %s", err, LogAddress(from))
487		return
488	}
489	// If node is provided, verify that it is for us
490	if p.Node != "" && p.Node != m.config.Name {
491		m.logger.Printf("[WARN] memberlist: Got ping for unexpected node '%s' %s", p.Node, LogAddress(from))
492		return
493	}
494	var ack ackResp
495	ack.SeqNo = p.SeqNo
496	if m.config.Ping != nil {
497		ack.Payload = m.config.Ping.AckPayload()
498	}
499
500	addr := ""
501	if len(p.SourceAddr) > 0 && p.SourcePort > 0 {
502		addr = joinHostPort(net.IP(p.SourceAddr).String(), p.SourcePort)
503	} else {
504		addr = from.String()
505	}
506
507	a := Address{
508		Addr: addr,
509		Name: p.SourceNode,
510	}
511	if err := m.encodeAndSendMsg(a, ackRespMsg, &ack); err != nil {
512		m.logger.Printf("[ERR] memberlist: Failed to send ack: %s %s", err, LogAddress(from))
513	}
514}
515
516func (m *Memberlist) handleIndirectPing(buf []byte, from net.Addr) {
517	var ind indirectPingReq
518	if err := decode(buf, &ind); err != nil {
519		m.logger.Printf("[ERR] memberlist: Failed to decode indirect ping request: %s %s", err, LogAddress(from))
520		return
521	}
522
523	// For proto versions < 2, there is no port provided. Mask old
524	// behavior by using the configured port.
525	if m.ProtocolVersion() < 2 || ind.Port == 0 {
526		ind.Port = uint16(m.config.BindPort)
527	}
528
529	// Send a ping to the correct host.
530	localSeqNo := m.nextSeqNo()
531	selfAddr, selfPort := m.getAdvertise()
532	ping := ping{
533		SeqNo: localSeqNo,
534		Node:  ind.Node,
535		// The outbound message is addressed FROM us.
536		SourceAddr: selfAddr,
537		SourcePort: selfPort,
538		SourceNode: m.config.Name,
539	}
540
541	// Forward the ack back to the requestor. If the request encodes an origin
542	// use that otherwise assume that the other end of the UDP socket is
543	// usable.
544	indAddr := ""
545	if len(ind.SourceAddr) > 0 && ind.SourcePort > 0 {
546		indAddr = joinHostPort(net.IP(ind.SourceAddr).String(), ind.SourcePort)
547	} else {
548		indAddr = from.String()
549	}
550
551	// Setup a response handler to relay the ack
552	cancelCh := make(chan struct{})
553	respHandler := func(payload []byte, timestamp time.Time) {
554		// Try to prevent the nack if we've caught it in time.
555		close(cancelCh)
556
557		ack := ackResp{ind.SeqNo, nil}
558		a := Address{
559			Addr: indAddr,
560			Name: ind.SourceNode,
561		}
562		if err := m.encodeAndSendMsg(a, ackRespMsg, &ack); err != nil {
563			m.logger.Printf("[ERR] memberlist: Failed to forward ack: %s %s", err, LogStringAddress(indAddr))
564		}
565	}
566	m.setAckHandler(localSeqNo, respHandler, m.config.ProbeTimeout)
567
568	// Send the ping.
569	addr := joinHostPort(net.IP(ind.Target).String(), ind.Port)
570	a := Address{
571		Addr: addr,
572		Name: ind.Node,
573	}
574	if err := m.encodeAndSendMsg(a, pingMsg, &ping); err != nil {
575		m.logger.Printf("[ERR] memberlist: Failed to send indirect ping: %s %s", err, LogStringAddress(indAddr))
576	}
577
578	// Setup a timer to fire off a nack if no ack is seen in time.
579	if ind.Nack {
580		go func() {
581			select {
582			case <-cancelCh:
583				return
584			case <-time.After(m.config.ProbeTimeout):
585				nack := nackResp{ind.SeqNo}
586				a := Address{
587					Addr: indAddr,
588					Name: ind.SourceNode,
589				}
590				if err := m.encodeAndSendMsg(a, nackRespMsg, &nack); err != nil {
591					m.logger.Printf("[ERR] memberlist: Failed to send nack: %s %s", err, LogStringAddress(indAddr))
592				}
593			}
594		}()
595	}
596}
597
598func (m *Memberlist) handleAck(buf []byte, from net.Addr, timestamp time.Time) {
599	var ack ackResp
600	if err := decode(buf, &ack); err != nil {
601		m.logger.Printf("[ERR] memberlist: Failed to decode ack response: %s %s", err, LogAddress(from))
602		return
603	}
604	m.invokeAckHandler(ack, timestamp)
605}
606
607func (m *Memberlist) handleNack(buf []byte, from net.Addr) {
608	var nack nackResp
609	if err := decode(buf, &nack); err != nil {
610		m.logger.Printf("[ERR] memberlist: Failed to decode nack response: %s %s", err, LogAddress(from))
611		return
612	}
613	m.invokeNackHandler(nack)
614}
615
616func (m *Memberlist) handleSuspect(buf []byte, from net.Addr) {
617	var sus suspect
618	if err := decode(buf, &sus); err != nil {
619		m.logger.Printf("[ERR] memberlist: Failed to decode suspect message: %s %s", err, LogAddress(from))
620		return
621	}
622	m.suspectNode(&sus)
623}
624
625// ensureCanConnect return the IP from a RemoteAddress
626// return error if this client must not connect
627func (m *Memberlist) ensureCanConnect(from net.Addr) error {
628	if !m.config.IPMustBeChecked() {
629		return nil
630	}
631	source := from.String()
632	if source == "pipe" {
633		return nil
634	}
635	host, _, err := net.SplitHostPort(source)
636	if err != nil {
637		return err
638	}
639
640	ip := net.ParseIP(host)
641	if ip == nil {
642		return fmt.Errorf("Cannot parse IP from %s", host)
643	}
644	return m.config.IPAllowed(ip)
645}
646
647func (m *Memberlist) handleAlive(buf []byte, from net.Addr) {
648	if err := m.ensureCanConnect(from); err != nil {
649		m.logger.Printf("[DEBUG] memberlist: Blocked alive message: %s %s", err, LogAddress(from))
650		return
651	}
652	var live alive
653	if err := decode(buf, &live); err != nil {
654		m.logger.Printf("[ERR] memberlist: Failed to decode alive message: %s %s", err, LogAddress(from))
655		return
656	}
657	if m.config.IPMustBeChecked() {
658		innerIP := net.IP(live.Addr)
659		if innerIP != nil {
660			if err := m.config.IPAllowed(innerIP); err != nil {
661				m.logger.Printf("[DEBUG] memberlist: Blocked alive.Addr=%s message from: %s %s", innerIP.String(), err, LogAddress(from))
662				return
663			}
664		}
665	}
666
667	// For proto versions < 2, there is no port provided. Mask old
668	// behavior by using the configured port
669	if m.ProtocolVersion() < 2 || live.Port == 0 {
670		live.Port = uint16(m.config.BindPort)
671	}
672
673	m.aliveNode(&live, nil, false)
674}
675
676func (m *Memberlist) handleDead(buf []byte, from net.Addr) {
677	var d dead
678	if err := decode(buf, &d); err != nil {
679		m.logger.Printf("[ERR] memberlist: Failed to decode dead message: %s %s", err, LogAddress(from))
680		return
681	}
682	m.deadNode(&d)
683}
684
685// handleUser is used to notify channels of incoming user data
686func (m *Memberlist) handleUser(buf []byte, from net.Addr) {
687	d := m.config.Delegate
688	if d != nil {
689		d.NotifyMsg(buf)
690	}
691}
692
693// handleCompressed is used to unpack a compressed message
694func (m *Memberlist) handleCompressed(buf []byte, from net.Addr, timestamp time.Time) {
695	// Try to decode the payload
696	payload, err := decompressPayload(buf)
697	if err != nil {
698		m.logger.Printf("[ERR] memberlist: Failed to decompress payload: %v %s", err, LogAddress(from))
699		return
700	}
701
702	// Recursively handle the payload
703	m.handleCommand(payload, from, timestamp)
704}
705
706// encodeAndSendMsg is used to combine the encoding and sending steps
707func (m *Memberlist) encodeAndSendMsg(a Address, msgType messageType, msg interface{}) error {
708	out, err := encode(msgType, msg)
709	if err != nil {
710		return err
711	}
712	if err := m.sendMsg(a, out.Bytes()); err != nil {
713		return err
714	}
715	return nil
716}
717
718// sendMsg is used to send a message via packet to another host. It will
719// opportunistically create a compoundMsg and piggy back other broadcasts.
720func (m *Memberlist) sendMsg(a Address, msg []byte) error {
721	// Check if we can piggy back any messages
722	bytesAvail := m.config.UDPBufferSize - len(msg) - compoundHeaderOverhead
723	if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing {
724		bytesAvail -= encryptOverhead(m.encryptionVersion())
725	}
726	extra := m.getBroadcasts(compoundOverhead, bytesAvail)
727
728	// Fast path if nothing to piggypack
729	if len(extra) == 0 {
730		return m.rawSendMsgPacket(a, nil, msg)
731	}
732
733	// Join all the messages
734	msgs := make([][]byte, 0, 1+len(extra))
735	msgs = append(msgs, msg)
736	msgs = append(msgs, extra...)
737
738	// Create a compound message
739	compound := makeCompoundMessage(msgs)
740
741	// Send the message
742	return m.rawSendMsgPacket(a, nil, compound.Bytes())
743}
744
745// rawSendMsgPacket is used to send message via packet to another host without
746// modification, other than compression or encryption if enabled.
747func (m *Memberlist) rawSendMsgPacket(a Address, node *Node, msg []byte) error {
748	if a.Name == "" && m.config.RequireNodeNames {
749		return errNodeNamesAreRequired
750	}
751
752	// Check if we have compression enabled
753	if m.config.EnableCompression {
754		buf, err := compressPayload(msg)
755		if err != nil {
756			m.logger.Printf("[WARN] memberlist: Failed to compress payload: %v", err)
757		} else {
758			// Only use compression if it reduced the size
759			if buf.Len() < len(msg) {
760				msg = buf.Bytes()
761			}
762		}
763	}
764
765	// Try to look up the destination node. Note this will only work if the
766	// bare ip address is used as the node name, which is not guaranteed.
767	if node == nil {
768		toAddr, _, err := net.SplitHostPort(a.Addr)
769		if err != nil {
770			m.logger.Printf("[ERR] memberlist: Failed to parse address %q: %v", a.Addr, err)
771			return err
772		}
773		m.nodeLock.RLock()
774		nodeState, ok := m.nodeMap[toAddr]
775		m.nodeLock.RUnlock()
776		if ok {
777			node = &nodeState.Node
778		}
779	}
780
781	// Add a CRC to the end of the payload if the recipient understands
782	// ProtocolVersion >= 5
783	if node != nil && node.PMax >= 5 {
784		crc := crc32.ChecksumIEEE(msg)
785		header := make([]byte, 5, 5+len(msg))
786		header[0] = byte(hasCrcMsg)
787		binary.BigEndian.PutUint32(header[1:], crc)
788		msg = append(header, msg...)
789	}
790
791	// Check if we have encryption enabled
792	if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing {
793		// Encrypt the payload
794		var buf bytes.Buffer
795		primaryKey := m.config.Keyring.GetPrimaryKey()
796		err := encryptPayload(m.encryptionVersion(), primaryKey, msg, nil, &buf)
797		if err != nil {
798			m.logger.Printf("[ERR] memberlist: Encryption of message failed: %v", err)
799			return err
800		}
801		msg = buf.Bytes()
802	}
803
804	metrics.IncrCounter([]string{"memberlist", "udp", "sent"}, float32(len(msg)))
805	_, err := m.transport.WriteToAddress(msg, a)
806	return err
807}
808
809// rawSendMsgStream is used to stream a message to another host without
810// modification, other than applying compression and encryption if enabled.
811func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte) error {
812	// Check if compression is enabled
813	if m.config.EnableCompression {
814		compBuf, err := compressPayload(sendBuf)
815		if err != nil {
816			m.logger.Printf("[ERROR] memberlist: Failed to compress payload: %v", err)
817		} else {
818			sendBuf = compBuf.Bytes()
819		}
820	}
821
822	// Check if encryption is enabled
823	if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing {
824		crypt, err := m.encryptLocalState(sendBuf)
825		if err != nil {
826			m.logger.Printf("[ERROR] memberlist: Failed to encrypt local state: %v", err)
827			return err
828		}
829		sendBuf = crypt
830	}
831
832	// Write out the entire send buffer
833	metrics.IncrCounter([]string{"memberlist", "tcp", "sent"}, float32(len(sendBuf)))
834
835	if n, err := conn.Write(sendBuf); err != nil {
836		return err
837	} else if n != len(sendBuf) {
838		return fmt.Errorf("only %d of %d bytes written", n, len(sendBuf))
839	}
840
841	return nil
842}
843
844// sendUserMsg is used to stream a user message to another host.
845func (m *Memberlist) sendUserMsg(a Address, sendBuf []byte) error {
846	if a.Name == "" && m.config.RequireNodeNames {
847		return errNodeNamesAreRequired
848	}
849
850	conn, err := m.transport.DialAddressTimeout(a, m.config.TCPTimeout)
851	if err != nil {
852		return err
853	}
854	defer conn.Close()
855
856	bufConn := bytes.NewBuffer(nil)
857	if err := bufConn.WriteByte(byte(userMsg)); err != nil {
858		return err
859	}
860
861	header := userMsgHeader{UserMsgLen: len(sendBuf)}
862	hd := codec.MsgpackHandle{}
863	enc := codec.NewEncoder(bufConn, &hd)
864	if err := enc.Encode(&header); err != nil {
865		return err
866	}
867	if _, err := bufConn.Write(sendBuf); err != nil {
868		return err
869	}
870	return m.rawSendMsgStream(conn, bufConn.Bytes())
871}
872
873// sendAndReceiveState is used to initiate a push/pull over a stream with a
874// remote host.
875func (m *Memberlist) sendAndReceiveState(a Address, join bool) ([]pushNodeState, []byte, error) {
876	if a.Name == "" && m.config.RequireNodeNames {
877		return nil, nil, errNodeNamesAreRequired
878	}
879
880	// Attempt to connect
881	conn, err := m.transport.DialAddressTimeout(a, m.config.TCPTimeout)
882	if err != nil {
883		return nil, nil, err
884	}
885	defer conn.Close()
886	m.logger.Printf("[DEBUG] memberlist: Initiating push/pull sync with: %s %s", a.Name, conn.RemoteAddr())
887	metrics.IncrCounter([]string{"memberlist", "tcp", "connect"}, 1)
888
889	// Send our state
890	if err := m.sendLocalState(conn, join); err != nil {
891		return nil, nil, err
892	}
893
894	conn.SetDeadline(time.Now().Add(m.config.TCPTimeout))
895	msgType, bufConn, dec, err := m.readStream(conn)
896	if err != nil {
897		return nil, nil, err
898	}
899
900	if msgType == errMsg {
901		var resp errResp
902		if err := dec.Decode(&resp); err != nil {
903			return nil, nil, err
904		}
905		return nil, nil, fmt.Errorf("remote error: %v", resp.Error)
906	}
907
908	// Quit if not push/pull
909	if msgType != pushPullMsg {
910		err := fmt.Errorf("received invalid msgType (%d), expected pushPullMsg (%d) %s", msgType, pushPullMsg, LogConn(conn))
911		return nil, nil, err
912	}
913
914	// Read remote state
915	_, remoteNodes, userState, err := m.readRemoteState(bufConn, dec)
916	return remoteNodes, userState, err
917}
918
919// sendLocalState is invoked to send our local state over a stream connection.
920func (m *Memberlist) sendLocalState(conn net.Conn, join bool) error {
921	// Setup a deadline
922	conn.SetDeadline(time.Now().Add(m.config.TCPTimeout))
923
924	// Prepare the local node state
925	m.nodeLock.RLock()
926	localNodes := make([]pushNodeState, len(m.nodes))
927	for idx, n := range m.nodes {
928		localNodes[idx].Name = n.Name
929		localNodes[idx].Addr = n.Addr
930		localNodes[idx].Port = n.Port
931		localNodes[idx].Incarnation = n.Incarnation
932		localNodes[idx].State = n.State
933		localNodes[idx].Meta = n.Meta
934		localNodes[idx].Vsn = []uint8{
935			n.PMin, n.PMax, n.PCur,
936			n.DMin, n.DMax, n.DCur,
937		}
938	}
939	m.nodeLock.RUnlock()
940
941	// Get the delegate state
942	var userData []byte
943	if m.config.Delegate != nil {
944		userData = m.config.Delegate.LocalState(join)
945	}
946
947	// Create a bytes buffer writer
948	bufConn := bytes.NewBuffer(nil)
949
950	// Send our node state
951	header := pushPullHeader{Nodes: len(localNodes), UserStateLen: len(userData), Join: join}
952	hd := codec.MsgpackHandle{}
953	enc := codec.NewEncoder(bufConn, &hd)
954
955	// Begin state push
956	if _, err := bufConn.Write([]byte{byte(pushPullMsg)}); err != nil {
957		return err
958	}
959
960	if err := enc.Encode(&header); err != nil {
961		return err
962	}
963	for i := 0; i < header.Nodes; i++ {
964		if err := enc.Encode(&localNodes[i]); err != nil {
965			return err
966		}
967	}
968
969	// Write the user state as well
970	if userData != nil {
971		if _, err := bufConn.Write(userData); err != nil {
972			return err
973		}
974	}
975
976	// Get the send buffer
977	return m.rawSendMsgStream(conn, bufConn.Bytes())
978}
979
980// encryptLocalState is used to help encrypt local state before sending
981func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) {
982	var buf bytes.Buffer
983
984	// Write the encryptMsg byte
985	buf.WriteByte(byte(encryptMsg))
986
987	// Write the size of the message
988	sizeBuf := make([]byte, 4)
989	encVsn := m.encryptionVersion()
990	encLen := encryptedLength(encVsn, len(sendBuf))
991	binary.BigEndian.PutUint32(sizeBuf, uint32(encLen))
992	buf.Write(sizeBuf)
993
994	// Write the encrypted cipher text to the buffer
995	key := m.config.Keyring.GetPrimaryKey()
996	err := encryptPayload(encVsn, key, sendBuf, buf.Bytes()[:5], &buf)
997	if err != nil {
998		return nil, err
999	}
1000	return buf.Bytes(), nil
1001}
1002
1003// decryptRemoteState is used to help decrypt the remote state
1004func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) {
1005	// Read in enough to determine message length
1006	cipherText := bytes.NewBuffer(nil)
1007	cipherText.WriteByte(byte(encryptMsg))
1008	_, err := io.CopyN(cipherText, bufConn, 4)
1009	if err != nil {
1010		return nil, err
1011	}
1012
1013	// Ensure we aren't asked to download too much. This is to guard against
1014	// an attack vector where a huge amount of state is sent
1015	moreBytes := binary.BigEndian.Uint32(cipherText.Bytes()[1:5])
1016	if moreBytes > maxPushStateBytes {
1017		return nil, fmt.Errorf("Remote node state is larger than limit (%d)", moreBytes)
1018	}
1019
1020	// Read in the rest of the payload
1021	_, err = io.CopyN(cipherText, bufConn, int64(moreBytes))
1022	if err != nil {
1023		return nil, err
1024	}
1025
1026	// Decrypt the cipherText
1027	dataBytes := cipherText.Bytes()[:5]
1028	cipherBytes := cipherText.Bytes()[5:]
1029
1030	// Decrypt the payload
1031	keys := m.config.Keyring.GetKeys()
1032	return decryptPayload(keys, cipherBytes, dataBytes)
1033}
1034
1035// readStream is used to read from a stream connection, decrypting and
1036// decompressing the stream if necessary.
1037func (m *Memberlist) readStream(conn net.Conn) (messageType, io.Reader, *codec.Decoder, error) {
1038	// Created a buffered reader
1039	var bufConn io.Reader = bufio.NewReader(conn)
1040
1041	// Read the message type
1042	buf := [1]byte{0}
1043	if _, err := bufConn.Read(buf[:]); err != nil {
1044		return 0, nil, nil, err
1045	}
1046	msgType := messageType(buf[0])
1047
1048	// Check if the message is encrypted
1049	if msgType == encryptMsg {
1050		if !m.config.EncryptionEnabled() {
1051			return 0, nil, nil,
1052				fmt.Errorf("Remote state is encrypted and encryption is not configured")
1053		}
1054
1055		plain, err := m.decryptRemoteState(bufConn)
1056		if err != nil {
1057			return 0, nil, nil, err
1058		}
1059
1060		// Reset message type and bufConn
1061		msgType = messageType(plain[0])
1062		bufConn = bytes.NewReader(plain[1:])
1063	} else if m.config.EncryptionEnabled() && m.config.GossipVerifyIncoming {
1064		return 0, nil, nil,
1065			fmt.Errorf("Encryption is configured but remote state is not encrypted")
1066	}
1067
1068	// Get the msgPack decoders
1069	hd := codec.MsgpackHandle{}
1070	dec := codec.NewDecoder(bufConn, &hd)
1071
1072	// Check if we have a compressed message
1073	if msgType == compressMsg {
1074		var c compress
1075		if err := dec.Decode(&c); err != nil {
1076			return 0, nil, nil, err
1077		}
1078		decomp, err := decompressBuffer(&c)
1079		if err != nil {
1080			return 0, nil, nil, err
1081		}
1082
1083		// Reset the message type
1084		msgType = messageType(decomp[0])
1085
1086		// Create a new bufConn
1087		bufConn = bytes.NewReader(decomp[1:])
1088
1089		// Create a new decoder
1090		dec = codec.NewDecoder(bufConn, &hd)
1091	}
1092
1093	return msgType, bufConn, dec, nil
1094}
1095
1096// readRemoteState is used to read the remote state from a connection
1097func (m *Memberlist) readRemoteState(bufConn io.Reader, dec *codec.Decoder) (bool, []pushNodeState, []byte, error) {
1098	// Read the push/pull header
1099	var header pushPullHeader
1100	if err := dec.Decode(&header); err != nil {
1101		return false, nil, nil, err
1102	}
1103
1104	// Allocate space for the transfer
1105	remoteNodes := make([]pushNodeState, header.Nodes)
1106
1107	// Try to decode all the states
1108	for i := 0; i < header.Nodes; i++ {
1109		if err := dec.Decode(&remoteNodes[i]); err != nil {
1110			return false, nil, nil, err
1111		}
1112	}
1113
1114	// Read the remote user state into a buffer
1115	var userBuf []byte
1116	if header.UserStateLen > 0 {
1117		userBuf = make([]byte, header.UserStateLen)
1118		bytes, err := io.ReadAtLeast(bufConn, userBuf, header.UserStateLen)
1119		if err == nil && bytes != header.UserStateLen {
1120			err = fmt.Errorf(
1121				"Failed to read full user state (%d / %d)",
1122				bytes, header.UserStateLen)
1123		}
1124		if err != nil {
1125			return false, nil, nil, err
1126		}
1127	}
1128
1129	// For proto versions < 2, there is no port provided. Mask old
1130	// behavior by using the configured port
1131	for idx := range remoteNodes {
1132		if m.ProtocolVersion() < 2 || remoteNodes[idx].Port == 0 {
1133			remoteNodes[idx].Port = uint16(m.config.BindPort)
1134		}
1135	}
1136
1137	return header.Join, remoteNodes, userBuf, nil
1138}
1139
1140// mergeRemoteState is used to merge the remote state with our local state
1141func (m *Memberlist) mergeRemoteState(join bool, remoteNodes []pushNodeState, userBuf []byte) error {
1142	if err := m.verifyProtocol(remoteNodes); err != nil {
1143		return err
1144	}
1145
1146	// Invoke the merge delegate if any
1147	if join && m.config.Merge != nil {
1148		nodes := make([]*Node, len(remoteNodes))
1149		for idx, n := range remoteNodes {
1150			nodes[idx] = &Node{
1151				Name:  n.Name,
1152				Addr:  n.Addr,
1153				Port:  n.Port,
1154				Meta:  n.Meta,
1155				State: n.State,
1156				PMin:  n.Vsn[0],
1157				PMax:  n.Vsn[1],
1158				PCur:  n.Vsn[2],
1159				DMin:  n.Vsn[3],
1160				DMax:  n.Vsn[4],
1161				DCur:  n.Vsn[5],
1162			}
1163		}
1164		if err := m.config.Merge.NotifyMerge(nodes); err != nil {
1165			return err
1166		}
1167	}
1168
1169	// Merge the membership state
1170	m.mergeState(remoteNodes)
1171
1172	// Invoke the delegate for user state
1173	if userBuf != nil && m.config.Delegate != nil {
1174		m.config.Delegate.MergeRemoteState(userBuf, join)
1175	}
1176	return nil
1177}
1178
1179// readUserMsg is used to decode a userMsg from a stream.
1180func (m *Memberlist) readUserMsg(bufConn io.Reader, dec *codec.Decoder) error {
1181	// Read the user message header
1182	var header userMsgHeader
1183	if err := dec.Decode(&header); err != nil {
1184		return err
1185	}
1186
1187	// Read the user message into a buffer
1188	var userBuf []byte
1189	if header.UserMsgLen > 0 {
1190		userBuf = make([]byte, header.UserMsgLen)
1191		bytes, err := io.ReadAtLeast(bufConn, userBuf, header.UserMsgLen)
1192		if err == nil && bytes != header.UserMsgLen {
1193			err = fmt.Errorf(
1194				"Failed to read full user message (%d / %d)",
1195				bytes, header.UserMsgLen)
1196		}
1197		if err != nil {
1198			return err
1199		}
1200
1201		d := m.config.Delegate
1202		if d != nil {
1203			d.NotifyMsg(userBuf)
1204		}
1205	}
1206
1207	return nil
1208}
1209
1210// sendPingAndWaitForAck makes a stream connection to the given address, sends
1211// a ping, and waits for an ack. All of this is done as a series of blocking
1212// operations, given the deadline. The bool return parameter is true if we
1213// we able to round trip a ping to the other node.
1214func (m *Memberlist) sendPingAndWaitForAck(a Address, ping ping, deadline time.Time) (bool, error) {
1215	if a.Name == "" && m.config.RequireNodeNames {
1216		return false, errNodeNamesAreRequired
1217	}
1218
1219	conn, err := m.transport.DialAddressTimeout(a, deadline.Sub(time.Now()))
1220	if err != nil {
1221		// If the node is actually dead we expect this to fail, so we
1222		// shouldn't spam the logs with it. After this point, errors
1223		// with the connection are real, unexpected errors and should
1224		// get propagated up.
1225		return false, nil
1226	}
1227	defer conn.Close()
1228	conn.SetDeadline(deadline)
1229
1230	out, err := encode(pingMsg, &ping)
1231	if err != nil {
1232		return false, err
1233	}
1234
1235	if err = m.rawSendMsgStream(conn, out.Bytes()); err != nil {
1236		return false, err
1237	}
1238
1239	msgType, _, dec, err := m.readStream(conn)
1240	if err != nil {
1241		return false, err
1242	}
1243
1244	if msgType != ackRespMsg {
1245		return false, fmt.Errorf("Unexpected msgType (%d) from ping %s", msgType, LogConn(conn))
1246	}
1247
1248	var ack ackResp
1249	if err = dec.Decode(&ack); err != nil {
1250		return false, err
1251	}
1252
1253	if ack.SeqNo != ping.SeqNo {
1254		return false, fmt.Errorf("Sequence number from ack (%d) doesn't match ping (%d)", ack.SeqNo, ping.SeqNo)
1255	}
1256
1257	return true, nil
1258}
1259