1package pool 2 3import ( 4 "context" 5 "errors" 6 "net" 7 "sync" 8 "sync/atomic" 9 "time" 10 11 "github.com/go-redis/redis/v7/internal" 12) 13 14var ErrClosed = errors.New("redis: client is closed") 15var ErrPoolTimeout = errors.New("redis: connection pool timeout") 16 17var timers = sync.Pool{ 18 New: func() interface{} { 19 t := time.NewTimer(time.Hour) 20 t.Stop() 21 return t 22 }, 23} 24 25// Stats contains pool state information and accumulated stats. 26type Stats struct { 27 Hits uint32 // number of times free connection was found in the pool 28 Misses uint32 // number of times free connection was NOT found in the pool 29 Timeouts uint32 // number of times a wait timeout occurred 30 31 TotalConns uint32 // number of total connections in the pool 32 IdleConns uint32 // number of idle connections in the pool 33 StaleConns uint32 // number of stale connections removed from the pool 34} 35 36type Pooler interface { 37 NewConn(context.Context) (*Conn, error) 38 CloseConn(*Conn) error 39 40 Get(context.Context) (*Conn, error) 41 Put(*Conn) 42 Remove(*Conn, error) 43 44 Len() int 45 IdleLen() int 46 Stats() *Stats 47 48 Close() error 49} 50 51type Options struct { 52 Dialer func(context.Context) (net.Conn, error) 53 OnClose func(*Conn) error 54 55 PoolSize int 56 MinIdleConns int 57 MaxConnAge time.Duration 58 PoolTimeout time.Duration 59 IdleTimeout time.Duration 60 IdleCheckFrequency time.Duration 61} 62 63type ConnPool struct { 64 opt *Options 65 66 dialErrorsNum uint32 // atomic 67 68 lastDialErrorMu sync.RWMutex 69 lastDialError error 70 71 queue chan struct{} 72 73 connsMu sync.Mutex 74 conns []*Conn 75 idleConns []*Conn 76 poolSize int 77 idleConnsLen int 78 79 stats Stats 80 81 _closed uint32 // atomic 82 closedCh chan struct{} 83} 84 85var _ Pooler = (*ConnPool)(nil) 86 87func NewConnPool(opt *Options) *ConnPool { 88 p := &ConnPool{ 89 opt: opt, 90 91 queue: make(chan struct{}, opt.PoolSize), 92 conns: make([]*Conn, 0, opt.PoolSize), 93 idleConns: make([]*Conn, 0, opt.PoolSize), 94 closedCh: make(chan struct{}), 95 } 96 97 p.checkMinIdleConns() 98 99 if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 { 100 go p.reaper(opt.IdleCheckFrequency) 101 } 102 103 return p 104} 105 106func (p *ConnPool) checkMinIdleConns() { 107 if p.opt.MinIdleConns == 0 { 108 return 109 } 110 for p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns { 111 p.poolSize++ 112 p.idleConnsLen++ 113 go func() { 114 err := p.addIdleConn() 115 if err != nil { 116 p.connsMu.Lock() 117 p.poolSize-- 118 p.idleConnsLen-- 119 p.connsMu.Unlock() 120 } 121 }() 122 } 123} 124 125func (p *ConnPool) addIdleConn() error { 126 cn, err := p.dialConn(context.TODO(), true) 127 if err != nil { 128 return err 129 } 130 131 p.connsMu.Lock() 132 p.conns = append(p.conns, cn) 133 p.idleConns = append(p.idleConns, cn) 134 p.connsMu.Unlock() 135 return nil 136} 137 138func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) { 139 return p.newConn(ctx, false) 140} 141 142func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { 143 cn, err := p.dialConn(ctx, pooled) 144 if err != nil { 145 return nil, err 146 } 147 148 p.connsMu.Lock() 149 p.conns = append(p.conns, cn) 150 if pooled { 151 // If pool is full remove the cn on next Put. 152 if p.poolSize >= p.opt.PoolSize { 153 cn.pooled = false 154 } else { 155 p.poolSize++ 156 } 157 } 158 p.connsMu.Unlock() 159 return cn, nil 160} 161 162func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { 163 if p.closed() { 164 return nil, ErrClosed 165 } 166 167 if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) { 168 return nil, p.getLastDialError() 169 } 170 171 netConn, err := p.opt.Dialer(ctx) 172 if err != nil { 173 p.setLastDialError(err) 174 if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) { 175 go p.tryDial() 176 } 177 return nil, err 178 } 179 180 cn := NewConn(netConn) 181 cn.pooled = pooled 182 return cn, nil 183} 184 185func (p *ConnPool) tryDial() { 186 for { 187 if p.closed() { 188 return 189 } 190 191 conn, err := p.opt.Dialer(context.Background()) 192 if err != nil { 193 p.setLastDialError(err) 194 time.Sleep(time.Second) 195 continue 196 } 197 198 atomic.StoreUint32(&p.dialErrorsNum, 0) 199 _ = conn.Close() 200 return 201 } 202} 203 204func (p *ConnPool) setLastDialError(err error) { 205 p.lastDialErrorMu.Lock() 206 p.lastDialError = err 207 p.lastDialErrorMu.Unlock() 208} 209 210func (p *ConnPool) getLastDialError() error { 211 p.lastDialErrorMu.RLock() 212 err := p.lastDialError 213 p.lastDialErrorMu.RUnlock() 214 return err 215} 216 217// Get returns existed connection from the pool or creates a new one. 218func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { 219 if p.closed() { 220 return nil, ErrClosed 221 } 222 223 err := p.waitTurn(ctx) 224 if err != nil { 225 return nil, err 226 } 227 228 for { 229 p.connsMu.Lock() 230 cn := p.popIdle() 231 p.connsMu.Unlock() 232 233 if cn == nil { 234 break 235 } 236 237 if p.isStaleConn(cn) { 238 _ = p.CloseConn(cn) 239 continue 240 } 241 242 atomic.AddUint32(&p.stats.Hits, 1) 243 return cn, nil 244 } 245 246 atomic.AddUint32(&p.stats.Misses, 1) 247 248 newcn, err := p.newConn(ctx, true) 249 if err != nil { 250 p.freeTurn() 251 return nil, err 252 } 253 254 return newcn, nil 255} 256 257func (p *ConnPool) getTurn() { 258 p.queue <- struct{}{} 259} 260 261func (p *ConnPool) waitTurn(ctx context.Context) error { 262 select { 263 case <-ctx.Done(): 264 return ctx.Err() 265 default: 266 } 267 268 select { 269 case p.queue <- struct{}{}: 270 return nil 271 default: 272 } 273 274 timer := timers.Get().(*time.Timer) 275 timer.Reset(p.opt.PoolTimeout) 276 277 select { 278 case <-ctx.Done(): 279 if !timer.Stop() { 280 <-timer.C 281 } 282 timers.Put(timer) 283 return ctx.Err() 284 case p.queue <- struct{}{}: 285 if !timer.Stop() { 286 <-timer.C 287 } 288 timers.Put(timer) 289 return nil 290 case <-timer.C: 291 timers.Put(timer) 292 atomic.AddUint32(&p.stats.Timeouts, 1) 293 return ErrPoolTimeout 294 } 295} 296 297func (p *ConnPool) freeTurn() { 298 <-p.queue 299} 300 301func (p *ConnPool) popIdle() *Conn { 302 if len(p.idleConns) == 0 { 303 return nil 304 } 305 306 idx := len(p.idleConns) - 1 307 cn := p.idleConns[idx] 308 p.idleConns = p.idleConns[:idx] 309 p.idleConnsLen-- 310 p.checkMinIdleConns() 311 return cn 312} 313 314func (p *ConnPool) Put(cn *Conn) { 315 if cn.rd.Buffered() > 0 { 316 internal.Logger.Printf("Conn has unread data") 317 p.Remove(cn, BadConnError{}) 318 return 319 } 320 321 if !cn.pooled { 322 p.Remove(cn, nil) 323 return 324 } 325 326 p.connsMu.Lock() 327 p.idleConns = append(p.idleConns, cn) 328 p.idleConnsLen++ 329 p.connsMu.Unlock() 330 p.freeTurn() 331} 332 333func (p *ConnPool) Remove(cn *Conn, reason error) { 334 p.removeConnWithLock(cn) 335 p.freeTurn() 336 _ = p.closeConn(cn) 337} 338 339func (p *ConnPool) CloseConn(cn *Conn) error { 340 p.removeConnWithLock(cn) 341 return p.closeConn(cn) 342} 343 344func (p *ConnPool) removeConnWithLock(cn *Conn) { 345 p.connsMu.Lock() 346 p.removeConn(cn) 347 p.connsMu.Unlock() 348} 349 350func (p *ConnPool) removeConn(cn *Conn) { 351 for i, c := range p.conns { 352 if c == cn { 353 p.conns = append(p.conns[:i], p.conns[i+1:]...) 354 if cn.pooled { 355 p.poolSize-- 356 p.checkMinIdleConns() 357 } 358 return 359 } 360 } 361} 362 363func (p *ConnPool) closeConn(cn *Conn) error { 364 if p.opt.OnClose != nil { 365 _ = p.opt.OnClose(cn) 366 } 367 return cn.Close() 368} 369 370// Len returns total number of connections. 371func (p *ConnPool) Len() int { 372 p.connsMu.Lock() 373 n := len(p.conns) 374 p.connsMu.Unlock() 375 return n 376} 377 378// IdleLen returns number of idle connections. 379func (p *ConnPool) IdleLen() int { 380 p.connsMu.Lock() 381 n := p.idleConnsLen 382 p.connsMu.Unlock() 383 return n 384} 385 386func (p *ConnPool) Stats() *Stats { 387 idleLen := p.IdleLen() 388 return &Stats{ 389 Hits: atomic.LoadUint32(&p.stats.Hits), 390 Misses: atomic.LoadUint32(&p.stats.Misses), 391 Timeouts: atomic.LoadUint32(&p.stats.Timeouts), 392 393 TotalConns: uint32(p.Len()), 394 IdleConns: uint32(idleLen), 395 StaleConns: atomic.LoadUint32(&p.stats.StaleConns), 396 } 397} 398 399func (p *ConnPool) closed() bool { 400 return atomic.LoadUint32(&p._closed) == 1 401} 402 403func (p *ConnPool) Filter(fn func(*Conn) bool) error { 404 var firstErr error 405 p.connsMu.Lock() 406 for _, cn := range p.conns { 407 if fn(cn) { 408 if err := p.closeConn(cn); err != nil && firstErr == nil { 409 firstErr = err 410 } 411 } 412 } 413 p.connsMu.Unlock() 414 return firstErr 415} 416 417func (p *ConnPool) Close() error { 418 if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) { 419 return ErrClosed 420 } 421 close(p.closedCh) 422 423 var firstErr error 424 p.connsMu.Lock() 425 for _, cn := range p.conns { 426 if err := p.closeConn(cn); err != nil && firstErr == nil { 427 firstErr = err 428 } 429 } 430 p.conns = nil 431 p.poolSize = 0 432 p.idleConns = nil 433 p.idleConnsLen = 0 434 p.connsMu.Unlock() 435 436 return firstErr 437} 438 439func (p *ConnPool) reaper(frequency time.Duration) { 440 ticker := time.NewTicker(frequency) 441 defer ticker.Stop() 442 443 for { 444 select { 445 case <-ticker.C: 446 // It is possible that ticker and closedCh arrive together, 447 // and select pseudo-randomly pick ticker case, we double 448 // check here to prevent being executed after closed. 449 if p.closed() { 450 return 451 } 452 _, err := p.ReapStaleConns() 453 if err != nil { 454 internal.Logger.Printf("ReapStaleConns failed: %s", err) 455 continue 456 } 457 case <-p.closedCh: 458 return 459 } 460 } 461} 462 463func (p *ConnPool) ReapStaleConns() (int, error) { 464 var n int 465 for { 466 p.getTurn() 467 468 p.connsMu.Lock() 469 cn := p.reapStaleConn() 470 p.connsMu.Unlock() 471 p.freeTurn() 472 473 if cn != nil { 474 _ = p.closeConn(cn) 475 n++ 476 } else { 477 break 478 } 479 } 480 atomic.AddUint32(&p.stats.StaleConns, uint32(n)) 481 return n, nil 482} 483 484func (p *ConnPool) reapStaleConn() *Conn { 485 if len(p.idleConns) == 0 { 486 return nil 487 } 488 489 cn := p.idleConns[0] 490 if !p.isStaleConn(cn) { 491 return nil 492 } 493 494 p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...) 495 p.idleConnsLen-- 496 p.removeConn(cn) 497 498 return cn 499} 500 501func (p *ConnPool) isStaleConn(cn *Conn) bool { 502 if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 { 503 return false 504 } 505 506 now := time.Now() 507 if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout { 508 return true 509 } 510 if p.opt.MaxConnAge > 0 && now.Sub(cn.createdAt) >= p.opt.MaxConnAge { 511 return true 512 } 513 514 return false 515} 516