1package pool
2
3import (
4	"context"
5	"errors"
6	"net"
7	"sync"
8	"sync/atomic"
9	"time"
10
11	"github.com/go-redis/redis/v7/internal"
12)
13
14var ErrClosed = errors.New("redis: client is closed")
15var ErrPoolTimeout = errors.New("redis: connection pool timeout")
16
17var timers = sync.Pool{
18	New: func() interface{} {
19		t := time.NewTimer(time.Hour)
20		t.Stop()
21		return t
22	},
23}
24
25// Stats contains pool state information and accumulated stats.
26type Stats struct {
27	Hits     uint32 // number of times free connection was found in the pool
28	Misses   uint32 // number of times free connection was NOT found in the pool
29	Timeouts uint32 // number of times a wait timeout occurred
30
31	TotalConns uint32 // number of total connections in the pool
32	IdleConns  uint32 // number of idle connections in the pool
33	StaleConns uint32 // number of stale connections removed from the pool
34}
35
36type Pooler interface {
37	NewConn(context.Context) (*Conn, error)
38	CloseConn(*Conn) error
39
40	Get(context.Context) (*Conn, error)
41	Put(*Conn)
42	Remove(*Conn, error)
43
44	Len() int
45	IdleLen() int
46	Stats() *Stats
47
48	Close() error
49}
50
51type Options struct {
52	Dialer  func(context.Context) (net.Conn, error)
53	OnClose func(*Conn) error
54
55	PoolSize           int
56	MinIdleConns       int
57	MaxConnAge         time.Duration
58	PoolTimeout        time.Duration
59	IdleTimeout        time.Duration
60	IdleCheckFrequency time.Duration
61}
62
63type ConnPool struct {
64	opt *Options
65
66	dialErrorsNum uint32 // atomic
67
68	lastDialErrorMu sync.RWMutex
69	lastDialError   error
70
71	queue chan struct{}
72
73	connsMu      sync.Mutex
74	conns        []*Conn
75	idleConns    []*Conn
76	poolSize     int
77	idleConnsLen int
78
79	stats Stats
80
81	_closed  uint32 // atomic
82	closedCh chan struct{}
83}
84
85var _ Pooler = (*ConnPool)(nil)
86
87func NewConnPool(opt *Options) *ConnPool {
88	p := &ConnPool{
89		opt: opt,
90
91		queue:     make(chan struct{}, opt.PoolSize),
92		conns:     make([]*Conn, 0, opt.PoolSize),
93		idleConns: make([]*Conn, 0, opt.PoolSize),
94		closedCh:  make(chan struct{}),
95	}
96
97	p.checkMinIdleConns()
98
99	if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 {
100		go p.reaper(opt.IdleCheckFrequency)
101	}
102
103	return p
104}
105
106func (p *ConnPool) checkMinIdleConns() {
107	if p.opt.MinIdleConns == 0 {
108		return
109	}
110	for p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns {
111		p.poolSize++
112		p.idleConnsLen++
113		go func() {
114			err := p.addIdleConn()
115			if err != nil {
116				p.connsMu.Lock()
117				p.poolSize--
118				p.idleConnsLen--
119				p.connsMu.Unlock()
120			}
121		}()
122	}
123}
124
125func (p *ConnPool) addIdleConn() error {
126	cn, err := p.dialConn(context.TODO(), true)
127	if err != nil {
128		return err
129	}
130
131	p.connsMu.Lock()
132	p.conns = append(p.conns, cn)
133	p.idleConns = append(p.idleConns, cn)
134	p.connsMu.Unlock()
135	return nil
136}
137
138func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
139	return p.newConn(ctx, false)
140}
141
142func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
143	cn, err := p.dialConn(ctx, pooled)
144	if err != nil {
145		return nil, err
146	}
147
148	p.connsMu.Lock()
149	p.conns = append(p.conns, cn)
150	if pooled {
151		// If pool is full remove the cn on next Put.
152		if p.poolSize >= p.opt.PoolSize {
153			cn.pooled = false
154		} else {
155			p.poolSize++
156		}
157	}
158	p.connsMu.Unlock()
159	return cn, nil
160}
161
162func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
163	if p.closed() {
164		return nil, ErrClosed
165	}
166
167	if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) {
168		return nil, p.getLastDialError()
169	}
170
171	netConn, err := p.opt.Dialer(ctx)
172	if err != nil {
173		p.setLastDialError(err)
174		if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
175			go p.tryDial()
176		}
177		return nil, err
178	}
179
180	cn := NewConn(netConn)
181	cn.pooled = pooled
182	return cn, nil
183}
184
185func (p *ConnPool) tryDial() {
186	for {
187		if p.closed() {
188			return
189		}
190
191		conn, err := p.opt.Dialer(context.Background())
192		if err != nil {
193			p.setLastDialError(err)
194			time.Sleep(time.Second)
195			continue
196		}
197
198		atomic.StoreUint32(&p.dialErrorsNum, 0)
199		_ = conn.Close()
200		return
201	}
202}
203
204func (p *ConnPool) setLastDialError(err error) {
205	p.lastDialErrorMu.Lock()
206	p.lastDialError = err
207	p.lastDialErrorMu.Unlock()
208}
209
210func (p *ConnPool) getLastDialError() error {
211	p.lastDialErrorMu.RLock()
212	err := p.lastDialError
213	p.lastDialErrorMu.RUnlock()
214	return err
215}
216
217// Get returns existed connection from the pool or creates a new one.
218func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
219	if p.closed() {
220		return nil, ErrClosed
221	}
222
223	err := p.waitTurn(ctx)
224	if err != nil {
225		return nil, err
226	}
227
228	for {
229		p.connsMu.Lock()
230		cn := p.popIdle()
231		p.connsMu.Unlock()
232
233		if cn == nil {
234			break
235		}
236
237		if p.isStaleConn(cn) {
238			_ = p.CloseConn(cn)
239			continue
240		}
241
242		atomic.AddUint32(&p.stats.Hits, 1)
243		return cn, nil
244	}
245
246	atomic.AddUint32(&p.stats.Misses, 1)
247
248	newcn, err := p.newConn(ctx, true)
249	if err != nil {
250		p.freeTurn()
251		return nil, err
252	}
253
254	return newcn, nil
255}
256
257func (p *ConnPool) getTurn() {
258	p.queue <- struct{}{}
259}
260
261func (p *ConnPool) waitTurn(ctx context.Context) error {
262	select {
263	case <-ctx.Done():
264		return ctx.Err()
265	default:
266	}
267
268	select {
269	case p.queue <- struct{}{}:
270		return nil
271	default:
272	}
273
274	timer := timers.Get().(*time.Timer)
275	timer.Reset(p.opt.PoolTimeout)
276
277	select {
278	case <-ctx.Done():
279		if !timer.Stop() {
280			<-timer.C
281		}
282		timers.Put(timer)
283		return ctx.Err()
284	case p.queue <- struct{}{}:
285		if !timer.Stop() {
286			<-timer.C
287		}
288		timers.Put(timer)
289		return nil
290	case <-timer.C:
291		timers.Put(timer)
292		atomic.AddUint32(&p.stats.Timeouts, 1)
293		return ErrPoolTimeout
294	}
295}
296
297func (p *ConnPool) freeTurn() {
298	<-p.queue
299}
300
301func (p *ConnPool) popIdle() *Conn {
302	if len(p.idleConns) == 0 {
303		return nil
304	}
305
306	idx := len(p.idleConns) - 1
307	cn := p.idleConns[idx]
308	p.idleConns = p.idleConns[:idx]
309	p.idleConnsLen--
310	p.checkMinIdleConns()
311	return cn
312}
313
314func (p *ConnPool) Put(cn *Conn) {
315	if cn.rd.Buffered() > 0 {
316		internal.Logger.Printf("Conn has unread data")
317		p.Remove(cn, BadConnError{})
318		return
319	}
320
321	if !cn.pooled {
322		p.Remove(cn, nil)
323		return
324	}
325
326	p.connsMu.Lock()
327	p.idleConns = append(p.idleConns, cn)
328	p.idleConnsLen++
329	p.connsMu.Unlock()
330	p.freeTurn()
331}
332
333func (p *ConnPool) Remove(cn *Conn, reason error) {
334	p.removeConnWithLock(cn)
335	p.freeTurn()
336	_ = p.closeConn(cn)
337}
338
339func (p *ConnPool) CloseConn(cn *Conn) error {
340	p.removeConnWithLock(cn)
341	return p.closeConn(cn)
342}
343
344func (p *ConnPool) removeConnWithLock(cn *Conn) {
345	p.connsMu.Lock()
346	p.removeConn(cn)
347	p.connsMu.Unlock()
348}
349
350func (p *ConnPool) removeConn(cn *Conn) {
351	for i, c := range p.conns {
352		if c == cn {
353			p.conns = append(p.conns[:i], p.conns[i+1:]...)
354			if cn.pooled {
355				p.poolSize--
356				p.checkMinIdleConns()
357			}
358			return
359		}
360	}
361}
362
363func (p *ConnPool) closeConn(cn *Conn) error {
364	if p.opt.OnClose != nil {
365		_ = p.opt.OnClose(cn)
366	}
367	return cn.Close()
368}
369
370// Len returns total number of connections.
371func (p *ConnPool) Len() int {
372	p.connsMu.Lock()
373	n := len(p.conns)
374	p.connsMu.Unlock()
375	return n
376}
377
378// IdleLen returns number of idle connections.
379func (p *ConnPool) IdleLen() int {
380	p.connsMu.Lock()
381	n := p.idleConnsLen
382	p.connsMu.Unlock()
383	return n
384}
385
386func (p *ConnPool) Stats() *Stats {
387	idleLen := p.IdleLen()
388	return &Stats{
389		Hits:     atomic.LoadUint32(&p.stats.Hits),
390		Misses:   atomic.LoadUint32(&p.stats.Misses),
391		Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
392
393		TotalConns: uint32(p.Len()),
394		IdleConns:  uint32(idleLen),
395		StaleConns: atomic.LoadUint32(&p.stats.StaleConns),
396	}
397}
398
399func (p *ConnPool) closed() bool {
400	return atomic.LoadUint32(&p._closed) == 1
401}
402
403func (p *ConnPool) Filter(fn func(*Conn) bool) error {
404	var firstErr error
405	p.connsMu.Lock()
406	for _, cn := range p.conns {
407		if fn(cn) {
408			if err := p.closeConn(cn); err != nil && firstErr == nil {
409				firstErr = err
410			}
411		}
412	}
413	p.connsMu.Unlock()
414	return firstErr
415}
416
417func (p *ConnPool) Close() error {
418	if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) {
419		return ErrClosed
420	}
421	close(p.closedCh)
422
423	var firstErr error
424	p.connsMu.Lock()
425	for _, cn := range p.conns {
426		if err := p.closeConn(cn); err != nil && firstErr == nil {
427			firstErr = err
428		}
429	}
430	p.conns = nil
431	p.poolSize = 0
432	p.idleConns = nil
433	p.idleConnsLen = 0
434	p.connsMu.Unlock()
435
436	return firstErr
437}
438
439func (p *ConnPool) reaper(frequency time.Duration) {
440	ticker := time.NewTicker(frequency)
441	defer ticker.Stop()
442
443	for {
444		select {
445		case <-ticker.C:
446			// It is possible that ticker and closedCh arrive together,
447			// and select pseudo-randomly pick ticker case, we double
448			// check here to prevent being executed after closed.
449			if p.closed() {
450				return
451			}
452			_, err := p.ReapStaleConns()
453			if err != nil {
454				internal.Logger.Printf("ReapStaleConns failed: %s", err)
455				continue
456			}
457		case <-p.closedCh:
458			return
459		}
460	}
461}
462
463func (p *ConnPool) ReapStaleConns() (int, error) {
464	var n int
465	for {
466		p.getTurn()
467
468		p.connsMu.Lock()
469		cn := p.reapStaleConn()
470		p.connsMu.Unlock()
471		p.freeTurn()
472
473		if cn != nil {
474			_ = p.closeConn(cn)
475			n++
476		} else {
477			break
478		}
479	}
480	atomic.AddUint32(&p.stats.StaleConns, uint32(n))
481	return n, nil
482}
483
484func (p *ConnPool) reapStaleConn() *Conn {
485	if len(p.idleConns) == 0 {
486		return nil
487	}
488
489	cn := p.idleConns[0]
490	if !p.isStaleConn(cn) {
491		return nil
492	}
493
494	p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...)
495	p.idleConnsLen--
496	p.removeConn(cn)
497
498	return cn
499}
500
501func (p *ConnPool) isStaleConn(cn *Conn) bool {
502	if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 {
503		return false
504	}
505
506	now := time.Now()
507	if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout {
508		return true
509	}
510	if p.opt.MaxConnAge > 0 && now.Sub(cn.createdAt) >= p.opt.MaxConnAge {
511		return true
512	}
513
514	return false
515}
516