1package pool 2 3import ( 4 "container/list" 5 "fmt" 6 "io" 7 "log" 8 "net" 9 "net/rpc" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 "github.com/hashicorp/consul/lib" 15 hclog "github.com/hashicorp/go-hclog" 16 msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" 17 "github.com/hashicorp/nomad/helper/tlsutil" 18 "github.com/hashicorp/nomad/nomad/structs" 19 "github.com/hashicorp/yamux" 20) 21 22// NewClientCodec returns a new rpc.ClientCodec to be used to make RPC calls. 23func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec { 24 return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle) 25} 26 27// NewServerCodec returns a new rpc.ServerCodec to be used to handle RPCs. 28func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec { 29 return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle) 30} 31 32// streamClient is used to wrap a stream with an RPC client 33type StreamClient struct { 34 stream net.Conn 35 codec rpc.ClientCodec 36} 37 38func (sc *StreamClient) Close() { 39 sc.stream.Close() 40 sc.codec.Close() 41} 42 43// Conn is a pooled connection to a Nomad server 44type Conn struct { 45 refCount int32 46 shouldClose int32 47 48 addr net.Addr 49 session *yamux.Session 50 lastUsed time.Time 51 version int 52 53 pool *ConnPool 54 55 clients *list.List 56 clientLock sync.Mutex 57} 58 59// markForUse does all the bookkeeping required to ready a connection for use. 60func (c *Conn) markForUse() { 61 c.lastUsed = time.Now() 62 atomic.AddInt32(&c.refCount, 1) 63} 64 65func (c *Conn) Close() error { 66 return c.session.Close() 67} 68 69// getClient is used to get a cached or new client 70func (c *Conn) getRPCClient() (*StreamClient, error) { 71 // Check for cached client 72 c.clientLock.Lock() 73 front := c.clients.Front() 74 if front != nil { 75 c.clients.Remove(front) 76 } 77 c.clientLock.Unlock() 78 if front != nil { 79 return front.Value.(*StreamClient), nil 80 } 81 82 // Open a new session 83 stream, err := c.session.Open() 84 if err != nil { 85 return nil, err 86 } 87 88 if _, err := stream.Write([]byte{byte(RpcNomad)}); err != nil { 89 stream.Close() 90 return nil, err 91 } 92 93 // Create a client codec 94 codec := NewClientCodec(stream) 95 96 // Return a new stream client 97 sc := &StreamClient{ 98 stream: stream, 99 codec: codec, 100 } 101 return sc, nil 102} 103 104// returnClient is used when done with a stream 105// to allow re-use by a future RPC 106func (c *Conn) returnClient(client *StreamClient) { 107 didSave := false 108 c.clientLock.Lock() 109 if c.clients.Len() < c.pool.maxStreams && atomic.LoadInt32(&c.shouldClose) == 0 { 110 c.clients.PushFront(client) 111 didSave = true 112 113 // If this is a Yamux stream, shrink the internal buffers so that 114 // we can GC the idle memory 115 if ys, ok := client.stream.(*yamux.Stream); ok { 116 ys.Shrink() 117 } 118 } 119 c.clientLock.Unlock() 120 if !didSave { 121 client.Close() 122 } 123} 124 125// ConnPool is used to maintain a connection pool to other 126// Nomad servers. This is used to reduce the latency of 127// RPC requests between servers. It is only used to pool 128// connections in the rpcNomad mode. Raft connections 129// are pooled separately. 130type ConnPool struct { 131 sync.Mutex 132 133 // logger is the logger to be used 134 logger *log.Logger 135 136 // The maximum time to keep a connection open 137 maxTime time.Duration 138 139 // The maximum number of open streams to keep 140 maxStreams int 141 142 // Pool maps an address to a open connection 143 pool map[string]*Conn 144 145 // limiter is used to throttle the number of connect attempts 146 // to a given address. The first thread will attempt a connection 147 // and put a channel in here, which all other threads will wait 148 // on to close. 149 limiter map[string]chan struct{} 150 151 // TLS wrapper 152 tlsWrap tlsutil.RegionWrapper 153 154 // Used to indicate the pool is shutdown 155 shutdown bool 156 shutdownCh chan struct{} 157 158 // connListener is used to notify a potential listener of a new connection 159 // being made. 160 connListener chan<- *yamux.Session 161} 162 163// NewPool is used to make a new connection pool 164// Maintain at most one connection per host, for up to maxTime. 165// Set maxTime to 0 to disable reaping. maxStreams is used to control 166// the number of idle streams allowed. 167// If TLS settings are provided outgoing connections use TLS. 168func NewPool(logger hclog.Logger, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.RegionWrapper) *ConnPool { 169 pool := &ConnPool{ 170 logger: logger.StandardLogger(&hclog.StandardLoggerOptions{InferLevels: true}), 171 maxTime: maxTime, 172 maxStreams: maxStreams, 173 pool: make(map[string]*Conn), 174 limiter: make(map[string]chan struct{}), 175 tlsWrap: tlsWrap, 176 shutdownCh: make(chan struct{}), 177 } 178 if maxTime > 0 { 179 go pool.reap() 180 } 181 return pool 182} 183 184// Shutdown is used to close the connection pool 185func (p *ConnPool) Shutdown() error { 186 p.Lock() 187 defer p.Unlock() 188 189 for _, conn := range p.pool { 190 conn.Close() 191 } 192 p.pool = make(map[string]*Conn) 193 194 if p.shutdown { 195 return nil 196 } 197 198 if p.connListener != nil { 199 close(p.connListener) 200 p.connListener = nil 201 } 202 203 p.shutdown = true 204 close(p.shutdownCh) 205 return nil 206} 207 208// ReloadTLS reloads TLS configuration on the fly 209func (p *ConnPool) ReloadTLS(tlsWrap tlsutil.RegionWrapper) { 210 p.Lock() 211 defer p.Unlock() 212 213 oldPool := p.pool 214 for _, conn := range oldPool { 215 conn.Close() 216 } 217 p.pool = make(map[string]*Conn) 218 p.tlsWrap = tlsWrap 219} 220 221// SetConnListener is used to listen to new connections being made. The 222// channel will be closed when the conn pool is closed or a new listener is set. 223func (p *ConnPool) SetConnListener(l chan<- *yamux.Session) { 224 p.Lock() 225 defer p.Unlock() 226 227 // Close the old listener 228 if p.connListener != nil { 229 close(p.connListener) 230 } 231 232 // Store the new listener 233 p.connListener = l 234} 235 236// Acquire is used to get a connection that is 237// pooled or to return a new connection 238func (p *ConnPool) acquire(region string, addr net.Addr, version int) (*Conn, error) { 239 // Check to see if there's a pooled connection available. This is up 240 // here since it should the vastly more common case than the rest 241 // of the code here. 242 p.Lock() 243 c := p.pool[addr.String()] 244 if c != nil { 245 c.markForUse() 246 p.Unlock() 247 return c, nil 248 } 249 250 // If not (while we are still locked), set up the throttling structure 251 // for this address, which will make everyone else wait until our 252 // attempt is done. 253 var wait chan struct{} 254 var ok bool 255 if wait, ok = p.limiter[addr.String()]; !ok { 256 wait = make(chan struct{}) 257 p.limiter[addr.String()] = wait 258 } 259 isLeadThread := !ok 260 p.Unlock() 261 262 // If we are the lead thread, make the new connection and then wake 263 // everybody else up to see if we got it. 264 if isLeadThread { 265 c, err := p.getNewConn(region, addr, version) 266 p.Lock() 267 delete(p.limiter, addr.String()) 268 close(wait) 269 if err != nil { 270 p.Unlock() 271 return nil, err 272 } 273 274 p.pool[addr.String()] = c 275 276 // If there is a connection listener, notify them of the new connection. 277 if p.connListener != nil { 278 select { 279 case p.connListener <- c.session: 280 default: 281 } 282 } 283 284 p.Unlock() 285 return c, nil 286 } 287 288 // Otherwise, wait for the lead thread to attempt the connection 289 // and use what's in the pool at that point. 290 select { 291 case <-p.shutdownCh: 292 return nil, fmt.Errorf("rpc error: shutdown") 293 case <-wait: 294 } 295 296 // See if the lead thread was able to get us a connection. 297 p.Lock() 298 if c := p.pool[addr.String()]; c != nil { 299 c.markForUse() 300 p.Unlock() 301 return c, nil 302 } 303 304 p.Unlock() 305 return nil, fmt.Errorf("rpc error: lead thread didn't get connection") 306} 307 308// getNewConn is used to return a new connection 309func (p *ConnPool) getNewConn(region string, addr net.Addr, version int) (*Conn, error) { 310 // Try to dial the conn 311 conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second) 312 if err != nil { 313 return nil, err 314 } 315 316 // Cast to TCPConn 317 if tcp, ok := conn.(*net.TCPConn); ok { 318 tcp.SetKeepAlive(true) 319 tcp.SetNoDelay(true) 320 } 321 322 // Check if TLS is enabled 323 if p.tlsWrap != nil { 324 // Switch the connection into TLS mode 325 if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil { 326 conn.Close() 327 return nil, err 328 } 329 330 // Wrap the connection in a TLS client 331 tlsConn, err := p.tlsWrap(region, conn) 332 if err != nil { 333 conn.Close() 334 return nil, err 335 } 336 conn = tlsConn 337 } 338 339 // Write the multiplex byte to set the mode 340 if _, err := conn.Write([]byte{byte(RpcMultiplexV2)}); err != nil { 341 conn.Close() 342 return nil, err 343 } 344 345 // Setup the logger 346 conf := yamux.DefaultConfig() 347 conf.LogOutput = nil 348 conf.Logger = p.logger 349 350 // Create a multiplexed session 351 session, err := yamux.Client(conn, conf) 352 if err != nil { 353 conn.Close() 354 return nil, err 355 } 356 357 // Wrap the connection 358 c := &Conn{ 359 refCount: 1, 360 addr: addr, 361 session: session, 362 clients: list.New(), 363 lastUsed: time.Now(), 364 version: version, 365 pool: p, 366 } 367 return c, nil 368} 369 370// clearConn is used to clear any cached connection, potentially in response to 371// an error 372func (p *ConnPool) clearConn(conn *Conn) { 373 // Ensure returned streams are closed 374 atomic.StoreInt32(&conn.shouldClose, 1) 375 376 // Clear from the cache 377 p.Lock() 378 if c, ok := p.pool[conn.addr.String()]; ok && c == conn { 379 delete(p.pool, conn.addr.String()) 380 } 381 p.Unlock() 382 383 // Close down immediately if idle 384 if refCount := atomic.LoadInt32(&conn.refCount); refCount == 0 { 385 conn.Close() 386 } 387} 388 389// releaseConn is invoked when we are done with a conn to reduce the ref count 390func (p *ConnPool) releaseConn(conn *Conn) { 391 refCount := atomic.AddInt32(&conn.refCount, -1) 392 if refCount == 0 && atomic.LoadInt32(&conn.shouldClose) == 1 { 393 conn.Close() 394 } 395} 396 397// getClient is used to get a usable client for an address and protocol version 398func (p *ConnPool) getRPCClient(region string, addr net.Addr, version int) (*Conn, *StreamClient, error) { 399 retries := 0 400START: 401 // Try to get a conn first 402 conn, err := p.acquire(region, addr, version) 403 if err != nil { 404 return nil, nil, fmt.Errorf("failed to get conn: %v", err) 405 } 406 407 // Get a client 408 client, err := conn.getRPCClient() 409 if err != nil { 410 p.clearConn(conn) 411 p.releaseConn(conn) 412 413 // Try to redial, possible that the TCP session closed due to timeout 414 if retries == 0 { 415 retries++ 416 goto START 417 } 418 return nil, nil, fmt.Errorf("failed to start stream: %v", err) 419 } 420 return conn, client, nil 421} 422 423// StreamingRPC is used to make an streaming RPC call. Callers must 424// close the connection when done. 425func (p *ConnPool) StreamingRPC(region string, addr net.Addr, version int) (net.Conn, error) { 426 conn, err := p.acquire(region, addr, version) 427 if err != nil { 428 return nil, fmt.Errorf("failed to get conn: %v", err) 429 } 430 431 s, err := conn.session.Open() 432 if err != nil { 433 return nil, fmt.Errorf("failed to open a streaming connection: %v", err) 434 } 435 436 if _, err := s.Write([]byte{byte(RpcStreaming)}); err != nil { 437 conn.Close() 438 return nil, err 439 } 440 441 return s, nil 442} 443 444// RPC is used to make an RPC call to a remote host 445func (p *ConnPool) RPC(region string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error { 446 // Get a usable client 447 conn, sc, err := p.getRPCClient(region, addr, version) 448 if err != nil { 449 return fmt.Errorf("rpc error: %w", err) 450 } 451 452 // Make the RPC call 453 err = msgpackrpc.CallWithCodec(sc.codec, method, args, reply) 454 if err != nil { 455 sc.Close() 456 457 // If we read EOF, the session is toast. Clear it and open a 458 // new session next time 459 // See https://github.com/hashicorp/consul/blob/v1.6.3/agent/pool/pool.go#L471-L477 460 if lib.IsErrEOF(err) { 461 p.clearConn(conn) 462 } 463 464 p.releaseConn(conn) 465 466 // If the error is an RPC Coded error 467 // return the coded error without wrapping 468 if structs.IsErrRPCCoded(err) { 469 return err 470 } 471 472 // TODO wrap with RPCCoded error instead 473 return fmt.Errorf("rpc error: %w", err) 474 } 475 476 // Done with the connection 477 conn.returnClient(sc) 478 p.releaseConn(conn) 479 return nil 480} 481 482// Reap is used to close conns open over maxTime 483func (p *ConnPool) reap() { 484 for { 485 // Sleep for a while 486 select { 487 case <-p.shutdownCh: 488 return 489 case <-time.After(time.Second): 490 } 491 492 // Reap all old conns 493 p.Lock() 494 var removed []string 495 now := time.Now() 496 for host, conn := range p.pool { 497 // Skip recently used connections 498 if now.Sub(conn.lastUsed) < p.maxTime { 499 continue 500 } 501 502 // Skip connections with active streams 503 if atomic.LoadInt32(&conn.refCount) > 0 { 504 continue 505 } 506 507 // Close the conn 508 conn.Close() 509 510 // Remove from pool 511 removed = append(removed, host) 512 } 513 for _, host := range removed { 514 delete(p.pool, host) 515 } 516 p.Unlock() 517 } 518} 519