1package pool
2
3import (
4	"container/list"
5	"fmt"
6	"io"
7	"log"
8	"net"
9	"net/rpc"
10	"sync"
11	"sync/atomic"
12	"time"
13
14	"github.com/hashicorp/consul/lib"
15	hclog "github.com/hashicorp/go-hclog"
16	msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
17	"github.com/hashicorp/nomad/helper/tlsutil"
18	"github.com/hashicorp/nomad/nomad/structs"
19	"github.com/hashicorp/yamux"
20)
21
22// NewClientCodec returns a new rpc.ClientCodec to be used to make RPC calls.
23func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec {
24	return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle)
25}
26
27// NewServerCodec returns a new rpc.ServerCodec to be used to handle RPCs.
28func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec {
29	return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle)
30}
31
32// streamClient is used to wrap a stream with an RPC client
33type StreamClient struct {
34	stream net.Conn
35	codec  rpc.ClientCodec
36}
37
38func (sc *StreamClient) Close() {
39	sc.stream.Close()
40	sc.codec.Close()
41}
42
43// Conn is a pooled connection to a Nomad server
44type Conn struct {
45	refCount    int32
46	shouldClose int32
47
48	addr     net.Addr
49	session  *yamux.Session
50	lastUsed time.Time
51	version  int
52
53	pool *ConnPool
54
55	clients    *list.List
56	clientLock sync.Mutex
57}
58
59// markForUse does all the bookkeeping required to ready a connection for use.
60func (c *Conn) markForUse() {
61	c.lastUsed = time.Now()
62	atomic.AddInt32(&c.refCount, 1)
63}
64
65func (c *Conn) Close() error {
66	return c.session.Close()
67}
68
69// getClient is used to get a cached or new client
70func (c *Conn) getRPCClient() (*StreamClient, error) {
71	// Check for cached client
72	c.clientLock.Lock()
73	front := c.clients.Front()
74	if front != nil {
75		c.clients.Remove(front)
76	}
77	c.clientLock.Unlock()
78	if front != nil {
79		return front.Value.(*StreamClient), nil
80	}
81
82	// Open a new session
83	stream, err := c.session.Open()
84	if err != nil {
85		return nil, err
86	}
87
88	if _, err := stream.Write([]byte{byte(RpcNomad)}); err != nil {
89		stream.Close()
90		return nil, err
91	}
92
93	// Create a client codec
94	codec := NewClientCodec(stream)
95
96	// Return a new stream client
97	sc := &StreamClient{
98		stream: stream,
99		codec:  codec,
100	}
101	return sc, nil
102}
103
104// returnClient is used when done with a stream
105// to allow re-use by a future RPC
106func (c *Conn) returnClient(client *StreamClient) {
107	didSave := false
108	c.clientLock.Lock()
109	if c.clients.Len() < c.pool.maxStreams && atomic.LoadInt32(&c.shouldClose) == 0 {
110		c.clients.PushFront(client)
111		didSave = true
112
113		// If this is a Yamux stream, shrink the internal buffers so that
114		// we can GC the idle memory
115		if ys, ok := client.stream.(*yamux.Stream); ok {
116			ys.Shrink()
117		}
118	}
119	c.clientLock.Unlock()
120	if !didSave {
121		client.Close()
122	}
123}
124
125// ConnPool is used to maintain a connection pool to other
126// Nomad servers. This is used to reduce the latency of
127// RPC requests between servers. It is only used to pool
128// connections in the rpcNomad mode. Raft connections
129// are pooled separately.
130type ConnPool struct {
131	sync.Mutex
132
133	// logger is the logger to be used
134	logger *log.Logger
135
136	// The maximum time to keep a connection open
137	maxTime time.Duration
138
139	// The maximum number of open streams to keep
140	maxStreams int
141
142	// Pool maps an address to a open connection
143	pool map[string]*Conn
144
145	// limiter is used to throttle the number of connect attempts
146	// to a given address. The first thread will attempt a connection
147	// and put a channel in here, which all other threads will wait
148	// on to close.
149	limiter map[string]chan struct{}
150
151	// TLS wrapper
152	tlsWrap tlsutil.RegionWrapper
153
154	// Used to indicate the pool is shutdown
155	shutdown   bool
156	shutdownCh chan struct{}
157
158	// connListener is used to notify a potential listener of a new connection
159	// being made.
160	connListener chan<- *yamux.Session
161}
162
163// NewPool is used to make a new connection pool
164// Maintain at most one connection per host, for up to maxTime.
165// Set maxTime to 0 to disable reaping. maxStreams is used to control
166// the number of idle streams allowed.
167// If TLS settings are provided outgoing connections use TLS.
168func NewPool(logger hclog.Logger, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.RegionWrapper) *ConnPool {
169	pool := &ConnPool{
170		logger:     logger.StandardLogger(&hclog.StandardLoggerOptions{InferLevels: true}),
171		maxTime:    maxTime,
172		maxStreams: maxStreams,
173		pool:       make(map[string]*Conn),
174		limiter:    make(map[string]chan struct{}),
175		tlsWrap:    tlsWrap,
176		shutdownCh: make(chan struct{}),
177	}
178	if maxTime > 0 {
179		go pool.reap()
180	}
181	return pool
182}
183
184// Shutdown is used to close the connection pool
185func (p *ConnPool) Shutdown() error {
186	p.Lock()
187	defer p.Unlock()
188
189	for _, conn := range p.pool {
190		conn.Close()
191	}
192	p.pool = make(map[string]*Conn)
193
194	if p.shutdown {
195		return nil
196	}
197
198	if p.connListener != nil {
199		close(p.connListener)
200		p.connListener = nil
201	}
202
203	p.shutdown = true
204	close(p.shutdownCh)
205	return nil
206}
207
208// ReloadTLS reloads TLS configuration on the fly
209func (p *ConnPool) ReloadTLS(tlsWrap tlsutil.RegionWrapper) {
210	p.Lock()
211	defer p.Unlock()
212
213	oldPool := p.pool
214	for _, conn := range oldPool {
215		conn.Close()
216	}
217	p.pool = make(map[string]*Conn)
218	p.tlsWrap = tlsWrap
219}
220
221// SetConnListener is used to listen to new connections being made. The
222// channel will be closed when the conn pool is closed or a new listener is set.
223func (p *ConnPool) SetConnListener(l chan<- *yamux.Session) {
224	p.Lock()
225	defer p.Unlock()
226
227	// Close the old listener
228	if p.connListener != nil {
229		close(p.connListener)
230	}
231
232	// Store the new listener
233	p.connListener = l
234}
235
236// Acquire is used to get a connection that is
237// pooled or to return a new connection
238func (p *ConnPool) acquire(region string, addr net.Addr, version int) (*Conn, error) {
239	// Check to see if there's a pooled connection available. This is up
240	// here since it should the vastly more common case than the rest
241	// of the code here.
242	p.Lock()
243	c := p.pool[addr.String()]
244	if c != nil {
245		c.markForUse()
246		p.Unlock()
247		return c, nil
248	}
249
250	// If not (while we are still locked), set up the throttling structure
251	// for this address, which will make everyone else wait until our
252	// attempt is done.
253	var wait chan struct{}
254	var ok bool
255	if wait, ok = p.limiter[addr.String()]; !ok {
256		wait = make(chan struct{})
257		p.limiter[addr.String()] = wait
258	}
259	isLeadThread := !ok
260	p.Unlock()
261
262	// If we are the lead thread, make the new connection and then wake
263	// everybody else up to see if we got it.
264	if isLeadThread {
265		c, err := p.getNewConn(region, addr, version)
266		p.Lock()
267		delete(p.limiter, addr.String())
268		close(wait)
269		if err != nil {
270			p.Unlock()
271			return nil, err
272		}
273
274		p.pool[addr.String()] = c
275
276		// If there is a connection listener, notify them of the new connection.
277		if p.connListener != nil {
278			select {
279			case p.connListener <- c.session:
280			default:
281			}
282		}
283
284		p.Unlock()
285		return c, nil
286	}
287
288	// Otherwise, wait for the lead thread to attempt the connection
289	// and use what's in the pool at that point.
290	select {
291	case <-p.shutdownCh:
292		return nil, fmt.Errorf("rpc error: shutdown")
293	case <-wait:
294	}
295
296	// See if the lead thread was able to get us a connection.
297	p.Lock()
298	if c := p.pool[addr.String()]; c != nil {
299		c.markForUse()
300		p.Unlock()
301		return c, nil
302	}
303
304	p.Unlock()
305	return nil, fmt.Errorf("rpc error: lead thread didn't get connection")
306}
307
308// getNewConn is used to return a new connection
309func (p *ConnPool) getNewConn(region string, addr net.Addr, version int) (*Conn, error) {
310	// Try to dial the conn
311	conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second)
312	if err != nil {
313		return nil, err
314	}
315
316	// Cast to TCPConn
317	if tcp, ok := conn.(*net.TCPConn); ok {
318		tcp.SetKeepAlive(true)
319		tcp.SetNoDelay(true)
320	}
321
322	// Check if TLS is enabled
323	if p.tlsWrap != nil {
324		// Switch the connection into TLS mode
325		if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil {
326			conn.Close()
327			return nil, err
328		}
329
330		// Wrap the connection in a TLS client
331		tlsConn, err := p.tlsWrap(region, conn)
332		if err != nil {
333			conn.Close()
334			return nil, err
335		}
336		conn = tlsConn
337	}
338
339	// Write the multiplex byte to set the mode
340	if _, err := conn.Write([]byte{byte(RpcMultiplexV2)}); err != nil {
341		conn.Close()
342		return nil, err
343	}
344
345	// Setup the logger
346	conf := yamux.DefaultConfig()
347	conf.LogOutput = nil
348	conf.Logger = p.logger
349
350	// Create a multiplexed session
351	session, err := yamux.Client(conn, conf)
352	if err != nil {
353		conn.Close()
354		return nil, err
355	}
356
357	// Wrap the connection
358	c := &Conn{
359		refCount: 1,
360		addr:     addr,
361		session:  session,
362		clients:  list.New(),
363		lastUsed: time.Now(),
364		version:  version,
365		pool:     p,
366	}
367	return c, nil
368}
369
370// clearConn is used to clear any cached connection, potentially in response to
371// an error
372func (p *ConnPool) clearConn(conn *Conn) {
373	// Ensure returned streams are closed
374	atomic.StoreInt32(&conn.shouldClose, 1)
375
376	// Clear from the cache
377	p.Lock()
378	if c, ok := p.pool[conn.addr.String()]; ok && c == conn {
379		delete(p.pool, conn.addr.String())
380	}
381	p.Unlock()
382
383	// Close down immediately if idle
384	if refCount := atomic.LoadInt32(&conn.refCount); refCount == 0 {
385		conn.Close()
386	}
387}
388
389// releaseConn is invoked when we are done with a conn to reduce the ref count
390func (p *ConnPool) releaseConn(conn *Conn) {
391	refCount := atomic.AddInt32(&conn.refCount, -1)
392	if refCount == 0 && atomic.LoadInt32(&conn.shouldClose) == 1 {
393		conn.Close()
394	}
395}
396
397// getClient is used to get a usable client for an address and protocol version
398func (p *ConnPool) getRPCClient(region string, addr net.Addr, version int) (*Conn, *StreamClient, error) {
399	retries := 0
400START:
401	// Try to get a conn first
402	conn, err := p.acquire(region, addr, version)
403	if err != nil {
404		return nil, nil, fmt.Errorf("failed to get conn: %v", err)
405	}
406
407	// Get a client
408	client, err := conn.getRPCClient()
409	if err != nil {
410		p.clearConn(conn)
411		p.releaseConn(conn)
412
413		// Try to redial, possible that the TCP session closed due to timeout
414		if retries == 0 {
415			retries++
416			goto START
417		}
418		return nil, nil, fmt.Errorf("failed to start stream: %v", err)
419	}
420	return conn, client, nil
421}
422
423// StreamingRPC is used to make an streaming RPC call.  Callers must
424// close the connection when done.
425func (p *ConnPool) StreamingRPC(region string, addr net.Addr, version int) (net.Conn, error) {
426	conn, err := p.acquire(region, addr, version)
427	if err != nil {
428		return nil, fmt.Errorf("failed to get conn: %v", err)
429	}
430
431	s, err := conn.session.Open()
432	if err != nil {
433		return nil, fmt.Errorf("failed to open a streaming connection: %v", err)
434	}
435
436	if _, err := s.Write([]byte{byte(RpcStreaming)}); err != nil {
437		conn.Close()
438		return nil, err
439	}
440
441	return s, nil
442}
443
444// RPC is used to make an RPC call to a remote host
445func (p *ConnPool) RPC(region string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error {
446	// Get a usable client
447	conn, sc, err := p.getRPCClient(region, addr, version)
448	if err != nil {
449		return fmt.Errorf("rpc error: %w", err)
450	}
451
452	// Make the RPC call
453	err = msgpackrpc.CallWithCodec(sc.codec, method, args, reply)
454	if err != nil {
455		sc.Close()
456
457		// If we read EOF, the session is toast. Clear it and open a
458		// new session next time
459		// See https://github.com/hashicorp/consul/blob/v1.6.3/agent/pool/pool.go#L471-L477
460		if lib.IsErrEOF(err) {
461			p.clearConn(conn)
462		}
463
464		p.releaseConn(conn)
465
466		// If the error is an RPC Coded error
467		// return the coded error without wrapping
468		if structs.IsErrRPCCoded(err) {
469			return err
470		}
471
472		// TODO wrap with RPCCoded error instead
473		return fmt.Errorf("rpc error: %w", err)
474	}
475
476	// Done with the connection
477	conn.returnClient(sc)
478	p.releaseConn(conn)
479	return nil
480}
481
482// Reap is used to close conns open over maxTime
483func (p *ConnPool) reap() {
484	for {
485		// Sleep for a while
486		select {
487		case <-p.shutdownCh:
488			return
489		case <-time.After(time.Second):
490		}
491
492		// Reap all old conns
493		p.Lock()
494		var removed []string
495		now := time.Now()
496		for host, conn := range p.pool {
497			// Skip recently used connections
498			if now.Sub(conn.lastUsed) < p.maxTime {
499				continue
500			}
501
502			// Skip connections with active streams
503			if atomic.LoadInt32(&conn.refCount) > 0 {
504				continue
505			}
506
507			// Close the conn
508			conn.Close()
509
510			// Remove from pool
511			removed = append(removed, host)
512		}
513		for _, host := range removed {
514			delete(p.pool, host)
515		}
516		p.Unlock()
517	}
518}
519