1package pool 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "sync/atomic" 8) 9 10const ( 11 stateDefault = 0 12 stateInited = 1 13 stateClosed = 2 14) 15 16type BadConnError struct { 17 wrapped error 18} 19 20var _ error = (*BadConnError)(nil) 21 22func (e BadConnError) Error() string { 23 s := "redis: Conn is in a bad state" 24 if e.wrapped != nil { 25 s += ": " + e.wrapped.Error() 26 } 27 return s 28} 29 30func (e BadConnError) Unwrap() error { 31 return e.wrapped 32} 33 34//------------------------------------------------------------------------------ 35 36type StickyConnPool struct { 37 pool Pooler 38 shared int32 // atomic 39 40 state uint32 // atomic 41 ch chan *Conn 42 43 _badConnError atomic.Value 44} 45 46var _ Pooler = (*StickyConnPool)(nil) 47 48func NewStickyConnPool(pool Pooler) *StickyConnPool { 49 p, ok := pool.(*StickyConnPool) 50 if !ok { 51 p = &StickyConnPool{ 52 pool: pool, 53 ch: make(chan *Conn, 1), 54 } 55 } 56 atomic.AddInt32(&p.shared, 1) 57 return p 58} 59 60func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) { 61 return p.pool.NewConn(ctx) 62} 63 64func (p *StickyConnPool) CloseConn(cn *Conn) error { 65 return p.pool.CloseConn(cn) 66} 67 68func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) { 69 // In worst case this races with Close which is not a very common operation. 70 for i := 0; i < 1000; i++ { 71 switch atomic.LoadUint32(&p.state) { 72 case stateDefault: 73 cn, err := p.pool.Get(ctx) 74 if err != nil { 75 return nil, err 76 } 77 if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { 78 return cn, nil 79 } 80 p.pool.Remove(ctx, cn, ErrClosed) 81 case stateInited: 82 if err := p.badConnError(); err != nil { 83 return nil, err 84 } 85 cn, ok := <-p.ch 86 if !ok { 87 return nil, ErrClosed 88 } 89 return cn, nil 90 case stateClosed: 91 return nil, ErrClosed 92 default: 93 panic("not reached") 94 } 95 } 96 return nil, fmt.Errorf("redis: StickyConnPool.Get: infinite loop") 97} 98 99func (p *StickyConnPool) Put(ctx context.Context, cn *Conn) { 100 defer func() { 101 if recover() != nil { 102 p.freeConn(ctx, cn) 103 } 104 }() 105 p.ch <- cn 106} 107 108func (p *StickyConnPool) freeConn(ctx context.Context, cn *Conn) { 109 if err := p.badConnError(); err != nil { 110 p.pool.Remove(ctx, cn, err) 111 } else { 112 p.pool.Put(ctx, cn) 113 } 114} 115 116func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) { 117 defer func() { 118 if recover() != nil { 119 p.pool.Remove(ctx, cn, ErrClosed) 120 } 121 }() 122 p._badConnError.Store(BadConnError{wrapped: reason}) 123 p.ch <- cn 124} 125 126func (p *StickyConnPool) Close() error { 127 if shared := atomic.AddInt32(&p.shared, -1); shared > 0 { 128 return nil 129 } 130 131 for i := 0; i < 1000; i++ { 132 state := atomic.LoadUint32(&p.state) 133 if state == stateClosed { 134 return ErrClosed 135 } 136 if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) { 137 close(p.ch) 138 cn, ok := <-p.ch 139 if ok { 140 p.freeConn(context.TODO(), cn) 141 } 142 return nil 143 } 144 } 145 146 return errors.New("redis: StickyConnPool.Close: infinite loop") 147} 148 149func (p *StickyConnPool) Reset(ctx context.Context) error { 150 if p.badConnError() == nil { 151 return nil 152 } 153 154 select { 155 case cn, ok := <-p.ch: 156 if !ok { 157 return ErrClosed 158 } 159 p.pool.Remove(ctx, cn, ErrClosed) 160 p._badConnError.Store(BadConnError{wrapped: nil}) 161 default: 162 return errors.New("redis: StickyConnPool does not have a Conn") 163 } 164 165 if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) { 166 state := atomic.LoadUint32(&p.state) 167 return fmt.Errorf("redis: invalid StickyConnPool state: %d", state) 168 } 169 170 return nil 171} 172 173func (p *StickyConnPool) badConnError() error { 174 if v := p._badConnError.Load(); v != nil { 175 err := v.(BadConnError) 176 if err.wrapped != nil { 177 return err 178 } 179 } 180 return nil 181} 182 183func (p *StickyConnPool) Len() int { 184 switch atomic.LoadUint32(&p.state) { 185 case stateDefault: 186 return 0 187 case stateInited: 188 return 1 189 case stateClosed: 190 return 0 191 default: 192 panic("not reached") 193 } 194} 195 196func (p *StickyConnPool) IdleLen() int { 197 return len(p.ch) 198} 199 200func (p *StickyConnPool) Stats() *Stats { 201 return &Stats{} 202} 203