1package pool 2 3import ( 4 "context" 5 "errors" 6 "net" 7 "sync" 8 "sync/atomic" 9 "time" 10 11 "github.com/go-redis/redis/v8/internal" 12) 13 14var ( 15 // ErrClosed performs any operation on the closed client will return this error. 16 ErrClosed = errors.New("redis: client is closed") 17 18 // ErrPoolTimeout timed out waiting to get a connection from the connection pool. 19 ErrPoolTimeout = errors.New("redis: connection pool timeout") 20) 21 22var timers = sync.Pool{ 23 New: func() interface{} { 24 t := time.NewTimer(time.Hour) 25 t.Stop() 26 return t 27 }, 28} 29 30// Stats contains pool state information and accumulated stats. 31type Stats struct { 32 Hits uint32 // number of times free connection was found in the pool 33 Misses uint32 // number of times free connection was NOT found in the pool 34 Timeouts uint32 // number of times a wait timeout occurred 35 36 TotalConns uint32 // number of total connections in the pool 37 IdleConns uint32 // number of idle connections in the pool 38 StaleConns uint32 // number of stale connections removed from the pool 39} 40 41type Pooler interface { 42 NewConn(context.Context) (*Conn, error) 43 CloseConn(*Conn) error 44 45 Get(context.Context) (*Conn, error) 46 Put(context.Context, *Conn) 47 Remove(context.Context, *Conn, error) 48 49 Len() int 50 IdleLen() int 51 Stats() *Stats 52 53 Close() error 54} 55 56type Options struct { 57 Dialer func(context.Context) (net.Conn, error) 58 OnClose func(*Conn) error 59 60 PoolFIFO bool 61 PoolSize int 62 MinIdleConns int 63 MaxConnAge time.Duration 64 PoolTimeout time.Duration 65 IdleTimeout time.Duration 66 IdleCheckFrequency time.Duration 67} 68 69type lastDialErrorWrap struct { 70 err error 71} 72 73type ConnPool struct { 74 opt *Options 75 76 dialErrorsNum uint32 // atomic 77 78 lastDialError atomic.Value 79 80 queue chan struct{} 81 82 connsMu sync.Mutex 83 conns []*Conn 84 idleConns []*Conn 85 poolSize int 86 idleConnsLen int 87 88 stats Stats 89 90 _closed uint32 // atomic 91 closedCh chan struct{} 92} 93 94var _ Pooler = (*ConnPool)(nil) 95 96func NewConnPool(opt *Options) *ConnPool { 97 p := &ConnPool{ 98 opt: opt, 99 100 queue: make(chan struct{}, opt.PoolSize), 101 conns: make([]*Conn, 0, opt.PoolSize), 102 idleConns: make([]*Conn, 0, opt.PoolSize), 103 closedCh: make(chan struct{}), 104 } 105 106 p.connsMu.Lock() 107 p.checkMinIdleConns() 108 p.connsMu.Unlock() 109 110 if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 { 111 go p.reaper(opt.IdleCheckFrequency) 112 } 113 114 return p 115} 116 117func (p *ConnPool) checkMinIdleConns() { 118 if p.opt.MinIdleConns == 0 { 119 return 120 } 121 for p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns { 122 p.poolSize++ 123 p.idleConnsLen++ 124 go func() { 125 err := p.addIdleConn() 126 if err != nil { 127 p.connsMu.Lock() 128 p.poolSize-- 129 p.idleConnsLen-- 130 p.connsMu.Unlock() 131 } 132 }() 133 } 134} 135 136func (p *ConnPool) addIdleConn() error { 137 cn, err := p.dialConn(context.TODO(), true) 138 if err != nil { 139 return err 140 } 141 142 p.connsMu.Lock() 143 p.conns = append(p.conns, cn) 144 p.idleConns = append(p.idleConns, cn) 145 p.connsMu.Unlock() 146 return nil 147} 148 149func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) { 150 return p.newConn(ctx, false) 151} 152 153func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { 154 cn, err := p.dialConn(ctx, pooled) 155 if err != nil { 156 return nil, err 157 } 158 159 p.connsMu.Lock() 160 p.conns = append(p.conns, cn) 161 if pooled { 162 // If pool is full remove the cn on next Put. 163 if p.poolSize >= p.opt.PoolSize { 164 cn.pooled = false 165 } else { 166 p.poolSize++ 167 } 168 } 169 p.connsMu.Unlock() 170 171 return cn, nil 172} 173 174func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { 175 if p.closed() { 176 return nil, ErrClosed 177 } 178 179 if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) { 180 return nil, p.getLastDialError() 181 } 182 183 netConn, err := p.opt.Dialer(ctx) 184 if err != nil { 185 p.setLastDialError(err) 186 if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) { 187 go p.tryDial() 188 } 189 return nil, err 190 } 191 192 cn := NewConn(netConn) 193 cn.pooled = pooled 194 return cn, nil 195} 196 197func (p *ConnPool) tryDial() { 198 for { 199 if p.closed() { 200 return 201 } 202 203 conn, err := p.opt.Dialer(context.Background()) 204 if err != nil { 205 p.setLastDialError(err) 206 time.Sleep(time.Second) 207 continue 208 } 209 210 atomic.StoreUint32(&p.dialErrorsNum, 0) 211 _ = conn.Close() 212 return 213 } 214} 215 216func (p *ConnPool) setLastDialError(err error) { 217 p.lastDialError.Store(&lastDialErrorWrap{err: err}) 218} 219 220func (p *ConnPool) getLastDialError() error { 221 err, _ := p.lastDialError.Load().(*lastDialErrorWrap) 222 if err != nil { 223 return err.err 224 } 225 return nil 226} 227 228// Get returns existed connection from the pool or creates a new one. 229func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { 230 if p.closed() { 231 return nil, ErrClosed 232 } 233 234 if err := p.waitTurn(ctx); err != nil { 235 return nil, err 236 } 237 238 for { 239 p.connsMu.Lock() 240 cn := p.popIdle() 241 p.connsMu.Unlock() 242 243 if cn == nil { 244 break 245 } 246 247 if p.isStaleConn(cn) { 248 _ = p.CloseConn(cn) 249 continue 250 } 251 252 atomic.AddUint32(&p.stats.Hits, 1) 253 return cn, nil 254 } 255 256 atomic.AddUint32(&p.stats.Misses, 1) 257 258 newcn, err := p.newConn(ctx, true) 259 if err != nil { 260 p.freeTurn() 261 return nil, err 262 } 263 264 return newcn, nil 265} 266 267func (p *ConnPool) getTurn() { 268 p.queue <- struct{}{} 269} 270 271func (p *ConnPool) waitTurn(ctx context.Context) error { 272 select { 273 case <-ctx.Done(): 274 return ctx.Err() 275 default: 276 } 277 278 select { 279 case p.queue <- struct{}{}: 280 return nil 281 default: 282 } 283 284 timer := timers.Get().(*time.Timer) 285 timer.Reset(p.opt.PoolTimeout) 286 287 select { 288 case <-ctx.Done(): 289 if !timer.Stop() { 290 <-timer.C 291 } 292 timers.Put(timer) 293 return ctx.Err() 294 case p.queue <- struct{}{}: 295 if !timer.Stop() { 296 <-timer.C 297 } 298 timers.Put(timer) 299 return nil 300 case <-timer.C: 301 timers.Put(timer) 302 atomic.AddUint32(&p.stats.Timeouts, 1) 303 return ErrPoolTimeout 304 } 305} 306 307func (p *ConnPool) freeTurn() { 308 <-p.queue 309} 310 311func (p *ConnPool) popIdle() *Conn { 312 n := len(p.idleConns) 313 if n == 0 { 314 return nil 315 } 316 317 var cn *Conn 318 if p.opt.PoolFIFO { 319 cn = p.idleConns[0] 320 copy(p.idleConns, p.idleConns[1:]) 321 p.idleConns = p.idleConns[:n-1] 322 } else { 323 idx := n - 1 324 cn = p.idleConns[idx] 325 p.idleConns = p.idleConns[:idx] 326 } 327 p.idleConnsLen-- 328 p.checkMinIdleConns() 329 return cn 330} 331 332func (p *ConnPool) Put(ctx context.Context, cn *Conn) { 333 if cn.rd.Buffered() > 0 { 334 internal.Logger.Printf(ctx, "Conn has unread data") 335 p.Remove(ctx, cn, BadConnError{}) 336 return 337 } 338 339 if !cn.pooled { 340 p.Remove(ctx, cn, nil) 341 return 342 } 343 344 p.connsMu.Lock() 345 p.idleConns = append(p.idleConns, cn) 346 p.idleConnsLen++ 347 p.connsMu.Unlock() 348 p.freeTurn() 349} 350 351func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) { 352 p.removeConnWithLock(cn) 353 p.freeTurn() 354 _ = p.closeConn(cn) 355} 356 357func (p *ConnPool) CloseConn(cn *Conn) error { 358 p.removeConnWithLock(cn) 359 return p.closeConn(cn) 360} 361 362func (p *ConnPool) removeConnWithLock(cn *Conn) { 363 p.connsMu.Lock() 364 p.removeConn(cn) 365 p.connsMu.Unlock() 366} 367 368func (p *ConnPool) removeConn(cn *Conn) { 369 for i, c := range p.conns { 370 if c == cn { 371 p.conns = append(p.conns[:i], p.conns[i+1:]...) 372 if cn.pooled { 373 p.poolSize-- 374 p.checkMinIdleConns() 375 } 376 return 377 } 378 } 379} 380 381func (p *ConnPool) closeConn(cn *Conn) error { 382 if p.opt.OnClose != nil { 383 _ = p.opt.OnClose(cn) 384 } 385 return cn.Close() 386} 387 388// Len returns total number of connections. 389func (p *ConnPool) Len() int { 390 p.connsMu.Lock() 391 n := len(p.conns) 392 p.connsMu.Unlock() 393 return n 394} 395 396// IdleLen returns number of idle connections. 397func (p *ConnPool) IdleLen() int { 398 p.connsMu.Lock() 399 n := p.idleConnsLen 400 p.connsMu.Unlock() 401 return n 402} 403 404func (p *ConnPool) Stats() *Stats { 405 idleLen := p.IdleLen() 406 return &Stats{ 407 Hits: atomic.LoadUint32(&p.stats.Hits), 408 Misses: atomic.LoadUint32(&p.stats.Misses), 409 Timeouts: atomic.LoadUint32(&p.stats.Timeouts), 410 411 TotalConns: uint32(p.Len()), 412 IdleConns: uint32(idleLen), 413 StaleConns: atomic.LoadUint32(&p.stats.StaleConns), 414 } 415} 416 417func (p *ConnPool) closed() bool { 418 return atomic.LoadUint32(&p._closed) == 1 419} 420 421func (p *ConnPool) Filter(fn func(*Conn) bool) error { 422 p.connsMu.Lock() 423 defer p.connsMu.Unlock() 424 425 var firstErr error 426 for _, cn := range p.conns { 427 if fn(cn) { 428 if err := p.closeConn(cn); err != nil && firstErr == nil { 429 firstErr = err 430 } 431 } 432 } 433 return firstErr 434} 435 436func (p *ConnPool) Close() error { 437 if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) { 438 return ErrClosed 439 } 440 close(p.closedCh) 441 442 var firstErr error 443 p.connsMu.Lock() 444 for _, cn := range p.conns { 445 if err := p.closeConn(cn); err != nil && firstErr == nil { 446 firstErr = err 447 } 448 } 449 p.conns = nil 450 p.poolSize = 0 451 p.idleConns = nil 452 p.idleConnsLen = 0 453 p.connsMu.Unlock() 454 455 return firstErr 456} 457 458func (p *ConnPool) reaper(frequency time.Duration) { 459 ticker := time.NewTicker(frequency) 460 defer ticker.Stop() 461 462 for { 463 select { 464 case <-ticker.C: 465 // It is possible that ticker and closedCh arrive together, 466 // and select pseudo-randomly pick ticker case, we double 467 // check here to prevent being executed after closed. 468 if p.closed() { 469 return 470 } 471 _, err := p.ReapStaleConns() 472 if err != nil { 473 internal.Logger.Printf(context.Background(), "ReapStaleConns failed: %s", err) 474 continue 475 } 476 case <-p.closedCh: 477 return 478 } 479 } 480} 481 482func (p *ConnPool) ReapStaleConns() (int, error) { 483 var n int 484 for { 485 p.getTurn() 486 487 p.connsMu.Lock() 488 cn := p.reapStaleConn() 489 p.connsMu.Unlock() 490 491 p.freeTurn() 492 493 if cn != nil { 494 _ = p.closeConn(cn) 495 n++ 496 } else { 497 break 498 } 499 } 500 atomic.AddUint32(&p.stats.StaleConns, uint32(n)) 501 return n, nil 502} 503 504func (p *ConnPool) reapStaleConn() *Conn { 505 if len(p.idleConns) == 0 { 506 return nil 507 } 508 509 cn := p.idleConns[0] 510 if !p.isStaleConn(cn) { 511 return nil 512 } 513 514 p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...) 515 p.idleConnsLen-- 516 p.removeConn(cn) 517 518 return cn 519} 520 521func (p *ConnPool) isStaleConn(cn *Conn) bool { 522 if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 { 523 return false 524 } 525 526 now := time.Now() 527 if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout { 528 return true 529 } 530 if p.opt.MaxConnAge > 0 && now.Sub(cn.createdAt) >= p.opt.MaxConnAge { 531 return true 532 } 533 534 return false 535} 536