1package pool
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"sync/atomic"
8)
9
10const (
11	stateDefault = 0
12	stateInited  = 1
13	stateClosed  = 2
14)
15
16type BadConnError struct {
17	wrapped error
18}
19
20var _ error = (*BadConnError)(nil)
21
22func (e BadConnError) Error() string {
23	s := "redis: Conn is in a bad state"
24	if e.wrapped != nil {
25		s += ": " + e.wrapped.Error()
26	}
27	return s
28}
29
30func (e BadConnError) Unwrap() error {
31	return e.wrapped
32}
33
34//------------------------------------------------------------------------------
35
36type StickyConnPool struct {
37	pool   Pooler
38	shared int32 // atomic
39
40	state uint32 // atomic
41	ch    chan *Conn
42
43	_badConnError atomic.Value
44}
45
46var _ Pooler = (*StickyConnPool)(nil)
47
48func NewStickyConnPool(pool Pooler) *StickyConnPool {
49	p, ok := pool.(*StickyConnPool)
50	if !ok {
51		p = &StickyConnPool{
52			pool: pool,
53			ch:   make(chan *Conn, 1),
54		}
55	}
56	atomic.AddInt32(&p.shared, 1)
57	return p
58}
59
60func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) {
61	return p.pool.NewConn(ctx)
62}
63
64func (p *StickyConnPool) CloseConn(cn *Conn) error {
65	return p.pool.CloseConn(cn)
66}
67
68func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) {
69	// In worst case this races with Close which is not a very common operation.
70	for i := 0; i < 1000; i++ {
71		switch atomic.LoadUint32(&p.state) {
72		case stateDefault:
73			cn, err := p.pool.Get(ctx)
74			if err != nil {
75				return nil, err
76			}
77			if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {
78				return cn, nil
79			}
80			p.pool.Remove(ctx, cn, ErrClosed)
81		case stateInited:
82			if err := p.badConnError(); err != nil {
83				return nil, err
84			}
85			cn, ok := <-p.ch
86			if !ok {
87				return nil, ErrClosed
88			}
89			return cn, nil
90		case stateClosed:
91			return nil, ErrClosed
92		default:
93			panic("not reached")
94		}
95	}
96	return nil, fmt.Errorf("redis: StickyConnPool.Get: infinite loop")
97}
98
99func (p *StickyConnPool) Put(ctx context.Context, cn *Conn) {
100	defer func() {
101		if recover() != nil {
102			p.freeConn(ctx, cn)
103		}
104	}()
105	p.ch <- cn
106}
107
108func (p *StickyConnPool) freeConn(ctx context.Context, cn *Conn) {
109	if err := p.badConnError(); err != nil {
110		p.pool.Remove(ctx, cn, err)
111	} else {
112		p.pool.Put(ctx, cn)
113	}
114}
115
116func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
117	defer func() {
118		if recover() != nil {
119			p.pool.Remove(ctx, cn, ErrClosed)
120		}
121	}()
122	p._badConnError.Store(BadConnError{wrapped: reason})
123	p.ch <- cn
124}
125
126func (p *StickyConnPool) Close() error {
127	if shared := atomic.AddInt32(&p.shared, -1); shared > 0 {
128		return nil
129	}
130
131	for i := 0; i < 1000; i++ {
132		state := atomic.LoadUint32(&p.state)
133		if state == stateClosed {
134			return ErrClosed
135		}
136		if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) {
137			close(p.ch)
138			cn, ok := <-p.ch
139			if ok {
140				p.freeConn(context.TODO(), cn)
141			}
142			return nil
143		}
144	}
145
146	return errors.New("redis: StickyConnPool.Close: infinite loop")
147}
148
149func (p *StickyConnPool) Reset(ctx context.Context) error {
150	if p.badConnError() == nil {
151		return nil
152	}
153
154	select {
155	case cn, ok := <-p.ch:
156		if !ok {
157			return ErrClosed
158		}
159		p.pool.Remove(ctx, cn, ErrClosed)
160		p._badConnError.Store(BadConnError{wrapped: nil})
161	default:
162		return errors.New("redis: StickyConnPool does not have a Conn")
163	}
164
165	if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) {
166		state := atomic.LoadUint32(&p.state)
167		return fmt.Errorf("redis: invalid StickyConnPool state: %d", state)
168	}
169
170	return nil
171}
172
173func (p *StickyConnPool) badConnError() error {
174	if v := p._badConnError.Load(); v != nil {
175		err := v.(BadConnError)
176		if err.wrapped != nil {
177			return err
178		}
179	}
180	return nil
181}
182
183func (p *StickyConnPool) Len() int {
184	switch atomic.LoadUint32(&p.state) {
185	case stateDefault:
186		return 0
187	case stateInited:
188		return 1
189	case stateClosed:
190		return 0
191	default:
192		panic("not reached")
193	}
194}
195
196func (p *StickyConnPool) IdleLen() int {
197	return len(p.ch)
198}
199
200func (p *StickyConnPool) Stats() *Stats {
201	return &Stats{}
202}
203