1package redis_test
2
3import (
4	"context"
5	"crypto/rand"
6	"fmt"
7	"net"
8	"strconv"
9	"sync"
10	"time"
11
12	"github.com/go-redis/redis/v7"
13
14	. "github.com/onsi/ginkgo"
15	. "github.com/onsi/gomega"
16)
17
18var _ = Describe("Redis Ring", func() {
19	const heartbeat = 100 * time.Millisecond
20
21	var ring *redis.Ring
22
23	setRingKeys := func() {
24		for i := 0; i < 100; i++ {
25			err := ring.Set(fmt.Sprintf("key%d", i), "value", 0).Err()
26			Expect(err).NotTo(HaveOccurred())
27		}
28	}
29
30	BeforeEach(func() {
31		opt := redisRingOptions()
32		opt.HeartbeatFrequency = heartbeat
33		ring = redis.NewRing(opt)
34
35		err := ring.ForEachShard(func(cl *redis.Client) error {
36			return cl.FlushDB().Err()
37		})
38		Expect(err).NotTo(HaveOccurred())
39	})
40
41	AfterEach(func() {
42		Expect(ring.Close()).NotTo(HaveOccurred())
43	})
44
45	It("supports WithContext", func() {
46		c, cancel := context.WithCancel(context.Background())
47		cancel()
48
49		err := ring.WithContext(c).Ping().Err()
50		Expect(err).To(MatchError("context canceled"))
51	})
52
53	It("distributes keys", func() {
54		setRingKeys()
55
56		// Both shards should have some keys now.
57		Expect(ringShard1.Info("keyspace").Val()).To(ContainSubstring("keys=57"))
58		Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=43"))
59	})
60
61	It("distributes keys when using EVAL", func() {
62		script := redis.NewScript(`
63			local r = redis.call('SET', KEYS[1], ARGV[1])
64			return r
65		`)
66
67		var key string
68		for i := 0; i < 100; i++ {
69			key = fmt.Sprintf("key%d", i)
70			err := script.Run(ring, []string{key}, "value").Err()
71			Expect(err).NotTo(HaveOccurred())
72		}
73
74		Expect(ringShard1.Info("keyspace").Val()).To(ContainSubstring("keys=57"))
75		Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=43"))
76	})
77
78	It("uses single shard when one of the shards is down", func() {
79		// Stop ringShard2.
80		Expect(ringShard2.Close()).NotTo(HaveOccurred())
81
82		Eventually(func() int {
83			return ring.Len()
84		}, "30s").Should(Equal(1))
85
86		setRingKeys()
87
88		// RingShard1 should have all keys.
89		Expect(ringShard1.Info("keyspace").Val()).To(ContainSubstring("keys=100"))
90
91		// Start ringShard2.
92		var err error
93		ringShard2, err = startRedis(ringShard2Port)
94		Expect(err).NotTo(HaveOccurred())
95
96		Eventually(func() int {
97			return ring.Len()
98		}, "30s").Should(Equal(2))
99
100		setRingKeys()
101
102		// RingShard2 should have its keys.
103		Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=43"))
104	})
105
106	It("supports hash tags", func() {
107		for i := 0; i < 100; i++ {
108			err := ring.Set(fmt.Sprintf("key%d{tag}", i), "value", 0).Err()
109			Expect(err).NotTo(HaveOccurred())
110		}
111
112		Expect(ringShard1.Info("keyspace").Val()).ToNot(ContainSubstring("keys="))
113		Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=100"))
114	})
115
116	Describe("pipeline", func() {
117		It("distributes keys", func() {
118			pipe := ring.Pipeline()
119			for i := 0; i < 100; i++ {
120				err := pipe.Set(fmt.Sprintf("key%d", i), "value", 0).Err()
121				Expect(err).NotTo(HaveOccurred())
122			}
123			cmds, err := pipe.Exec()
124			Expect(err).NotTo(HaveOccurred())
125			Expect(cmds).To(HaveLen(100))
126			Expect(pipe.Close()).NotTo(HaveOccurred())
127
128			for _, cmd := range cmds {
129				Expect(cmd.Err()).NotTo(HaveOccurred())
130				Expect(cmd.(*redis.StatusCmd).Val()).To(Equal("OK"))
131			}
132
133			// Both shards should have some keys now.
134			Expect(ringShard1.Info().Val()).To(ContainSubstring("keys=57"))
135			Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=43"))
136		})
137
138		It("is consistent with ring", func() {
139			var keys []string
140			for i := 0; i < 100; i++ {
141				key := make([]byte, 64)
142				_, err := rand.Read(key)
143				Expect(err).NotTo(HaveOccurred())
144				keys = append(keys, string(key))
145			}
146
147			_, err := ring.Pipelined(func(pipe redis.Pipeliner) error {
148				for _, key := range keys {
149					pipe.Set(key, "value", 0).Err()
150				}
151				return nil
152			})
153			Expect(err).NotTo(HaveOccurred())
154
155			for _, key := range keys {
156				val, err := ring.Get(key).Result()
157				Expect(err).NotTo(HaveOccurred())
158				Expect(val).To(Equal("value"))
159			}
160		})
161
162		It("supports hash tags", func() {
163			_, err := ring.Pipelined(func(pipe redis.Pipeliner) error {
164				for i := 0; i < 100; i++ {
165					pipe.Set(fmt.Sprintf("key%d{tag}", i), "value", 0).Err()
166				}
167				return nil
168			})
169			Expect(err).NotTo(HaveOccurred())
170
171			Expect(ringShard1.Info().Val()).ToNot(ContainSubstring("keys="))
172			Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=100"))
173		})
174	})
175
176	Describe("shard passwords", func() {
177		It("can be initialized with a single password, used for all shards", func() {
178			opts := redisRingOptions()
179			opts.Password = "password"
180			ring = redis.NewRing(opts)
181
182			err := ring.Ping().Err()
183			Expect(err).To(MatchError("ERR Client sent AUTH, but no password is set"))
184		})
185
186		It("can be initialized with a passwords map, one for each shard", func() {
187			opts := redisRingOptions()
188			opts.Passwords = map[string]string{
189				"ringShardOne": "password1",
190				"ringShardTwo": "password2",
191			}
192			ring = redis.NewRing(opts)
193
194			err := ring.Ping().Err()
195			Expect(err).To(MatchError("ERR Client sent AUTH, but no password is set"))
196		})
197	})
198
199	It("supports Process hook", func() {
200		err := ring.Ping().Err()
201		Expect(err).NotTo(HaveOccurred())
202
203		var stack []string
204
205		ring.AddHook(&hook{
206			beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
207				Expect(cmd.String()).To(Equal("ping: "))
208				stack = append(stack, "ring.BeforeProcess")
209				return ctx, nil
210			},
211			afterProcess: func(ctx context.Context, cmd redis.Cmder) error {
212				Expect(cmd.String()).To(Equal("ping: PONG"))
213				stack = append(stack, "ring.AfterProcess")
214				return nil
215			},
216		})
217
218		ring.ForEachShard(func(shard *redis.Client) error {
219			shard.AddHook(&hook{
220				beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
221					Expect(cmd.String()).To(Equal("ping: "))
222					stack = append(stack, "shard.BeforeProcess")
223					return ctx, nil
224				},
225				afterProcess: func(ctx context.Context, cmd redis.Cmder) error {
226					Expect(cmd.String()).To(Equal("ping: PONG"))
227					stack = append(stack, "shard.AfterProcess")
228					return nil
229				},
230			})
231			return nil
232		})
233
234		err = ring.Ping().Err()
235		Expect(err).NotTo(HaveOccurred())
236		Expect(stack).To(Equal([]string{
237			"ring.BeforeProcess",
238			"shard.BeforeProcess",
239			"shard.AfterProcess",
240			"ring.AfterProcess",
241		}))
242	})
243
244	It("supports Pipeline hook", func() {
245		err := ring.Ping().Err()
246		Expect(err).NotTo(HaveOccurred())
247
248		var stack []string
249
250		ring.AddHook(&hook{
251			beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
252				Expect(cmds).To(HaveLen(1))
253				Expect(cmds[0].String()).To(Equal("ping: "))
254				stack = append(stack, "ring.BeforeProcessPipeline")
255				return ctx, nil
256			},
257			afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
258				Expect(cmds).To(HaveLen(1))
259				Expect(cmds[0].String()).To(Equal("ping: PONG"))
260				stack = append(stack, "ring.AfterProcessPipeline")
261				return nil
262			},
263		})
264
265		ring.ForEachShard(func(shard *redis.Client) error {
266			shard.AddHook(&hook{
267				beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
268					Expect(cmds).To(HaveLen(1))
269					Expect(cmds[0].String()).To(Equal("ping: "))
270					stack = append(stack, "shard.BeforeProcessPipeline")
271					return ctx, nil
272				},
273				afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
274					Expect(cmds).To(HaveLen(1))
275					Expect(cmds[0].String()).To(Equal("ping: PONG"))
276					stack = append(stack, "shard.AfterProcessPipeline")
277					return nil
278				},
279			})
280			return nil
281		})
282
283		_, err = ring.Pipelined(func(pipe redis.Pipeliner) error {
284			pipe.Ping()
285			return nil
286		})
287		Expect(err).NotTo(HaveOccurred())
288		Expect(stack).To(Equal([]string{
289			"ring.BeforeProcessPipeline",
290			"shard.BeforeProcessPipeline",
291			"shard.AfterProcessPipeline",
292			"ring.AfterProcessPipeline",
293		}))
294	})
295
296	It("supports TxPipeline hook", func() {
297		err := ring.Ping().Err()
298		Expect(err).NotTo(HaveOccurred())
299
300		var stack []string
301
302		ring.AddHook(&hook{
303			beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
304				Expect(cmds).To(HaveLen(1))
305				Expect(cmds[0].String()).To(Equal("ping: "))
306				stack = append(stack, "ring.BeforeProcessPipeline")
307				return ctx, nil
308			},
309			afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
310				Expect(cmds).To(HaveLen(1))
311				Expect(cmds[0].String()).To(Equal("ping: PONG"))
312				stack = append(stack, "ring.AfterProcessPipeline")
313				return nil
314			},
315		})
316
317		ring.ForEachShard(func(shard *redis.Client) error {
318			shard.AddHook(&hook{
319				beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
320					Expect(cmds).To(HaveLen(3))
321					Expect(cmds[1].String()).To(Equal("ping: "))
322					stack = append(stack, "shard.BeforeProcessPipeline")
323					return ctx, nil
324				},
325				afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
326					Expect(cmds).To(HaveLen(3))
327					Expect(cmds[1].String()).To(Equal("ping: PONG"))
328					stack = append(stack, "shard.AfterProcessPipeline")
329					return nil
330				},
331			})
332			return nil
333		})
334
335		_, err = ring.TxPipelined(func(pipe redis.Pipeliner) error {
336			pipe.Ping()
337			return nil
338		})
339		Expect(err).NotTo(HaveOccurred())
340		Expect(stack).To(Equal([]string{
341			"ring.BeforeProcessPipeline",
342			"shard.BeforeProcessPipeline",
343			"shard.AfterProcessPipeline",
344			"ring.AfterProcessPipeline",
345		}))
346	})
347})
348
349var _ = Describe("empty Redis Ring", func() {
350	var ring *redis.Ring
351
352	BeforeEach(func() {
353		ring = redis.NewRing(&redis.RingOptions{})
354	})
355
356	AfterEach(func() {
357		Expect(ring.Close()).NotTo(HaveOccurred())
358	})
359
360	It("returns an error", func() {
361		err := ring.Ping().Err()
362		Expect(err).To(MatchError("redis: all ring shards are down"))
363	})
364
365	It("pipeline returns an error", func() {
366		_, err := ring.Pipelined(func(pipe redis.Pipeliner) error {
367			pipe.Ping()
368			return nil
369		})
370		Expect(err).To(MatchError("redis: all ring shards are down"))
371	})
372})
373
374var _ = Describe("Ring watch", func() {
375	const heartbeat = 100 * time.Millisecond
376
377	var ring *redis.Ring
378
379	BeforeEach(func() {
380		opt := redisRingOptions()
381		opt.HeartbeatFrequency = heartbeat
382		ring = redis.NewRing(opt)
383
384		err := ring.ForEachShard(func(cl *redis.Client) error {
385			return cl.FlushDB().Err()
386		})
387		Expect(err).NotTo(HaveOccurred())
388	})
389
390	AfterEach(func() {
391		Expect(ring.Close()).NotTo(HaveOccurred())
392	})
393
394	It("should Watch", func() {
395		var incr func(string) error
396
397		// Transactionally increments key using GET and SET commands.
398		incr = func(key string) error {
399			err := ring.Watch(func(tx *redis.Tx) error {
400				n, err := tx.Get(key).Int64()
401				if err != nil && err != redis.Nil {
402					return err
403				}
404
405				_, err = tx.TxPipelined(func(pipe redis.Pipeliner) error {
406					pipe.Set(key, strconv.FormatInt(n+1, 10), 0)
407					return nil
408				})
409				return err
410			}, key)
411			if err == redis.TxFailedErr {
412				return incr(key)
413			}
414			return err
415		}
416
417		var wg sync.WaitGroup
418		for i := 0; i < 100; i++ {
419			wg.Add(1)
420			go func() {
421				defer GinkgoRecover()
422				defer wg.Done()
423
424				err := incr("key")
425				Expect(err).NotTo(HaveOccurred())
426			}()
427		}
428		wg.Wait()
429
430		n, err := ring.Get("key").Int64()
431		Expect(err).NotTo(HaveOccurred())
432		Expect(n).To(Equal(int64(100)))
433	})
434
435	It("should discard", func() {
436		err := ring.Watch(func(tx *redis.Tx) error {
437			cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
438				pipe.Set("key1", "hello1", 0)
439				pipe.Discard()
440				pipe.Set("key2", "hello2", 0)
441				return nil
442			})
443			Expect(err).NotTo(HaveOccurred())
444			Expect(cmds).To(HaveLen(1))
445			return err
446		}, "key1", "key2")
447		Expect(err).NotTo(HaveOccurred())
448
449		get := ring.Get("key1")
450		Expect(get.Err()).To(Equal(redis.Nil))
451		Expect(get.Val()).To(Equal(""))
452
453		get = ring.Get("key2")
454		Expect(get.Err()).NotTo(HaveOccurred())
455		Expect(get.Val()).To(Equal("hello2"))
456	})
457
458	It("returns no error when there are no commands", func() {
459		err := ring.Watch(func(tx *redis.Tx) error {
460			_, err := tx.TxPipelined(func(redis.Pipeliner) error { return nil })
461			return err
462		}, "key")
463		Expect(err).NotTo(HaveOccurred())
464
465		v, err := ring.Ping().Result()
466		Expect(err).NotTo(HaveOccurred())
467		Expect(v).To(Equal("PONG"))
468	})
469
470	It("should exec bulks", func() {
471		const N = 20000
472
473		err := ring.Watch(func(tx *redis.Tx) error {
474			cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
475				for i := 0; i < N; i++ {
476					pipe.Incr("key")
477				}
478				return nil
479			})
480			Expect(err).NotTo(HaveOccurred())
481			Expect(len(cmds)).To(Equal(N))
482			for _, cmd := range cmds {
483				Expect(cmd.Err()).NotTo(HaveOccurred())
484			}
485			return err
486		}, "key")
487		Expect(err).NotTo(HaveOccurred())
488
489		num, err := ring.Get("key").Int64()
490		Expect(err).NotTo(HaveOccurred())
491		Expect(num).To(Equal(int64(N)))
492	})
493
494	It("should Watch/Unwatch", func() {
495		var C, N int
496
497		err := ring.Set("key", "0", 0).Err()
498		Expect(err).NotTo(HaveOccurred())
499
500		perform(C, func(id int) {
501			for i := 0; i < N; i++ {
502				err := ring.Watch(func(tx *redis.Tx) error {
503					val, err := tx.Get("key").Result()
504					Expect(err).NotTo(HaveOccurred())
505					Expect(val).NotTo(Equal(redis.Nil))
506
507					num, err := strconv.ParseInt(val, 10, 64)
508					Expect(err).NotTo(HaveOccurred())
509
510					cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
511						pipe.Set("key", strconv.FormatInt(num+1, 10), 0)
512						return nil
513					})
514					Expect(cmds).To(HaveLen(1))
515					return err
516				}, "key")
517				if err == redis.TxFailedErr {
518					i--
519					continue
520				}
521				Expect(err).NotTo(HaveOccurred())
522			}
523		})
524
525		val, err := ring.Get("key").Int64()
526		Expect(err).NotTo(HaveOccurred())
527		Expect(val).To(Equal(int64(C * N)))
528	})
529
530	It("should close Tx without closing the client", func() {
531		err := ring.Watch(func(tx *redis.Tx) error {
532			_, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
533				pipe.Ping()
534				return nil
535			})
536			return err
537		}, "key")
538		Expect(err).NotTo(HaveOccurred())
539
540		Expect(ring.Ping().Err()).NotTo(HaveOccurred())
541	})
542
543	It("respects max size on multi", func() {
544		perform(1000, func(id int) {
545			var ping *redis.StatusCmd
546
547			err := ring.Watch(func(tx *redis.Tx) error {
548				cmds, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
549					ping = pipe.Ping()
550					return nil
551				})
552				Expect(err).NotTo(HaveOccurred())
553				Expect(cmds).To(HaveLen(1))
554				return err
555			}, "key")
556			Expect(err).NotTo(HaveOccurred())
557
558			Expect(ping.Err()).NotTo(HaveOccurred())
559			Expect(ping.Val()).To(Equal("PONG"))
560		})
561
562		ring.ForEachShard(func(cl *redis.Client) error {
563			defer GinkgoRecover()
564
565			pool := cl.Pool()
566			Expect(pool.Len()).To(BeNumerically("<=", 10))
567			Expect(pool.IdleLen()).To(BeNumerically("<=", 10))
568			Expect(pool.Len()).To(Equal(pool.IdleLen()))
569
570			return nil
571		})
572	})
573})
574
575var _ = Describe("Ring Tx timeout", func() {
576	const heartbeat = 100 * time.Millisecond
577
578	var ring *redis.Ring
579
580	AfterEach(func() {
581		_ = ring.Close()
582	})
583
584	testTimeout := func() {
585		It("Tx timeouts", func() {
586			err := ring.Watch(func(tx *redis.Tx) error {
587				return tx.Ping().Err()
588			}, "foo")
589			Expect(err).To(HaveOccurred())
590			Expect(err.(net.Error).Timeout()).To(BeTrue())
591		})
592
593		It("Tx Pipeline timeouts", func() {
594			err := ring.Watch(func(tx *redis.Tx) error {
595				_, err := tx.TxPipelined(func(pipe redis.Pipeliner) error {
596					pipe.Ping()
597					return nil
598				})
599				return err
600			}, "foo")
601			Expect(err).To(HaveOccurred())
602			Expect(err.(net.Error).Timeout()).To(BeTrue())
603		})
604	}
605
606	const pause = 5 * time.Second
607
608	Context("read/write timeout", func() {
609		BeforeEach(func() {
610			opt := redisRingOptions()
611			opt.ReadTimeout = 250 * time.Millisecond
612			opt.WriteTimeout = 250 * time.Millisecond
613			opt.HeartbeatFrequency = heartbeat
614			ring = redis.NewRing(opt)
615
616			err := ring.ForEachShard(func(client *redis.Client) error {
617				return client.ClientPause(pause).Err()
618			})
619			Expect(err).NotTo(HaveOccurred())
620		})
621
622		AfterEach(func() {
623			_ = ring.ForEachShard(func(client *redis.Client) error {
624				defer GinkgoRecover()
625				Eventually(func() error {
626					return client.Ping().Err()
627				}, 2*pause).ShouldNot(HaveOccurred())
628				return nil
629			})
630		})
631
632		testTimeout()
633	})
634})
635