1// Copyright 2012 Gary Burd
2//
3// Licensed under the Apache License, Version 2.0 (the "License"): you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations
13// under the License.
14
15package redis
16
17import (
18	"bytes"
19	"container/list"
20	"crypto/rand"
21	"crypto/sha1"
22	"errors"
23	"io"
24	"strconv"
25	"sync"
26	"time"
27
28	"github.com/garyburd/redigo/internal"
29)
30
31var nowFunc = time.Now // for testing
32
33// ErrPoolExhausted is returned from a pool connection method (Do, Send,
34// Receive, Flush, Err) when the maximum number of database connections in the
35// pool has been reached.
36var ErrPoolExhausted = errors.New("redigo: connection pool exhausted")
37
38var (
39	errPoolClosed = errors.New("redigo: connection pool closed")
40	errConnClosed = errors.New("redigo: connection closed")
41)
42
43// Pool maintains a pool of connections. The application calls the Get method
44// to get a connection from the pool and the connection's Close method to
45// return the connection's resources to the pool.
46//
47// The following example shows how to use a pool in a web application. The
48// application creates a pool at application startup and makes it available to
49// request handlers using a global variable.
50//
51//  func newPool(server, password string) *redis.Pool {
52//      return &redis.Pool{
53//          MaxIdle: 3,
54//          IdleTimeout: 240 * time.Second,
55//          Dial: func () (redis.Conn, error) {
56//              c, err := redis.Dial("tcp", server)
57//              if err != nil {
58//                  return nil, err
59//              }
60//              if _, err := c.Do("AUTH", password); err != nil {
61//                  c.Close()
62//                  return nil, err
63//              }
64//              return c, err
65//          },
66//          TestOnBorrow: func(c redis.Conn, t time.Time) error {
67//              _, err := c.Do("PING")
68//              return err
69//          },
70//      }
71//  }
72//
73//  var (
74//      pool *redis.Pool
75//      redisServer = flag.String("redisServer", ":6379", "")
76//      redisPassword = flag.String("redisPassword", "", "")
77//  )
78//
79//  func main() {
80//      flag.Parse()
81//      pool = newPool(*redisServer, *redisPassword)
82//      ...
83//  }
84//
85// A request handler gets a connection from the pool and closes the connection
86// when the handler is done:
87//
88//  func serveHome(w http.ResponseWriter, r *http.Request) {
89//      conn := pool.Get()
90//      defer conn.Close()
91//      ....
92//  }
93//
94type Pool struct {
95
96	// Dial is an application supplied function for creating and configuring a
97	// connection
98	Dial func() (Conn, error)
99
100	// TestOnBorrow is an optional application supplied function for checking
101	// the health of an idle connection before the connection is used again by
102	// the application. Argument t is the time that the connection was returned
103	// to the pool. If the function returns an error, then the connection is
104	// closed.
105	TestOnBorrow func(c Conn, t time.Time) error
106
107	// Maximum number of idle connections in the pool.
108	MaxIdle int
109
110	// Maximum number of connections allocated by the pool at a given time.
111	// When zero, there is no limit on the number of connections in the pool.
112	MaxActive int
113
114	// Close connections after remaining idle for this duration. If the value
115	// is zero, then idle connections are not closed. Applications should set
116	// the timeout to a value less than the server's timeout.
117	IdleTimeout time.Duration
118
119	// If Wait is true and the pool is at the MaxIdle limit, then Get() waits
120	// for a connection to be returned to the pool before returning.
121	Wait bool
122
123	// mu protects fields defined below.
124	mu     sync.Mutex
125	cond   *sync.Cond
126	closed bool
127	active int
128
129	// Stack of idleConn with most recently used at the front.
130	idle list.List
131}
132
133type idleConn struct {
134	c Conn
135	t time.Time
136}
137
138// NewPool creates a new pool. This function is deprecated. Applications should
139// initialize the Pool fields directly as shown in example.
140func NewPool(newFn func() (Conn, error), maxIdle int) *Pool {
141	return &Pool{Dial: newFn, MaxIdle: maxIdle}
142}
143
144// Get gets a connection. The application must close the returned connection.
145// This method always returns a valid connection so that applications can defer
146// error handling to the first use of the connection. If there is an error
147// getting an underlying connection, then the connection Err, Do, Send, Flush
148// and Receive methods return that error.
149func (p *Pool) Get() Conn {
150	c, err := p.get()
151	if err != nil {
152		return errorConnection{err}
153	}
154	return &pooledConnection{p: p, c: c}
155}
156
157// ActiveCount returns the number of active connections in the pool.
158func (p *Pool) ActiveCount() int {
159	p.mu.Lock()
160	active := p.active
161	p.mu.Unlock()
162	return active
163}
164
165// Close releases the resources used by the pool.
166func (p *Pool) Close() error {
167	p.mu.Lock()
168	idle := p.idle
169	p.idle.Init()
170	p.closed = true
171	p.active -= idle.Len()
172	if p.cond != nil {
173		p.cond.Broadcast()
174	}
175	p.mu.Unlock()
176	for e := idle.Front(); e != nil; e = e.Next() {
177		e.Value.(idleConn).c.Close()
178	}
179	return nil
180}
181
182// release decrements the active count and signals waiters. The caller must
183// hold p.mu during the call.
184func (p *Pool) release() {
185	p.active -= 1
186	if p.cond != nil {
187		p.cond.Signal()
188	}
189}
190
191// get prunes stale connections and returns a connection from the idle list or
192// creates a new connection.
193func (p *Pool) get() (Conn, error) {
194	p.mu.Lock()
195
196	// Prune stale connections.
197
198	if timeout := p.IdleTimeout; timeout > 0 {
199		for i, n := 0, p.idle.Len(); i < n; i++ {
200			e := p.idle.Back()
201			if e == nil {
202				break
203			}
204			ic := e.Value.(idleConn)
205			if ic.t.Add(timeout).After(nowFunc()) {
206				break
207			}
208			p.idle.Remove(e)
209			p.release()
210			p.mu.Unlock()
211			ic.c.Close()
212			p.mu.Lock()
213		}
214	}
215
216	for {
217
218		// Get idle connection.
219
220		for i, n := 0, p.idle.Len(); i < n; i++ {
221			e := p.idle.Front()
222			if e == nil {
223				break
224			}
225			ic := e.Value.(idleConn)
226			p.idle.Remove(e)
227			test := p.TestOnBorrow
228			p.mu.Unlock()
229			if test == nil || test(ic.c, ic.t) == nil {
230				return ic.c, nil
231			}
232			ic.c.Close()
233			p.mu.Lock()
234			p.release()
235		}
236
237		// Check for pool closed before dialing a new connection.
238
239		if p.closed {
240			p.mu.Unlock()
241			return nil, errors.New("redigo: get on closed pool")
242		}
243
244		// Dial new connection if under limit.
245
246		if p.MaxActive == 0 || p.active < p.MaxActive {
247			dial := p.Dial
248			p.active += 1
249			p.mu.Unlock()
250			c, err := dial()
251			if err != nil {
252				p.mu.Lock()
253				p.release()
254				p.mu.Unlock()
255				c = nil
256			}
257			return c, err
258		}
259
260		if !p.Wait {
261			p.mu.Unlock()
262			return nil, ErrPoolExhausted
263		}
264
265		if p.cond == nil {
266			p.cond = sync.NewCond(&p.mu)
267		}
268		p.cond.Wait()
269	}
270}
271
272func (p *Pool) put(c Conn, forceClose bool) error {
273	err := c.Err()
274	p.mu.Lock()
275	if !p.closed && err == nil && !forceClose {
276		p.idle.PushFront(idleConn{t: nowFunc(), c: c})
277		if p.idle.Len() > p.MaxIdle {
278			c = p.idle.Remove(p.idle.Back()).(idleConn).c
279		} else {
280			c = nil
281		}
282	}
283
284	if c == nil {
285		if p.cond != nil {
286			p.cond.Signal()
287		}
288		p.mu.Unlock()
289		return nil
290	}
291
292	p.release()
293	p.mu.Unlock()
294	return c.Close()
295}
296
297type pooledConnection struct {
298	p     *Pool
299	c     Conn
300	state int
301}
302
303var (
304	sentinel     []byte
305	sentinelOnce sync.Once
306)
307
308func initSentinel() {
309	p := make([]byte, 64)
310	if _, err := rand.Read(p); err == nil {
311		sentinel = p
312	} else {
313		h := sha1.New()
314		io.WriteString(h, "Oops, rand failed. Use time instead.")
315		io.WriteString(h, strconv.FormatInt(time.Now().UnixNano(), 10))
316		sentinel = h.Sum(nil)
317	}
318}
319
320func (pc *pooledConnection) Close() error {
321	c := pc.c
322	if _, ok := c.(errorConnection); ok {
323		return nil
324	}
325	pc.c = errorConnection{errConnClosed}
326
327	if pc.state&internal.MultiState != 0 {
328		c.Send("DISCARD")
329		pc.state &^= (internal.MultiState | internal.WatchState)
330	} else if pc.state&internal.WatchState != 0 {
331		c.Send("UNWATCH")
332		pc.state &^= internal.WatchState
333	}
334	if pc.state&internal.SubscribeState != 0 {
335		c.Send("UNSUBSCRIBE")
336		c.Send("PUNSUBSCRIBE")
337		// To detect the end of the message stream, ask the server to echo
338		// a sentinel value and read until we see that value.
339		sentinelOnce.Do(initSentinel)
340		c.Send("ECHO", sentinel)
341		c.Flush()
342		for {
343			p, err := c.Receive()
344			if err != nil {
345				break
346			}
347			if p, ok := p.([]byte); ok && bytes.Equal(p, sentinel) {
348				pc.state &^= internal.SubscribeState
349				break
350			}
351		}
352	}
353	c.Do("")
354	pc.p.put(c, pc.state != 0)
355	return nil
356}
357
358func (pc *pooledConnection) Err() error {
359	return pc.c.Err()
360}
361
362func (pc *pooledConnection) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
363	ci := internal.LookupCommandInfo(commandName)
364	pc.state = (pc.state | ci.Set) &^ ci.Clear
365	return pc.c.Do(commandName, args...)
366}
367
368func (pc *pooledConnection) Send(commandName string, args ...interface{}) error {
369	ci := internal.LookupCommandInfo(commandName)
370	pc.state = (pc.state | ci.Set) &^ ci.Clear
371	return pc.c.Send(commandName, args...)
372}
373
374func (pc *pooledConnection) Flush() error {
375	return pc.c.Flush()
376}
377
378func (pc *pooledConnection) Receive() (reply interface{}, err error) {
379	return pc.c.Receive()
380}
381
382type errorConnection struct{ err error }
383
384func (ec errorConnection) Do(string, ...interface{}) (interface{}, error) { return nil, ec.err }
385func (ec errorConnection) Send(string, ...interface{}) error              { return ec.err }
386func (ec errorConnection) Err() error                                     { return ec.err }
387func (ec errorConnection) Close() error                                   { return ec.err }
388func (ec errorConnection) Flush() error                                   { return ec.err }
389func (ec errorConnection) Receive() (interface{}, error)                  { return nil, ec.err }
390