1package pool
2
3import (
4	"context"
5	"errors"
6	"net"
7	"sync"
8	"sync/atomic"
9	"time"
10
11	"github.com/go-redis/redis/v8/internal"
12)
13
14var (
15	// ErrClosed performs any operation on the closed client will return this error.
16	ErrClosed = errors.New("redis: client is closed")
17
18	// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
19	ErrPoolTimeout = errors.New("redis: connection pool timeout")
20)
21
22var timers = sync.Pool{
23	New: func() interface{} {
24		t := time.NewTimer(time.Hour)
25		t.Stop()
26		return t
27	},
28}
29
30// Stats contains pool state information and accumulated stats.
31type Stats struct {
32	Hits     uint32 // number of times free connection was found in the pool
33	Misses   uint32 // number of times free connection was NOT found in the pool
34	Timeouts uint32 // number of times a wait timeout occurred
35
36	TotalConns uint32 // number of total connections in the pool
37	IdleConns  uint32 // number of idle connections in the pool
38	StaleConns uint32 // number of stale connections removed from the pool
39}
40
41type Pooler interface {
42	NewConn(context.Context) (*Conn, error)
43	CloseConn(*Conn) error
44
45	Get(context.Context) (*Conn, error)
46	Put(context.Context, *Conn)
47	Remove(context.Context, *Conn, error)
48
49	Len() int
50	IdleLen() int
51	Stats() *Stats
52
53	Close() error
54}
55
56type Options struct {
57	Dialer  func(context.Context) (net.Conn, error)
58	OnClose func(*Conn) error
59
60	PoolFIFO           bool
61	PoolSize           int
62	MinIdleConns       int
63	MaxConnAge         time.Duration
64	PoolTimeout        time.Duration
65	IdleTimeout        time.Duration
66	IdleCheckFrequency time.Duration
67}
68
69type lastDialErrorWrap struct {
70	err error
71}
72
73type ConnPool struct {
74	opt *Options
75
76	dialErrorsNum uint32 // atomic
77
78	lastDialError atomic.Value
79
80	queue chan struct{}
81
82	connsMu      sync.Mutex
83	conns        []*Conn
84	idleConns    []*Conn
85	poolSize     int
86	idleConnsLen int
87
88	stats Stats
89
90	_closed  uint32 // atomic
91	closedCh chan struct{}
92}
93
94var _ Pooler = (*ConnPool)(nil)
95
96func NewConnPool(opt *Options) *ConnPool {
97	p := &ConnPool{
98		opt: opt,
99
100		queue:     make(chan struct{}, opt.PoolSize),
101		conns:     make([]*Conn, 0, opt.PoolSize),
102		idleConns: make([]*Conn, 0, opt.PoolSize),
103		closedCh:  make(chan struct{}),
104	}
105
106	p.connsMu.Lock()
107	p.checkMinIdleConns()
108	p.connsMu.Unlock()
109
110	if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 {
111		go p.reaper(opt.IdleCheckFrequency)
112	}
113
114	return p
115}
116
117func (p *ConnPool) checkMinIdleConns() {
118	if p.opt.MinIdleConns == 0 {
119		return
120	}
121	for p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns {
122		p.poolSize++
123		p.idleConnsLen++
124		go func() {
125			err := p.addIdleConn()
126			if err != nil {
127				p.connsMu.Lock()
128				p.poolSize--
129				p.idleConnsLen--
130				p.connsMu.Unlock()
131			}
132		}()
133	}
134}
135
136func (p *ConnPool) addIdleConn() error {
137	cn, err := p.dialConn(context.TODO(), true)
138	if err != nil {
139		return err
140	}
141
142	p.connsMu.Lock()
143	p.conns = append(p.conns, cn)
144	p.idleConns = append(p.idleConns, cn)
145	p.connsMu.Unlock()
146	return nil
147}
148
149func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
150	return p.newConn(ctx, false)
151}
152
153func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
154	cn, err := p.dialConn(ctx, pooled)
155	if err != nil {
156		return nil, err
157	}
158
159	p.connsMu.Lock()
160	p.conns = append(p.conns, cn)
161	if pooled {
162		// If pool is full remove the cn on next Put.
163		if p.poolSize >= p.opt.PoolSize {
164			cn.pooled = false
165		} else {
166			p.poolSize++
167		}
168	}
169	p.connsMu.Unlock()
170
171	return cn, nil
172}
173
174func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
175	if p.closed() {
176		return nil, ErrClosed
177	}
178
179	if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) {
180		return nil, p.getLastDialError()
181	}
182
183	netConn, err := p.opt.Dialer(ctx)
184	if err != nil {
185		p.setLastDialError(err)
186		if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
187			go p.tryDial()
188		}
189		return nil, err
190	}
191
192	cn := NewConn(netConn)
193	cn.pooled = pooled
194	return cn, nil
195}
196
197func (p *ConnPool) tryDial() {
198	for {
199		if p.closed() {
200			return
201		}
202
203		conn, err := p.opt.Dialer(context.Background())
204		if err != nil {
205			p.setLastDialError(err)
206			time.Sleep(time.Second)
207			continue
208		}
209
210		atomic.StoreUint32(&p.dialErrorsNum, 0)
211		_ = conn.Close()
212		return
213	}
214}
215
216func (p *ConnPool) setLastDialError(err error) {
217	p.lastDialError.Store(&lastDialErrorWrap{err: err})
218}
219
220func (p *ConnPool) getLastDialError() error {
221	err, _ := p.lastDialError.Load().(*lastDialErrorWrap)
222	if err != nil {
223		return err.err
224	}
225	return nil
226}
227
228// Get returns existed connection from the pool or creates a new one.
229func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
230	if p.closed() {
231		return nil, ErrClosed
232	}
233
234	if err := p.waitTurn(ctx); err != nil {
235		return nil, err
236	}
237
238	for {
239		p.connsMu.Lock()
240		cn := p.popIdle()
241		p.connsMu.Unlock()
242
243		if cn == nil {
244			break
245		}
246
247		if p.isStaleConn(cn) {
248			_ = p.CloseConn(cn)
249			continue
250		}
251
252		atomic.AddUint32(&p.stats.Hits, 1)
253		return cn, nil
254	}
255
256	atomic.AddUint32(&p.stats.Misses, 1)
257
258	newcn, err := p.newConn(ctx, true)
259	if err != nil {
260		p.freeTurn()
261		return nil, err
262	}
263
264	return newcn, nil
265}
266
267func (p *ConnPool) getTurn() {
268	p.queue <- struct{}{}
269}
270
271func (p *ConnPool) waitTurn(ctx context.Context) error {
272	select {
273	case <-ctx.Done():
274		return ctx.Err()
275	default:
276	}
277
278	select {
279	case p.queue <- struct{}{}:
280		return nil
281	default:
282	}
283
284	timer := timers.Get().(*time.Timer)
285	timer.Reset(p.opt.PoolTimeout)
286
287	select {
288	case <-ctx.Done():
289		if !timer.Stop() {
290			<-timer.C
291		}
292		timers.Put(timer)
293		return ctx.Err()
294	case p.queue <- struct{}{}:
295		if !timer.Stop() {
296			<-timer.C
297		}
298		timers.Put(timer)
299		return nil
300	case <-timer.C:
301		timers.Put(timer)
302		atomic.AddUint32(&p.stats.Timeouts, 1)
303		return ErrPoolTimeout
304	}
305}
306
307func (p *ConnPool) freeTurn() {
308	<-p.queue
309}
310
311func (p *ConnPool) popIdle() *Conn {
312	n := len(p.idleConns)
313	if n == 0 {
314		return nil
315	}
316
317	var cn *Conn
318	if p.opt.PoolFIFO {
319		cn = p.idleConns[0]
320		copy(p.idleConns, p.idleConns[1:])
321		p.idleConns = p.idleConns[:n-1]
322	} else {
323		idx := n - 1
324		cn = p.idleConns[idx]
325		p.idleConns = p.idleConns[:idx]
326	}
327	p.idleConnsLen--
328	p.checkMinIdleConns()
329	return cn
330}
331
332func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
333	if cn.rd.Buffered() > 0 {
334		internal.Logger.Printf(ctx, "Conn has unread data")
335		p.Remove(ctx, cn, BadConnError{})
336		return
337	}
338
339	if !cn.pooled {
340		p.Remove(ctx, cn, nil)
341		return
342	}
343
344	p.connsMu.Lock()
345	p.idleConns = append(p.idleConns, cn)
346	p.idleConnsLen++
347	p.connsMu.Unlock()
348	p.freeTurn()
349}
350
351func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
352	p.removeConnWithLock(cn)
353	p.freeTurn()
354	_ = p.closeConn(cn)
355}
356
357func (p *ConnPool) CloseConn(cn *Conn) error {
358	p.removeConnWithLock(cn)
359	return p.closeConn(cn)
360}
361
362func (p *ConnPool) removeConnWithLock(cn *Conn) {
363	p.connsMu.Lock()
364	p.removeConn(cn)
365	p.connsMu.Unlock()
366}
367
368func (p *ConnPool) removeConn(cn *Conn) {
369	for i, c := range p.conns {
370		if c == cn {
371			p.conns = append(p.conns[:i], p.conns[i+1:]...)
372			if cn.pooled {
373				p.poolSize--
374				p.checkMinIdleConns()
375			}
376			return
377		}
378	}
379}
380
381func (p *ConnPool) closeConn(cn *Conn) error {
382	if p.opt.OnClose != nil {
383		_ = p.opt.OnClose(cn)
384	}
385	return cn.Close()
386}
387
388// Len returns total number of connections.
389func (p *ConnPool) Len() int {
390	p.connsMu.Lock()
391	n := len(p.conns)
392	p.connsMu.Unlock()
393	return n
394}
395
396// IdleLen returns number of idle connections.
397func (p *ConnPool) IdleLen() int {
398	p.connsMu.Lock()
399	n := p.idleConnsLen
400	p.connsMu.Unlock()
401	return n
402}
403
404func (p *ConnPool) Stats() *Stats {
405	idleLen := p.IdleLen()
406	return &Stats{
407		Hits:     atomic.LoadUint32(&p.stats.Hits),
408		Misses:   atomic.LoadUint32(&p.stats.Misses),
409		Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
410
411		TotalConns: uint32(p.Len()),
412		IdleConns:  uint32(idleLen),
413		StaleConns: atomic.LoadUint32(&p.stats.StaleConns),
414	}
415}
416
417func (p *ConnPool) closed() bool {
418	return atomic.LoadUint32(&p._closed) == 1
419}
420
421func (p *ConnPool) Filter(fn func(*Conn) bool) error {
422	p.connsMu.Lock()
423	defer p.connsMu.Unlock()
424
425	var firstErr error
426	for _, cn := range p.conns {
427		if fn(cn) {
428			if err := p.closeConn(cn); err != nil && firstErr == nil {
429				firstErr = err
430			}
431		}
432	}
433	return firstErr
434}
435
436func (p *ConnPool) Close() error {
437	if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) {
438		return ErrClosed
439	}
440	close(p.closedCh)
441
442	var firstErr error
443	p.connsMu.Lock()
444	for _, cn := range p.conns {
445		if err := p.closeConn(cn); err != nil && firstErr == nil {
446			firstErr = err
447		}
448	}
449	p.conns = nil
450	p.poolSize = 0
451	p.idleConns = nil
452	p.idleConnsLen = 0
453	p.connsMu.Unlock()
454
455	return firstErr
456}
457
458func (p *ConnPool) reaper(frequency time.Duration) {
459	ticker := time.NewTicker(frequency)
460	defer ticker.Stop()
461
462	for {
463		select {
464		case <-ticker.C:
465			// It is possible that ticker and closedCh arrive together,
466			// and select pseudo-randomly pick ticker case, we double
467			// check here to prevent being executed after closed.
468			if p.closed() {
469				return
470			}
471			_, err := p.ReapStaleConns()
472			if err != nil {
473				internal.Logger.Printf(context.Background(), "ReapStaleConns failed: %s", err)
474				continue
475			}
476		case <-p.closedCh:
477			return
478		}
479	}
480}
481
482func (p *ConnPool) ReapStaleConns() (int, error) {
483	var n int
484	for {
485		p.getTurn()
486
487		p.connsMu.Lock()
488		cn := p.reapStaleConn()
489		p.connsMu.Unlock()
490
491		p.freeTurn()
492
493		if cn != nil {
494			_ = p.closeConn(cn)
495			n++
496		} else {
497			break
498		}
499	}
500	atomic.AddUint32(&p.stats.StaleConns, uint32(n))
501	return n, nil
502}
503
504func (p *ConnPool) reapStaleConn() *Conn {
505	if len(p.idleConns) == 0 {
506		return nil
507	}
508
509	cn := p.idleConns[0]
510	if !p.isStaleConn(cn) {
511		return nil
512	}
513
514	p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...)
515	p.idleConnsLen--
516	p.removeConn(cn)
517
518	return cn
519}
520
521func (p *ConnPool) isStaleConn(cn *Conn) bool {
522	if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 {
523		return false
524	}
525
526	now := time.Now()
527	if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout {
528		return true
529	}
530	if p.opt.MaxConnAge > 0 && now.Sub(cn.createdAt) >= p.opt.MaxConnAge {
531		return true
532	}
533
534	return false
535}
536