1// Copyright 2011 Gary Burd
2//
3// Licensed under the Apache License, Version 2.0 (the "License"): you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations
13// under the License.
14
15package redis_test
16
17import (
18	"errors"
19	"io"
20	"reflect"
21	"sync"
22	"testing"
23	"time"
24
25	"github.com/gomodule/redigo/redis"
26)
27
28type poolTestConn struct {
29	d   *poolDialer
30	err error
31	redis.Conn
32}
33
34func (c *poolTestConn) Close() error {
35	c.d.mu.Lock()
36	c.d.open -= 1
37	c.d.mu.Unlock()
38	return c.Conn.Close()
39}
40
41func (c *poolTestConn) Err() error { return c.err }
42
43func (c *poolTestConn) Do(commandName string, args ...interface{}) (interface{}, error) {
44	if commandName == "ERR" {
45		c.err = args[0].(error)
46		commandName = "PING"
47	}
48	if commandName != "" {
49		c.d.commands = append(c.d.commands, commandName)
50	}
51	return c.Conn.Do(commandName, args...)
52}
53
54func (c *poolTestConn) Send(commandName string, args ...interface{}) error {
55	c.d.commands = append(c.d.commands, commandName)
56	return c.Conn.Send(commandName, args...)
57}
58
59type poolDialer struct {
60	mu       sync.Mutex
61	t        *testing.T
62	dialed   int
63	open     int
64	commands []string
65	dialErr  error
66}
67
68func (d *poolDialer) dial() (redis.Conn, error) {
69	d.mu.Lock()
70	d.dialed += 1
71	dialErr := d.dialErr
72	d.mu.Unlock()
73	if dialErr != nil {
74		return nil, d.dialErr
75	}
76	c, err := redis.DialDefaultServer()
77	if err != nil {
78		return nil, err
79	}
80	d.mu.Lock()
81	d.open += 1
82	d.mu.Unlock()
83	return &poolTestConn{d: d, Conn: c}, nil
84}
85
86func (d *poolDialer) check(message string, p *redis.Pool, dialed, open, inuse int) {
87	d.mu.Lock()
88	if d.dialed != dialed {
89		d.t.Errorf("%s: dialed=%d, want %d", message, d.dialed, dialed)
90	}
91	if d.open != open {
92		d.t.Errorf("%s: open=%d, want %d", message, d.open, open)
93	}
94
95	stats := p.Stats()
96
97	if stats.ActiveCount != open {
98		d.t.Errorf("%s: active=%d, want %d", message, stats.ActiveCount, open)
99	}
100	if stats.IdleCount != open-inuse {
101		d.t.Errorf("%s: idle=%d, want %d", message, stats.IdleCount, open-inuse)
102	}
103
104	d.mu.Unlock()
105}
106
107func TestPoolReuse(t *testing.T) {
108	d := poolDialer{t: t}
109	p := &redis.Pool{
110		MaxIdle: 2,
111		Dial:    d.dial,
112	}
113
114	for i := 0; i < 10; i++ {
115		c1 := p.Get()
116		c1.Do("PING")
117		c2 := p.Get()
118		c2.Do("PING")
119		c1.Close()
120		c2.Close()
121	}
122
123	d.check("before close", p, 2, 2, 0)
124	p.Close()
125	d.check("after close", p, 2, 0, 0)
126}
127
128func TestPoolMaxIdle(t *testing.T) {
129	d := poolDialer{t: t}
130	p := &redis.Pool{
131		MaxIdle: 2,
132		Dial:    d.dial,
133	}
134	defer p.Close()
135
136	for i := 0; i < 10; i++ {
137		c1 := p.Get()
138		c1.Do("PING")
139		c2 := p.Get()
140		c2.Do("PING")
141		c3 := p.Get()
142		c3.Do("PING")
143		c1.Close()
144		c2.Close()
145		c3.Close()
146	}
147	d.check("before close", p, 12, 2, 0)
148	p.Close()
149	d.check("after close", p, 12, 0, 0)
150}
151
152func TestPoolError(t *testing.T) {
153	d := poolDialer{t: t}
154	p := &redis.Pool{
155		MaxIdle: 2,
156		Dial:    d.dial,
157	}
158	defer p.Close()
159
160	c := p.Get()
161	c.Do("ERR", io.EOF)
162	if c.Err() == nil {
163		t.Errorf("expected c.Err() != nil")
164	}
165	c.Close()
166
167	c = p.Get()
168	c.Do("ERR", io.EOF)
169	c.Close()
170
171	d.check(".", p, 2, 0, 0)
172}
173
174func TestPoolClose(t *testing.T) {
175	d := poolDialer{t: t}
176	p := &redis.Pool{
177		MaxIdle: 2,
178		Dial:    d.dial,
179	}
180	defer p.Close()
181
182	c1 := p.Get()
183	c1.Do("PING")
184	c2 := p.Get()
185	c2.Do("PING")
186	c3 := p.Get()
187	c3.Do("PING")
188
189	c1.Close()
190	if _, err := c1.Do("PING"); err == nil {
191		t.Errorf("expected error after connection closed")
192	}
193
194	c2.Close()
195	c2.Close()
196
197	p.Close()
198
199	d.check("after pool close", p, 3, 1, 1)
200
201	if _, err := c1.Do("PING"); err == nil {
202		t.Errorf("expected error after connection and pool closed")
203	}
204
205	c3.Close()
206
207	d.check("after conn close", p, 3, 0, 0)
208
209	c1 = p.Get()
210	if _, err := c1.Do("PING"); err == nil {
211		t.Errorf("expected error after pool closed")
212	}
213}
214
215func TestPoolClosedConn(t *testing.T) {
216	d := poolDialer{t: t}
217	p := &redis.Pool{
218		MaxIdle:     2,
219		IdleTimeout: 300 * time.Second,
220		Dial:        d.dial,
221	}
222	defer p.Close()
223	c := p.Get()
224	if c.Err() != nil {
225		t.Fatal("get failed")
226	}
227	c.Close()
228	if err := c.Err(); err == nil {
229		t.Fatal("Err on closed connection did not return error")
230	}
231	if _, err := c.Do("PING"); err == nil {
232		t.Fatal("Do on closed connection did not return error")
233	}
234	if err := c.Send("PING"); err == nil {
235		t.Fatal("Send on closed connection did not return error")
236	}
237	if err := c.Flush(); err == nil {
238		t.Fatal("Flush on closed connection did not return error")
239	}
240	if _, err := c.Receive(); err == nil {
241		t.Fatal("Receive on closed connection did not return error")
242	}
243}
244
245func TestPoolIdleTimeout(t *testing.T) {
246	d := poolDialer{t: t}
247	p := &redis.Pool{
248		MaxIdle:     2,
249		IdleTimeout: 300 * time.Second,
250		Dial:        d.dial,
251	}
252	defer p.Close()
253
254	now := time.Now()
255	redis.SetNowFunc(func() time.Time { return now })
256	defer redis.SetNowFunc(time.Now)
257
258	c := p.Get()
259	c.Do("PING")
260	c.Close()
261
262	d.check("1", p, 1, 1, 0)
263
264	now = now.Add(p.IdleTimeout + 1)
265
266	c = p.Get()
267	c.Do("PING")
268	c.Close()
269
270	d.check("2", p, 2, 1, 0)
271}
272
273func TestPoolMaxLifetime(t *testing.T) {
274	d := poolDialer{t: t}
275	p := &redis.Pool{
276		MaxIdle:         2,
277		MaxConnLifetime: 300 * time.Second,
278		Dial:            d.dial,
279	}
280	defer p.Close()
281
282	now := time.Now()
283	redis.SetNowFunc(func() time.Time { return now })
284	defer redis.SetNowFunc(time.Now)
285
286	c := p.Get()
287	c.Do("PING")
288	c.Close()
289
290	d.check("1", p, 1, 1, 0)
291
292	now = now.Add(p.MaxConnLifetime + 1)
293
294	c = p.Get()
295	c.Do("PING")
296	c.Close()
297
298	d.check("2", p, 2, 1, 0)
299}
300
301func TestPoolConcurrenSendReceive(t *testing.T) {
302	p := &redis.Pool{
303		Dial: redis.DialDefaultServer,
304	}
305	defer p.Close()
306
307	c := p.Get()
308	done := make(chan error, 1)
309	go func() {
310		_, err := c.Receive()
311		done <- err
312	}()
313	c.Send("PING")
314	c.Flush()
315	err := <-done
316	if err != nil {
317		t.Fatalf("Receive() returned error %v", err)
318	}
319	_, err = c.Do("")
320	if err != nil {
321		t.Fatalf("Do() returned error %v", err)
322	}
323	c.Close()
324}
325
326func TestPoolBorrowCheck(t *testing.T) {
327	d := poolDialer{t: t}
328	p := &redis.Pool{
329		MaxIdle:      2,
330		Dial:         d.dial,
331		TestOnBorrow: func(redis.Conn, time.Time) error { return redis.Error("BLAH") },
332	}
333	defer p.Close()
334
335	for i := 0; i < 10; i++ {
336		c := p.Get()
337		c.Do("PING")
338		c.Close()
339	}
340	d.check("1", p, 10, 1, 0)
341}
342
343func TestPoolMaxActive(t *testing.T) {
344	d := poolDialer{t: t}
345	p := &redis.Pool{
346		MaxIdle:   2,
347		MaxActive: 2,
348		Dial:      d.dial,
349	}
350	defer p.Close()
351
352	c1 := p.Get()
353	c1.Do("PING")
354	c2 := p.Get()
355	c2.Do("PING")
356
357	d.check("1", p, 2, 2, 2)
358
359	c3 := p.Get()
360	if _, err := c3.Do("PING"); err != redis.ErrPoolExhausted {
361		t.Errorf("expected pool exhausted")
362	}
363
364	c3.Close()
365	d.check("2", p, 2, 2, 2)
366	c2.Close()
367	d.check("3", p, 2, 2, 1)
368
369	c3 = p.Get()
370	if _, err := c3.Do("PING"); err != nil {
371		t.Errorf("expected good channel, err=%v", err)
372	}
373	c3.Close()
374
375	d.check("4", p, 2, 2, 1)
376}
377
378func TestPoolMonitorCleanup(t *testing.T) {
379	d := poolDialer{t: t}
380	p := &redis.Pool{
381		MaxIdle:   2,
382		MaxActive: 2,
383		Dial:      d.dial,
384	}
385	defer p.Close()
386
387	c := p.Get()
388	c.Send("MONITOR")
389	c.Close()
390
391	d.check("", p, 1, 0, 0)
392}
393
394func TestPoolPubSubCleanup(t *testing.T) {
395	d := poolDialer{t: t}
396	p := &redis.Pool{
397		MaxIdle:   2,
398		MaxActive: 2,
399		Dial:      d.dial,
400	}
401	defer p.Close()
402
403	c := p.Get()
404	c.Send("SUBSCRIBE", "x")
405	c.Close()
406
407	want := []string{"SUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "ECHO"}
408	if !reflect.DeepEqual(d.commands, want) {
409		t.Errorf("got commands %v, want %v", d.commands, want)
410	}
411	d.commands = nil
412
413	c = p.Get()
414	c.Send("PSUBSCRIBE", "x*")
415	c.Close()
416
417	want = []string{"PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "ECHO"}
418	if !reflect.DeepEqual(d.commands, want) {
419		t.Errorf("got commands %v, want %v", d.commands, want)
420	}
421	d.commands = nil
422}
423
424func TestPoolTransactionCleanup(t *testing.T) {
425	d := poolDialer{t: t}
426	p := &redis.Pool{
427		MaxIdle:   2,
428		MaxActive: 2,
429		Dial:      d.dial,
430	}
431	defer p.Close()
432
433	c := p.Get()
434	c.Do("WATCH", "key")
435	c.Do("PING")
436	c.Close()
437
438	want := []string{"WATCH", "PING", "UNWATCH"}
439	if !reflect.DeepEqual(d.commands, want) {
440		t.Errorf("got commands %v, want %v", d.commands, want)
441	}
442	d.commands = nil
443
444	c = p.Get()
445	c.Do("WATCH", "key")
446	c.Do("UNWATCH")
447	c.Do("PING")
448	c.Close()
449
450	want = []string{"WATCH", "UNWATCH", "PING"}
451	if !reflect.DeepEqual(d.commands, want) {
452		t.Errorf("got commands %v, want %v", d.commands, want)
453	}
454	d.commands = nil
455
456	c = p.Get()
457	c.Do("WATCH", "key")
458	c.Do("MULTI")
459	c.Do("PING")
460	c.Close()
461
462	want = []string{"WATCH", "MULTI", "PING", "DISCARD"}
463	if !reflect.DeepEqual(d.commands, want) {
464		t.Errorf("got commands %v, want %v", d.commands, want)
465	}
466	d.commands = nil
467
468	c = p.Get()
469	c.Do("WATCH", "key")
470	c.Do("MULTI")
471	c.Do("DISCARD")
472	c.Do("PING")
473	c.Close()
474
475	want = []string{"WATCH", "MULTI", "DISCARD", "PING"}
476	if !reflect.DeepEqual(d.commands, want) {
477		t.Errorf("got commands %v, want %v", d.commands, want)
478	}
479	d.commands = nil
480
481	c = p.Get()
482	c.Do("WATCH", "key")
483	c.Do("MULTI")
484	c.Do("EXEC")
485	c.Do("PING")
486	c.Close()
487
488	want = []string{"WATCH", "MULTI", "EXEC", "PING"}
489	if !reflect.DeepEqual(d.commands, want) {
490		t.Errorf("got commands %v, want %v", d.commands, want)
491	}
492	d.commands = nil
493}
494
495func startGoroutines(p *redis.Pool, cmd string, args ...interface{}) chan error {
496	errs := make(chan error, 10)
497	for i := 0; i < cap(errs); i++ {
498		go func() {
499			c := p.Get()
500			_, err := c.Do(cmd, args...)
501			c.Close()
502			errs <- err
503		}()
504	}
505
506	return errs
507}
508
509func TestWaitPool(t *testing.T) {
510	d := poolDialer{t: t}
511	p := &redis.Pool{
512		MaxIdle:   1,
513		MaxActive: 1,
514		Dial:      d.dial,
515		Wait:      true,
516	}
517	defer p.Close()
518
519	c := p.Get()
520	errs := startGoroutines(p, "PING")
521	d.check("before close", p, 1, 1, 1)
522	c.Close()
523	timeout := time.After(2 * time.Second)
524	for i := 0; i < cap(errs); i++ {
525		select {
526		case err := <-errs:
527			if err != nil {
528				t.Fatal(err)
529			}
530		case <-timeout:
531			t.Fatalf("timeout waiting for blocked goroutine %d", i)
532		}
533	}
534	d.check("done", p, 1, 1, 0)
535}
536
537func TestWaitPoolClose(t *testing.T) {
538	d := poolDialer{t: t}
539	p := &redis.Pool{
540		MaxIdle:   1,
541		MaxActive: 1,
542		Dial:      d.dial,
543		Wait:      true,
544	}
545	defer p.Close()
546
547	c := p.Get()
548	if _, err := c.Do("PING"); err != nil {
549		t.Fatal(err)
550	}
551	errs := startGoroutines(p, "PING")
552	d.check("before close", p, 1, 1, 1)
553	p.Close()
554	timeout := time.After(2 * time.Second)
555	for i := 0; i < cap(errs); i++ {
556		select {
557		case err := <-errs:
558			switch err {
559			case nil:
560				t.Fatal("blocked goroutine did not get error")
561			case redis.ErrPoolExhausted:
562				t.Fatal("blocked goroutine got pool exhausted error")
563			}
564		case <-timeout:
565			t.Fatal("timeout waiting for blocked goroutine")
566		}
567	}
568	c.Close()
569	d.check("done", p, 1, 0, 0)
570}
571
572func TestWaitPoolCommandError(t *testing.T) {
573	testErr := errors.New("test")
574	d := poolDialer{t: t}
575	p := &redis.Pool{
576		MaxIdle:   1,
577		MaxActive: 1,
578		Dial:      d.dial,
579		Wait:      true,
580	}
581	defer p.Close()
582
583	c := p.Get()
584	errs := startGoroutines(p, "ERR", testErr)
585	d.check("before close", p, 1, 1, 1)
586	c.Close()
587	timeout := time.After(2 * time.Second)
588	for i := 0; i < cap(errs); i++ {
589		select {
590		case err := <-errs:
591			if err != nil {
592				t.Fatal(err)
593			}
594		case <-timeout:
595			t.Fatalf("timeout waiting for blocked goroutine %d", i)
596		}
597	}
598	d.check("done", p, cap(errs), 0, 0)
599}
600
601func TestWaitPoolDialError(t *testing.T) {
602	testErr := errors.New("test")
603	d := poolDialer{t: t}
604	p := &redis.Pool{
605		MaxIdle:   1,
606		MaxActive: 1,
607		Dial:      d.dial,
608		Wait:      true,
609	}
610	defer p.Close()
611
612	c := p.Get()
613	errs := startGoroutines(p, "ERR", testErr)
614	d.check("before close", p, 1, 1, 1)
615
616	d.dialErr = errors.New("dial")
617	c.Close()
618
619	nilCount := 0
620	errCount := 0
621	timeout := time.After(2 * time.Second)
622	for i := 0; i < cap(errs); i++ {
623		select {
624		case err := <-errs:
625			switch err {
626			case nil:
627				nilCount++
628			case d.dialErr:
629				errCount++
630			default:
631				t.Fatalf("expected dial error or nil, got %v", err)
632			}
633		case <-timeout:
634			t.Fatalf("timeout waiting for blocked goroutine %d", i)
635		}
636	}
637	if nilCount != 1 {
638		t.Errorf("expected one nil error, got %d", nilCount)
639	}
640	if errCount != cap(errs)-1 {
641		t.Errorf("expected %d dial errors, got %d", cap(errs)-1, errCount)
642	}
643	d.check("done", p, cap(errs), 0, 0)
644}
645
646// Borrowing requires us to iterate over the idle connections, unlock the pool,
647// and perform a blocking operation to check the connection still works. If
648// TestOnBorrow fails, we must reacquire the lock and continue iteration. This
649// test ensures that iteration will work correctly if multiple threads are
650// iterating simultaneously.
651func TestLocking_TestOnBorrowFails_PoolDoesntCrash(t *testing.T) {
652	const count = 100
653
654	// First we'll Create a pool where the pilfering of idle connections fails.
655	d := poolDialer{t: t}
656	p := &redis.Pool{
657		MaxIdle:   count,
658		MaxActive: count,
659		Dial:      d.dial,
660		TestOnBorrow: func(c redis.Conn, t time.Time) error {
661			return errors.New("No way back into the real world.")
662		},
663	}
664	defer p.Close()
665
666	// Fill the pool with idle connections.
667	conns := make([]redis.Conn, count)
668	for i := range conns {
669		conns[i] = p.Get()
670	}
671	for i := range conns {
672		conns[i].Close()
673	}
674
675	// Spawn a bunch of goroutines to thrash the pool.
676	var wg sync.WaitGroup
677	wg.Add(count)
678	for i := 0; i < count; i++ {
679		go func() {
680			c := p.Get()
681			if c.Err() != nil {
682				t.Errorf("pool get failed: %v", c.Err())
683			}
684			c.Close()
685			wg.Done()
686		}()
687	}
688	wg.Wait()
689	if d.dialed != count*2 {
690		t.Errorf("Expected %d dials, got %d", count*2, d.dialed)
691	}
692}
693
694func BenchmarkPoolGet(b *testing.B) {
695	b.StopTimer()
696	p := redis.Pool{Dial: redis.DialDefaultServer, MaxIdle: 2}
697	c := p.Get()
698	if err := c.Err(); err != nil {
699		b.Fatal(err)
700	}
701	c.Close()
702	defer p.Close()
703	b.StartTimer()
704	for i := 0; i < b.N; i++ {
705		c = p.Get()
706		c.Close()
707	}
708}
709
710func BenchmarkPoolGetErr(b *testing.B) {
711	b.StopTimer()
712	p := redis.Pool{Dial: redis.DialDefaultServer, MaxIdle: 2}
713	c := p.Get()
714	if err := c.Err(); err != nil {
715		b.Fatal(err)
716	}
717	c.Close()
718	defer p.Close()
719	b.StartTimer()
720	for i := 0; i < b.N; i++ {
721		c = p.Get()
722		if err := c.Err(); err != nil {
723			b.Fatal(err)
724		}
725		c.Close()
726	}
727}
728
729func BenchmarkPoolGetPing(b *testing.B) {
730	b.StopTimer()
731	p := redis.Pool{Dial: redis.DialDefaultServer, MaxIdle: 2}
732	c := p.Get()
733	if err := c.Err(); err != nil {
734		b.Fatal(err)
735	}
736	c.Close()
737	defer p.Close()
738	b.StartTimer()
739	for i := 0; i < b.N; i++ {
740		c = p.Get()
741		if _, err := c.Do("PING"); err != nil {
742			b.Fatal(err)
743		}
744		c.Close()
745	}
746}
747