1// Copyright 2012 Gary Burd
2//
3// Licensed under the Apache License, Version 2.0 (the "License"): you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations
13// under the License.
14
15package redis
16
17import (
18	"bytes"
19	"context"
20	"crypto/rand"
21	"crypto/sha1"
22	"errors"
23	"io"
24	"strconv"
25	"sync"
26	"time"
27)
28
29var (
30	_ ConnWithTimeout = (*activeConn)(nil)
31	_ ConnWithTimeout = (*errorConn)(nil)
32)
33
34var nowFunc = time.Now // for testing
35
36// ErrPoolExhausted is returned from a pool connection method (Do, Send,
37// Receive, Flush, Err) when the maximum number of database connections in the
38// pool has been reached.
39var ErrPoolExhausted = errors.New("redigo: connection pool exhausted")
40
41var (
42	errConnClosed = errors.New("redigo: connection closed")
43)
44
45// Pool maintains a pool of connections. The application calls the Get method
46// to get a connection from the pool and the connection's Close method to
47// return the connection's resources to the pool.
48//
49// The following example shows how to use a pool in a web application. The
50// application creates a pool at application startup and makes it available to
51// request handlers using a package level variable. The pool configuration used
52// here is an example, not a recommendation.
53//
54//  func newPool(addr string) *redis.Pool {
55//    return &redis.Pool{
56//      MaxIdle: 3,
57//      IdleTimeout: 240 * time.Second,
58//      // Dial or DialContext must be set. When both are set, DialContext takes precedence over Dial.
59//      Dial: func () (redis.Conn, error) { return redis.Dial("tcp", addr) },
60//    }
61//  }
62//
63//  var (
64//    pool *redis.Pool
65//    redisServer = flag.String("redisServer", ":6379", "")
66//  )
67//
68//  func main() {
69//    flag.Parse()
70//    pool = newPool(*redisServer)
71//    ...
72//  }
73//
74// A request handler gets a connection from the pool and closes the connection
75// when the handler is done:
76//
77//  func serveHome(w http.ResponseWriter, r *http.Request) {
78//      conn := pool.Get()
79//      defer conn.Close()
80//      ...
81//  }
82//
83// Use the Dial function to authenticate connections with the AUTH command or
84// select a database with the SELECT command:
85//
86//  pool := &redis.Pool{
87//    // Other pool configuration not shown in this example.
88//    Dial: func () (redis.Conn, error) {
89//      c, err := redis.Dial("tcp", server)
90//      if err != nil {
91//        return nil, err
92//      }
93//      if _, err := c.Do("AUTH", password); err != nil {
94//        c.Close()
95//        return nil, err
96//      }
97//      if _, err := c.Do("SELECT", db); err != nil {
98//        c.Close()
99//        return nil, err
100//      }
101//      return c, nil
102//    },
103//  }
104//
105// Use the TestOnBorrow function to check the health of an idle connection
106// before the connection is returned to the application. This example PINGs
107// connections that have been idle more than a minute:
108//
109//  pool := &redis.Pool{
110//    // Other pool configuration not shown in this example.
111//    TestOnBorrow: func(c redis.Conn, t time.Time) error {
112//      if time.Since(t) < time.Minute {
113//        return nil
114//      }
115//      _, err := c.Do("PING")
116//      return err
117//    },
118//  }
119//
120type Pool struct {
121	// Dial is an application supplied function for creating and configuring a
122	// connection.
123	//
124	// The connection returned from Dial must not be in a special state
125	// (subscribed to pubsub channel, transaction started, ...).
126	Dial func() (Conn, error)
127
128	// DialContext is an application supplied function for creating and configuring a
129	// connection with the given context.
130	//
131	// The connection returned from Dial must not be in a special state
132	// (subscribed to pubsub channel, transaction started, ...).
133	DialContext func(ctx context.Context) (Conn, error)
134
135	// TestOnBorrow is an optional application supplied function for checking
136	// the health of an idle connection before the connection is used again by
137	// the application. Argument t is the time that the connection was returned
138	// to the pool. If the function returns an error, then the connection is
139	// closed.
140	TestOnBorrow func(c Conn, t time.Time) error
141
142	// Maximum number of idle connections in the pool.
143	MaxIdle int
144
145	// Maximum number of connections allocated by the pool at a given time.
146	// When zero, there is no limit on the number of connections in the pool.
147	MaxActive int
148
149	// Close connections after remaining idle for this duration. If the value
150	// is zero, then idle connections are not closed. Applications should set
151	// the timeout to a value less than the server's timeout.
152	IdleTimeout time.Duration
153
154	// If Wait is true and the pool is at the MaxActive limit, then Get() waits
155	// for a connection to be returned to the pool before returning.
156	Wait bool
157
158	// Close connections older than this duration. If the value is zero, then
159	// the pool does not close connections based on age.
160	MaxConnLifetime time.Duration
161
162	mu           sync.Mutex    // mu protects the following fields
163	closed       bool          // set to true when the pool is closed.
164	active       int           // the number of open connections in the pool
165	initOnce     sync.Once     // the init ch once func
166	ch           chan struct{} // limits open connections when p.Wait is true
167	idle         idleList      // idle connections
168	waitCount    int64         // total number of connections waited for.
169	waitDuration time.Duration // total time waited for new connections.
170}
171
172// NewPool creates a new pool.
173//
174// Deprecated: Initialize the Pool directly as shown in the example.
175func NewPool(newFn func() (Conn, error), maxIdle int) *Pool {
176	return &Pool{Dial: newFn, MaxIdle: maxIdle}
177}
178
179// Get gets a connection. The application must close the returned connection.
180// This method always returns a valid connection so that applications can defer
181// error handling to the first use of the connection. If there is an error
182// getting an underlying connection, then the connection Err, Do, Send, Flush
183// and Receive methods return that error.
184func (p *Pool) Get() Conn {
185	// GetContext returns errorConn in the first argument when an error occurs.
186	c, _ := p.GetContext(context.Background())
187	return c
188}
189
190// GetContext gets a connection using the provided context.
191//
192// The provided Context must be non-nil. If the context expires before the
193// connection is complete, an error is returned. Any expiration on the context
194// will not affect the returned connection.
195//
196// If the function completes without error, then the application must close the
197// returned connection.
198func (p *Pool) GetContext(ctx context.Context) (Conn, error) {
199	// Wait until there is a vacant connection in the pool.
200	waited, err := p.waitVacantConn(ctx)
201	if err != nil {
202		return errorConn{err}, err
203	}
204
205	p.mu.Lock()
206
207	if waited > 0 {
208		p.waitCount++
209		p.waitDuration += waited
210	}
211
212	// Prune stale connections at the back of the idle list.
213	if p.IdleTimeout > 0 {
214		n := p.idle.count
215		for i := 0; i < n && p.idle.back != nil && p.idle.back.t.Add(p.IdleTimeout).Before(nowFunc()); i++ {
216			pc := p.idle.back
217			p.idle.popBack()
218			p.mu.Unlock()
219			pc.c.Close()
220			p.mu.Lock()
221			p.active--
222		}
223	}
224
225	// Get idle connection from the front of idle list.
226	for p.idle.front != nil {
227		pc := p.idle.front
228		p.idle.popFront()
229		p.mu.Unlock()
230		if (p.TestOnBorrow == nil || p.TestOnBorrow(pc.c, pc.t) == nil) &&
231			(p.MaxConnLifetime == 0 || nowFunc().Sub(pc.created) < p.MaxConnLifetime) {
232			return &activeConn{p: p, pc: pc}, nil
233		}
234		pc.c.Close()
235		p.mu.Lock()
236		p.active--
237	}
238
239	// Check for pool closed before dialing a new connection.
240	if p.closed {
241		p.mu.Unlock()
242		err := errors.New("redigo: get on closed pool")
243		return errorConn{err}, err
244	}
245
246	// Handle limit for p.Wait == false.
247	if !p.Wait && p.MaxActive > 0 && p.active >= p.MaxActive {
248		p.mu.Unlock()
249		return errorConn{ErrPoolExhausted}, ErrPoolExhausted
250	}
251
252	p.active++
253	p.mu.Unlock()
254	c, err := p.dial(ctx)
255	if err != nil {
256		p.mu.Lock()
257		p.active--
258		if p.ch != nil && !p.closed {
259			p.ch <- struct{}{}
260		}
261		p.mu.Unlock()
262		return errorConn{err}, err
263	}
264	return &activeConn{p: p, pc: &poolConn{c: c, created: nowFunc()}}, nil
265}
266
267// PoolStats contains pool statistics.
268type PoolStats struct {
269	// ActiveCount is the number of connections in the pool. The count includes
270	// idle connections and connections in use.
271	ActiveCount int
272	// IdleCount is the number of idle connections in the pool.
273	IdleCount int
274
275	// WaitCount is the total number of connections waited for.
276	// This value is currently not guaranteed to be 100% accurate.
277	WaitCount int64
278
279	// WaitDuration is the total time blocked waiting for a new connection.
280	// This value is currently not guaranteed to be 100% accurate.
281	WaitDuration time.Duration
282}
283
284// Stats returns pool's statistics.
285func (p *Pool) Stats() PoolStats {
286	p.mu.Lock()
287	stats := PoolStats{
288		ActiveCount:  p.active,
289		IdleCount:    p.idle.count,
290		WaitCount:    p.waitCount,
291		WaitDuration: p.waitDuration,
292	}
293	p.mu.Unlock()
294
295	return stats
296}
297
298// ActiveCount returns the number of connections in the pool. The count
299// includes idle connections and connections in use.
300func (p *Pool) ActiveCount() int {
301	p.mu.Lock()
302	active := p.active
303	p.mu.Unlock()
304	return active
305}
306
307// IdleCount returns the number of idle connections in the pool.
308func (p *Pool) IdleCount() int {
309	p.mu.Lock()
310	idle := p.idle.count
311	p.mu.Unlock()
312	return idle
313}
314
315// Close releases the resources used by the pool.
316func (p *Pool) Close() error {
317	p.mu.Lock()
318	if p.closed {
319		p.mu.Unlock()
320		return nil
321	}
322	p.closed = true
323	p.active -= p.idle.count
324	pc := p.idle.front
325	p.idle.count = 0
326	p.idle.front, p.idle.back = nil, nil
327	if p.ch != nil {
328		close(p.ch)
329	}
330	p.mu.Unlock()
331	for ; pc != nil; pc = pc.next {
332		pc.c.Close()
333	}
334	return nil
335}
336
337func (p *Pool) lazyInit() {
338	p.initOnce.Do(func() {
339		p.ch = make(chan struct{}, p.MaxActive)
340		if p.closed {
341			close(p.ch)
342		} else {
343			for i := 0; i < p.MaxActive; i++ {
344				p.ch <- struct{}{}
345			}
346		}
347	})
348}
349
350// waitVacantConn waits for a vacant connection in pool if waiting
351// is enabled and pool size is limited, otherwise returns instantly.
352// If ctx expires before that, an error is returned.
353//
354// If there were no vacant connection in the pool right away it returns the time spent waiting
355// for that connection to appear in the pool.
356func (p *Pool) waitVacantConn(ctx context.Context) (waited time.Duration, err error) {
357	if !p.Wait || p.MaxActive <= 0 {
358		// No wait or no connection limit.
359		return 0, nil
360	}
361
362	p.lazyInit()
363
364	// wait indicates if we believe it will block so its not 100% accurate
365	// however for stats it should be good enough.
366	wait := len(p.ch) == 0
367	var start time.Time
368	if wait {
369		start = time.Now()
370	}
371
372	select {
373	case <-p.ch:
374		// Additionally check that context hasn't expired while we were waiting,
375		// because `select` picks a random `case` if several of them are "ready".
376		select {
377		case <-ctx.Done():
378			p.ch <- struct{}{}
379			return 0, ctx.Err()
380		default:
381		}
382	case <-ctx.Done():
383		return 0, ctx.Err()
384	}
385
386	if wait {
387		return time.Since(start), nil
388	}
389	return 0, nil
390}
391
392func (p *Pool) dial(ctx context.Context) (Conn, error) {
393	if p.DialContext != nil {
394		return p.DialContext(ctx)
395	}
396	if p.Dial != nil {
397		return p.Dial()
398	}
399	return nil, errors.New("redigo: must pass Dial or DialContext to pool")
400}
401
402func (p *Pool) put(pc *poolConn, forceClose bool) error {
403	p.mu.Lock()
404	if !p.closed && !forceClose {
405		pc.t = nowFunc()
406		p.idle.pushFront(pc)
407		if p.idle.count > p.MaxIdle {
408			pc = p.idle.back
409			p.idle.popBack()
410		} else {
411			pc = nil
412		}
413	}
414
415	if pc != nil {
416		p.mu.Unlock()
417		pc.c.Close()
418		p.mu.Lock()
419		p.active--
420	}
421
422	if p.ch != nil && !p.closed {
423		p.ch <- struct{}{}
424	}
425	p.mu.Unlock()
426	return nil
427}
428
429type activeConn struct {
430	p     *Pool
431	pc    *poolConn
432	state int
433}
434
435var (
436	sentinel     []byte
437	sentinelOnce sync.Once
438)
439
440func initSentinel() {
441	p := make([]byte, 64)
442	if _, err := rand.Read(p); err == nil {
443		sentinel = p
444	} else {
445		h := sha1.New()
446		io.WriteString(h, "Oops, rand failed. Use time instead.")       // nolint: errcheck
447		io.WriteString(h, strconv.FormatInt(time.Now().UnixNano(), 10)) // nolint: errcheck
448		sentinel = h.Sum(nil)
449	}
450}
451
452func (ac *activeConn) firstError(errs ...error) error {
453	for _, err := range errs[:len(errs)-1] {
454		if err != nil {
455			return err
456		}
457	}
458	return errs[len(errs)-1]
459}
460
461func (ac *activeConn) Close() (err error) {
462	pc := ac.pc
463	if pc == nil {
464		return nil
465	}
466	ac.pc = nil
467
468	if ac.state&connectionMultiState != 0 {
469		err = pc.c.Send("DISCARD")
470		ac.state &^= (connectionMultiState | connectionWatchState)
471	} else if ac.state&connectionWatchState != 0 {
472		err = pc.c.Send("UNWATCH")
473		ac.state &^= connectionWatchState
474	}
475	if ac.state&connectionSubscribeState != 0 {
476		err = ac.firstError(err,
477			pc.c.Send("UNSUBSCRIBE"),
478			pc.c.Send("PUNSUBSCRIBE"),
479		)
480		// To detect the end of the message stream, ask the server to echo
481		// a sentinel value and read until we see that value.
482		sentinelOnce.Do(initSentinel)
483		err = ac.firstError(err,
484			pc.c.Send("ECHO", sentinel),
485			pc.c.Flush(),
486		)
487		for {
488			p, err2 := pc.c.Receive()
489			if err2 != nil {
490				err = ac.firstError(err, err2)
491				break
492			}
493			if p, ok := p.([]byte); ok && bytes.Equal(p, sentinel) {
494				ac.state &^= connectionSubscribeState
495				break
496			}
497		}
498	}
499	_, err2 := pc.c.Do("")
500	return ac.firstError(
501		err,
502		err2,
503		ac.p.put(pc, ac.state != 0 || pc.c.Err() != nil),
504	)
505}
506
507func (ac *activeConn) Err() error {
508	pc := ac.pc
509	if pc == nil {
510		return errConnClosed
511	}
512	return pc.c.Err()
513}
514
515func (ac *activeConn) DoContext(ctx context.Context, commandName string, args ...interface{}) (reply interface{}, err error) {
516	pc := ac.pc
517	if pc == nil {
518		return nil, errConnClosed
519	}
520	cwt, ok := pc.c.(ConnWithContext)
521	if !ok {
522		return nil, errContextNotSupported
523	}
524	ci := lookupCommandInfo(commandName)
525	ac.state = (ac.state | ci.Set) &^ ci.Clear
526	return cwt.DoContext(ctx, commandName, args...)
527}
528
529func (ac *activeConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
530	pc := ac.pc
531	if pc == nil {
532		return nil, errConnClosed
533	}
534	ci := lookupCommandInfo(commandName)
535	ac.state = (ac.state | ci.Set) &^ ci.Clear
536	return pc.c.Do(commandName, args...)
537}
538
539func (ac *activeConn) DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (reply interface{}, err error) {
540	pc := ac.pc
541	if pc == nil {
542		return nil, errConnClosed
543	}
544	cwt, ok := pc.c.(ConnWithTimeout)
545	if !ok {
546		return nil, errTimeoutNotSupported
547	}
548	ci := lookupCommandInfo(commandName)
549	ac.state = (ac.state | ci.Set) &^ ci.Clear
550	return cwt.DoWithTimeout(timeout, commandName, args...)
551}
552
553func (ac *activeConn) Send(commandName string, args ...interface{}) error {
554	pc := ac.pc
555	if pc == nil {
556		return errConnClosed
557	}
558	ci := lookupCommandInfo(commandName)
559	ac.state = (ac.state | ci.Set) &^ ci.Clear
560	return pc.c.Send(commandName, args...)
561}
562
563func (ac *activeConn) Flush() error {
564	pc := ac.pc
565	if pc == nil {
566		return errConnClosed
567	}
568	return pc.c.Flush()
569}
570
571func (ac *activeConn) Receive() (reply interface{}, err error) {
572	pc := ac.pc
573	if pc == nil {
574		return nil, errConnClosed
575	}
576	return pc.c.Receive()
577}
578
579func (ac *activeConn) ReceiveContext(ctx context.Context) (reply interface{}, err error) {
580	pc := ac.pc
581	if pc == nil {
582		return nil, errConnClosed
583	}
584	cwt, ok := pc.c.(ConnWithContext)
585	if !ok {
586		return nil, errContextNotSupported
587	}
588	return cwt.ReceiveContext(ctx)
589}
590
591func (ac *activeConn) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) {
592	pc := ac.pc
593	if pc == nil {
594		return nil, errConnClosed
595	}
596	cwt, ok := pc.c.(ConnWithTimeout)
597	if !ok {
598		return nil, errTimeoutNotSupported
599	}
600	return cwt.ReceiveWithTimeout(timeout)
601}
602
603type errorConn struct{ err error }
604
605func (ec errorConn) Do(string, ...interface{}) (interface{}, error) { return nil, ec.err }
606func (ec errorConn) DoContext(context.Context, string, ...interface{}) (interface{}, error) {
607	return nil, ec.err
608}
609func (ec errorConn) DoWithTimeout(time.Duration, string, ...interface{}) (interface{}, error) {
610	return nil, ec.err
611}
612func (ec errorConn) Send(string, ...interface{}) error                     { return ec.err }
613func (ec errorConn) Err() error                                            { return ec.err }
614func (ec errorConn) Close() error                                          { return nil }
615func (ec errorConn) Flush() error                                          { return ec.err }
616func (ec errorConn) Receive() (interface{}, error)                         { return nil, ec.err }
617func (ec errorConn) ReceiveContext(context.Context) (interface{}, error)   { return nil, ec.err }
618func (ec errorConn) ReceiveWithTimeout(time.Duration) (interface{}, error) { return nil, ec.err }
619
620type idleList struct {
621	count       int
622	front, back *poolConn
623}
624
625type poolConn struct {
626	c          Conn
627	t          time.Time
628	created    time.Time
629	next, prev *poolConn
630}
631
632func (l *idleList) pushFront(pc *poolConn) {
633	pc.next = l.front
634	pc.prev = nil
635	if l.count == 0 {
636		l.back = pc
637	} else {
638		l.front.prev = pc
639	}
640	l.front = pc
641	l.count++
642}
643
644func (l *idleList) popFront() {
645	pc := l.front
646	l.count--
647	if l.count == 0 {
648		l.front, l.back = nil, nil
649	} else {
650		pc.next.prev = nil
651		l.front = pc.next
652	}
653	pc.next, pc.prev = nil, nil
654}
655
656func (l *idleList) popBack() {
657	pc := l.back
658	l.count--
659	if l.count == 0 {
660		l.front, l.back = nil, nil
661	} else {
662		pc.prev.next = nil
663		l.back = pc.prev
664	}
665	pc.next, pc.prev = nil, nil
666}
667