1package memberlist
2
3import (
4	"bytes"
5	"fmt"
6	"io"
7	"log"
8	"net"
9	"sync"
10	"sync/atomic"
11	"time"
12
13	"github.com/armon/go-metrics"
14	sockaddr "github.com/hashicorp/go-sockaddr"
15)
16
17const (
18	// udpPacketBufSize is used to buffer incoming packets during read
19	// operations.
20	udpPacketBufSize = 65536
21
22	// udpRecvBufSize is a large buffer size that we attempt to set UDP
23	// sockets to in order to handle a large volume of messages.
24	udpRecvBufSize = 2 * 1024 * 1024
25)
26
27// NetTransportConfig is used to configure a net transport.
28type NetTransportConfig struct {
29	// BindAddrs is a list of addresses to bind to for both TCP and UDP
30	// communications.
31	BindAddrs []string
32
33	// BindPort is the port to listen on, for each address above.
34	BindPort int
35
36	// Logger is a logger for operator messages.
37	Logger *log.Logger
38}
39
40// NetTransport is a Transport implementation that uses connectionless UDP for
41// packet operations, and ad-hoc TCP connections for stream operations.
42type NetTransport struct {
43	config       *NetTransportConfig
44	packetCh     chan *Packet
45	streamCh     chan net.Conn
46	logger       *log.Logger
47	wg           sync.WaitGroup
48	tcpListeners []*net.TCPListener
49	udpListeners []*net.UDPConn
50	shutdown     int32
51}
52
53var _ NodeAwareTransport = (*NetTransport)(nil)
54
55// NewNetTransport returns a net transport with the given configuration. On
56// success all the network listeners will be created and listening.
57func NewNetTransport(config *NetTransportConfig) (*NetTransport, error) {
58	// If we reject the empty list outright we can assume that there's at
59	// least one listener of each type later during operation.
60	if len(config.BindAddrs) == 0 {
61		return nil, fmt.Errorf("At least one bind address is required")
62	}
63
64	// Build out the new transport.
65	var ok bool
66	t := NetTransport{
67		config:   config,
68		packetCh: make(chan *Packet),
69		streamCh: make(chan net.Conn),
70		logger:   config.Logger,
71	}
72
73	// Clean up listeners if there's an error.
74	defer func() {
75		if !ok {
76			t.Shutdown()
77		}
78	}()
79
80	// Build all the TCP and UDP listeners.
81	port := config.BindPort
82	for _, addr := range config.BindAddrs {
83		ip := net.ParseIP(addr)
84
85		tcpAddr := &net.TCPAddr{IP: ip, Port: port}
86		tcpLn, err := net.ListenTCP("tcp", tcpAddr)
87		if err != nil {
88			return nil, fmt.Errorf("Failed to start TCP listener on %q port %d: %v", addr, port, err)
89		}
90		t.tcpListeners = append(t.tcpListeners, tcpLn)
91
92		// If the config port given was zero, use the first TCP listener
93		// to pick an available port and then apply that to everything
94		// else.
95		if port == 0 {
96			port = tcpLn.Addr().(*net.TCPAddr).Port
97		}
98
99		udpAddr := &net.UDPAddr{IP: ip, Port: port}
100		udpLn, err := net.ListenUDP("udp", udpAddr)
101		if err != nil {
102			return nil, fmt.Errorf("Failed to start UDP listener on %q port %d: %v", addr, port, err)
103		}
104		if err := setUDPRecvBuf(udpLn); err != nil {
105			return nil, fmt.Errorf("Failed to resize UDP buffer: %v", err)
106		}
107		t.udpListeners = append(t.udpListeners, udpLn)
108	}
109
110	// Fire them up now that we've been able to create them all.
111	for i := 0; i < len(config.BindAddrs); i++ {
112		t.wg.Add(2)
113		go t.tcpListen(t.tcpListeners[i])
114		go t.udpListen(t.udpListeners[i])
115	}
116
117	ok = true
118	return &t, nil
119}
120
121// GetAutoBindPort returns the bind port that was automatically given by the
122// kernel, if a bind port of 0 was given.
123func (t *NetTransport) GetAutoBindPort() int {
124	// We made sure there's at least one TCP listener, and that one's
125	// port was applied to all the others for the dynamic bind case.
126	return t.tcpListeners[0].Addr().(*net.TCPAddr).Port
127}
128
129// See Transport.
130func (t *NetTransport) FinalAdvertiseAddr(ip string, port int) (net.IP, int, error) {
131	var advertiseAddr net.IP
132	var advertisePort int
133	if ip != "" {
134		// If they've supplied an address, use that.
135		advertiseAddr = net.ParseIP(ip)
136		if advertiseAddr == nil {
137			return nil, 0, fmt.Errorf("Failed to parse advertise address %q", ip)
138		}
139
140		// Ensure IPv4 conversion if necessary.
141		if ip4 := advertiseAddr.To4(); ip4 != nil {
142			advertiseAddr = ip4
143		}
144		advertisePort = port
145	} else {
146		if t.config.BindAddrs[0] == "0.0.0.0" {
147			// Otherwise, if we're not bound to a specific IP, let's
148			// use a suitable private IP address.
149			var err error
150			ip, err = sockaddr.GetPrivateIP()
151			if err != nil {
152				return nil, 0, fmt.Errorf("Failed to get interface addresses: %v", err)
153			}
154			if ip == "" {
155				return nil, 0, fmt.Errorf("No private IP address found, and explicit IP not provided")
156			}
157
158			advertiseAddr = net.ParseIP(ip)
159			if advertiseAddr == nil {
160				return nil, 0, fmt.Errorf("Failed to parse advertise address: %q", ip)
161			}
162		} else {
163			// Use the IP that we're bound to, based on the first
164			// TCP listener, which we already ensure is there.
165			advertiseAddr = t.tcpListeners[0].Addr().(*net.TCPAddr).IP
166		}
167
168		// Use the port we are bound to.
169		advertisePort = t.GetAutoBindPort()
170	}
171
172	return advertiseAddr, advertisePort, nil
173}
174
175// See Transport.
176func (t *NetTransport) WriteTo(b []byte, addr string) (time.Time, error) {
177	a := Address{Addr: addr, Name: ""}
178	return t.WriteToAddress(b, a)
179}
180
181// See NodeAwareTransport.
182func (t *NetTransport) WriteToAddress(b []byte, a Address) (time.Time, error) {
183	addr := a.Addr
184
185	udpAddr, err := net.ResolveUDPAddr("udp", addr)
186	if err != nil {
187		return time.Time{}, err
188	}
189
190	// We made sure there's at least one UDP listener, so just use the
191	// packet sending interface on the first one. Take the time after the
192	// write call comes back, which will underestimate the time a little,
193	// but help account for any delays before the write occurs.
194	_, err = t.udpListeners[0].WriteTo(b, udpAddr)
195	return time.Now(), err
196}
197
198// See Transport.
199func (t *NetTransport) PacketCh() <-chan *Packet {
200	return t.packetCh
201}
202
203// See IngestionAwareTransport.
204func (t *NetTransport) IngestPacket(conn net.Conn, addr net.Addr, now time.Time, shouldClose bool) error {
205	if shouldClose {
206		defer conn.Close()
207	}
208
209	// Copy everything from the stream into packet buffer.
210	var buf bytes.Buffer
211	if _, err := io.Copy(&buf, conn); err != nil {
212		return fmt.Errorf("failed to read packet: %v", err)
213	}
214
215	// Check the length - it needs to have at least one byte to be a proper
216	// message. This is checked elsewhere for writes coming in directly from
217	// the UDP socket.
218	if n := buf.Len(); n < 1 {
219		return fmt.Errorf("packet too short (%d bytes) %s", n, LogAddress(addr))
220	}
221
222	// Inject the packet.
223	t.packetCh <- &Packet{
224		Buf:       buf.Bytes(),
225		From:      addr,
226		Timestamp: now,
227	}
228	return nil
229}
230
231// See Transport.
232func (t *NetTransport) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
233	a := Address{Addr: addr, Name: ""}
234	return t.DialAddressTimeout(a, timeout)
235}
236
237// See NodeAwareTransport.
238func (t *NetTransport) DialAddressTimeout(a Address, timeout time.Duration) (net.Conn, error) {
239	addr := a.Addr
240
241	dialer := net.Dialer{Timeout: timeout}
242	return dialer.Dial("tcp", addr)
243}
244
245// See Transport.
246func (t *NetTransport) StreamCh() <-chan net.Conn {
247	return t.streamCh
248}
249
250// See IngestionAwareTransport.
251func (t *NetTransport) IngestStream(conn net.Conn) error {
252	t.streamCh <- conn
253	return nil
254}
255
256// See Transport.
257func (t *NetTransport) Shutdown() error {
258	// This will avoid log spam about errors when we shut down.
259	atomic.StoreInt32(&t.shutdown, 1)
260
261	// Rip through all the connections and shut them down.
262	for _, conn := range t.tcpListeners {
263		conn.Close()
264	}
265	for _, conn := range t.udpListeners {
266		conn.Close()
267	}
268
269	// Block until all the listener threads have died.
270	t.wg.Wait()
271	return nil
272}
273
274// tcpListen is a long running goroutine that accepts incoming TCP connections
275// and hands them off to the stream channel.
276func (t *NetTransport) tcpListen(tcpLn *net.TCPListener) {
277	defer t.wg.Done()
278
279	// baseDelay is the initial delay after an AcceptTCP() error before attempting again
280	const baseDelay = 5 * time.Millisecond
281
282	// maxDelay is the maximum delay after an AcceptTCP() error before attempting again.
283	// In the case that tcpListen() is error-looping, it will delay the shutdown check.
284	// Therefore, changes to maxDelay may have an effect on the latency of shutdown.
285	const maxDelay = 1 * time.Second
286
287	var loopDelay time.Duration
288	for {
289		conn, err := tcpLn.AcceptTCP()
290		if err != nil {
291			if s := atomic.LoadInt32(&t.shutdown); s == 1 {
292				break
293			}
294
295			if loopDelay == 0 {
296				loopDelay = baseDelay
297			} else {
298				loopDelay *= 2
299			}
300
301			if loopDelay > maxDelay {
302				loopDelay = maxDelay
303			}
304
305			t.logger.Printf("[ERR] memberlist: Error accepting TCP connection: %v", err)
306			time.Sleep(loopDelay)
307			continue
308		}
309		// No error, reset loop delay
310		loopDelay = 0
311
312		t.streamCh <- conn
313	}
314}
315
316// udpListen is a long running goroutine that accepts incoming UDP packets and
317// hands them off to the packet channel.
318func (t *NetTransport) udpListen(udpLn *net.UDPConn) {
319	defer t.wg.Done()
320	for {
321		// Do a blocking read into a fresh buffer. Grab a time stamp as
322		// close as possible to the I/O.
323		buf := make([]byte, udpPacketBufSize)
324		n, addr, err := udpLn.ReadFrom(buf)
325		ts := time.Now()
326		if err != nil {
327			if s := atomic.LoadInt32(&t.shutdown); s == 1 {
328				break
329			}
330
331			t.logger.Printf("[ERR] memberlist: Error reading UDP packet: %v", err)
332			continue
333		}
334
335		// Check the length - it needs to have at least one byte to be a
336		// proper message.
337		if n < 1 {
338			t.logger.Printf("[ERR] memberlist: UDP packet too short (%d bytes) %s",
339				len(buf), LogAddress(addr))
340			continue
341		}
342
343		// Ingest the packet.
344		metrics.IncrCounter([]string{"memberlist", "udp", "received"}, float32(n))
345		t.packetCh <- &Packet{
346			Buf:       buf[:n],
347			From:      addr,
348			Timestamp: ts,
349		}
350	}
351}
352
353// setUDPRecvBuf is used to resize the UDP receive window. The function
354// attempts to set the read buffer to `udpRecvBuf` but backs off until
355// the read buffer can be set.
356func setUDPRecvBuf(c *net.UDPConn) error {
357	size := udpRecvBufSize
358	var err error
359	for size > 0 {
360		if err = c.SetReadBuffer(size); err == nil {
361			return nil
362		}
363		size = size / 2
364	}
365	return err
366}
367