1package redis
2
3import (
4	"context"
5	"fmt"
6	"time"
7
8	"github.com/go-redis/redis/v8/internal"
9	"github.com/go-redis/redis/v8/internal/pool"
10	"github.com/go-redis/redis/v8/internal/proto"
11	"go.opentelemetry.io/otel/api/trace"
12	"go.opentelemetry.io/otel/label"
13)
14
15// Nil reply returned by Redis when key does not exist.
16const Nil = proto.Nil
17
18func SetLogger(logger internal.Logging) {
19	internal.Logger = logger
20}
21
22//------------------------------------------------------------------------------
23
24type Hook interface {
25	BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error)
26	AfterProcess(ctx context.Context, cmd Cmder) error
27
28	BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error)
29	AfterProcessPipeline(ctx context.Context, cmds []Cmder) error
30}
31
32type hooks struct {
33	hooks []Hook
34}
35
36func (hs *hooks) lock() {
37	hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)]
38}
39
40func (hs hooks) clone() hooks {
41	clone := hs
42	clone.lock()
43	return clone
44}
45
46func (hs *hooks) AddHook(hook Hook) {
47	hs.hooks = append(hs.hooks, hook)
48}
49
50func (hs hooks) process(
51	ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
52) error {
53	if len(hs.hooks) == 0 {
54		err := hs.withContext(ctx, func() error {
55			return fn(ctx, cmd)
56		})
57		cmd.SetErr(err)
58		return err
59	}
60
61	var hookIndex int
62	var retErr error
63
64	for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
65		ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd)
66		if retErr != nil {
67			cmd.SetErr(retErr)
68		}
69	}
70
71	if retErr == nil {
72		retErr = hs.withContext(ctx, func() error {
73			return fn(ctx, cmd)
74		})
75		cmd.SetErr(retErr)
76	}
77
78	for hookIndex--; hookIndex >= 0; hookIndex-- {
79		if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil {
80			retErr = err
81			cmd.SetErr(retErr)
82		}
83	}
84
85	return retErr
86}
87
88func (hs hooks) processPipeline(
89	ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
90) error {
91	if len(hs.hooks) == 0 {
92		err := hs.withContext(ctx, func() error {
93			return fn(ctx, cmds)
94		})
95		return err
96	}
97
98	var hookIndex int
99	var retErr error
100
101	for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
102		ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds)
103		if retErr != nil {
104			setCmdsErr(cmds, retErr)
105		}
106	}
107
108	if retErr == nil {
109		retErr = hs.withContext(ctx, func() error {
110			return fn(ctx, cmds)
111		})
112	}
113
114	for hookIndex--; hookIndex >= 0; hookIndex-- {
115		if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil {
116			retErr = err
117			setCmdsErr(cmds, retErr)
118		}
119	}
120
121	return retErr
122}
123
124func (hs hooks) processTxPipeline(
125	ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
126) error {
127	cmds = wrapMultiExec(ctx, cmds)
128	return hs.processPipeline(ctx, cmds, fn)
129}
130
131func (hs hooks) withContext(ctx context.Context, fn func() error) error {
132	done := ctx.Done()
133	if done == nil {
134		return fn()
135	}
136
137	errc := make(chan error, 1)
138	go func() { errc <- fn() }()
139
140	select {
141	case <-done:
142		return ctx.Err()
143	case err := <-errc:
144		return err
145	}
146}
147
148//------------------------------------------------------------------------------
149
150type baseClient struct {
151	opt      *Options
152	connPool pool.Pooler
153
154	onClose func() error // hook called when client is closed
155}
156
157func newBaseClient(opt *Options, connPool pool.Pooler) *baseClient {
158	return &baseClient{
159		opt:      opt,
160		connPool: connPool,
161	}
162}
163
164func (c *baseClient) clone() *baseClient {
165	clone := *c
166	return &clone
167}
168
169func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
170	opt := c.opt.clone()
171	opt.ReadTimeout = timeout
172	opt.WriteTimeout = timeout
173
174	clone := c.clone()
175	clone.opt = opt
176
177	return clone
178}
179
180func (c *baseClient) String() string {
181	return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
182}
183
184func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
185	cn, err := c.connPool.NewConn(ctx)
186	if err != nil {
187		return nil, err
188	}
189
190	err = c.initConn(ctx, cn)
191	if err != nil {
192		_ = c.connPool.CloseConn(cn)
193		return nil, err
194	}
195
196	return cn, nil
197}
198
199func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
200	if c.opt.Limiter != nil {
201		err := c.opt.Limiter.Allow()
202		if err != nil {
203			return nil, err
204		}
205	}
206
207	cn, err := c._getConn(ctx)
208	if err != nil {
209		if c.opt.Limiter != nil {
210			c.opt.Limiter.ReportResult(err)
211		}
212		return nil, err
213	}
214
215	return cn, nil
216}
217
218func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
219	cn, err := c.connPool.Get(ctx)
220	if err != nil {
221		return nil, err
222	}
223
224	if cn.Inited {
225		return cn, nil
226	}
227
228	err = internal.WithSpan(ctx, "init_conn", func(ctx context.Context, span trace.Span) error {
229		return c.initConn(ctx, cn)
230	})
231	if err != nil {
232		c.connPool.Remove(ctx, cn, err)
233		if err := internal.Unwrap(err); err != nil {
234			return nil, err
235		}
236		return nil, err
237	}
238
239	return cn, nil
240}
241
242func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
243	if cn.Inited {
244		return nil
245	}
246	cn.Inited = true
247
248	if c.opt.Password == "" &&
249		c.opt.DB == 0 &&
250		!c.opt.readOnly &&
251		c.opt.OnConnect == nil {
252		return nil
253	}
254
255	connPool := pool.NewSingleConnPool(c.connPool, cn)
256	conn := newConn(ctx, c.opt, connPool)
257
258	_, err := conn.Pipelined(ctx, func(pipe Pipeliner) error {
259		if c.opt.Password != "" {
260			if c.opt.Username != "" {
261				pipe.AuthACL(ctx, c.opt.Username, c.opt.Password)
262			} else {
263				pipe.Auth(ctx, c.opt.Password)
264			}
265		}
266
267		if c.opt.DB > 0 {
268			pipe.Select(ctx, c.opt.DB)
269		}
270
271		if c.opt.readOnly {
272			pipe.ReadOnly(ctx)
273		}
274
275		return nil
276	})
277	if err != nil {
278		return err
279	}
280
281	if c.opt.OnConnect != nil {
282		return c.opt.OnConnect(ctx, conn)
283	}
284	return nil
285}
286
287func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) {
288	if c.opt.Limiter != nil {
289		c.opt.Limiter.ReportResult(err)
290	}
291
292	if isBadConn(err, false) {
293		c.connPool.Remove(ctx, cn, err)
294	} else {
295		c.connPool.Put(ctx, cn)
296	}
297}
298
299func (c *baseClient) withConn(
300	ctx context.Context, fn func(context.Context, *pool.Conn) error,
301) error {
302	return internal.WithSpan(ctx, "with_conn", func(ctx context.Context, span trace.Span) error {
303		cn, err := c.getConn(ctx)
304		if err != nil {
305			return err
306		}
307
308		if span.IsRecording() {
309			if remoteAddr := cn.RemoteAddr(); remoteAddr != nil {
310				span.SetAttributes(label.String("net.peer.ip", remoteAddr.String()))
311			}
312		}
313
314		defer func() {
315			c.releaseConn(ctx, cn, err)
316		}()
317
318		err = fn(ctx, cn)
319		return err
320	})
321}
322
323func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
324	var lastErr error
325	for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
326		attempt := attempt
327
328		var retry bool
329		err := internal.WithSpan(ctx, "process", func(ctx context.Context, span trace.Span) error {
330			if attempt > 0 {
331				if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
332					return err
333				}
334			}
335
336			retryTimeout := true
337			err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
338				err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
339					return writeCmd(wr, cmd)
340				})
341				if err != nil {
342					return err
343				}
344
345				err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply)
346				if err != nil {
347					retryTimeout = cmd.readTimeout() == nil
348					return err
349				}
350
351				return nil
352			})
353			if err == nil {
354				return nil
355			}
356			retry = shouldRetry(err, retryTimeout)
357			return err
358		})
359		if err == nil || !retry {
360			return err
361		}
362		lastErr = err
363	}
364	return lastErr
365}
366
367func (c *baseClient) retryBackoff(attempt int) time.Duration {
368	return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
369}
370
371func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
372	if timeout := cmd.readTimeout(); timeout != nil {
373		t := *timeout
374		if t == 0 {
375			return 0
376		}
377		return t + 10*time.Second
378	}
379	return c.opt.ReadTimeout
380}
381
382// Close closes the client, releasing any open resources.
383//
384// It is rare to Close a Client, as the Client is meant to be
385// long-lived and shared between many goroutines.
386func (c *baseClient) Close() error {
387	var firstErr error
388	if c.onClose != nil {
389		if err := c.onClose(); err != nil {
390			firstErr = err
391		}
392	}
393	if err := c.connPool.Close(); err != nil && firstErr == nil {
394		firstErr = err
395	}
396	return firstErr
397}
398
399func (c *baseClient) getAddr() string {
400	return c.opt.Addr
401}
402
403func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error {
404	return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds)
405}
406
407func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
408	return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds)
409}
410
411type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error)
412
413func (c *baseClient) generalProcessPipeline(
414	ctx context.Context, cmds []Cmder, p pipelineProcessor,
415) error {
416	err := c._generalProcessPipeline(ctx, cmds, p)
417	if err != nil {
418		setCmdsErr(cmds, err)
419		return err
420	}
421	return cmdsFirstErr(cmds)
422}
423
424func (c *baseClient) _generalProcessPipeline(
425	ctx context.Context, cmds []Cmder, p pipelineProcessor,
426) error {
427	var lastErr error
428	for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
429		if attempt > 0 {
430			if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
431				return err
432			}
433		}
434
435		var canRetry bool
436		lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
437			var err error
438			canRetry, err = p(ctx, cn, cmds)
439			return err
440		})
441		if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) {
442			return lastErr
443		}
444	}
445	return lastErr
446}
447
448func (c *baseClient) pipelineProcessCmds(
449	ctx context.Context, cn *pool.Conn, cmds []Cmder,
450) (bool, error) {
451	err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
452		return writeCmds(wr, cmds)
453	})
454	if err != nil {
455		return true, err
456	}
457
458	err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
459		return pipelineReadCmds(rd, cmds)
460	})
461	return true, err
462}
463
464func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
465	for _, cmd := range cmds {
466		err := cmd.readReply(rd)
467		cmd.SetErr(err)
468		if err != nil && !isRedisError(err) {
469			return err
470		}
471	}
472	return nil
473}
474
475func (c *baseClient) txPipelineProcessCmds(
476	ctx context.Context, cn *pool.Conn, cmds []Cmder,
477) (bool, error) {
478	err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
479		return writeCmds(wr, cmds)
480	})
481	if err != nil {
482		return true, err
483	}
484
485	err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
486		statusCmd := cmds[0].(*StatusCmd)
487		// Trim multi and exec.
488		cmds = cmds[1 : len(cmds)-1]
489
490		err := txPipelineReadQueued(rd, statusCmd, cmds)
491		if err != nil {
492			return err
493		}
494
495		return pipelineReadCmds(rd, cmds)
496	})
497	return false, err
498}
499
500func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder {
501	if len(cmds) == 0 {
502		panic("not reached")
503	}
504	cmdCopy := make([]Cmder, len(cmds)+2)
505	cmdCopy[0] = NewStatusCmd(ctx, "multi")
506	copy(cmdCopy[1:], cmds)
507	cmdCopy[len(cmdCopy)-1] = NewSliceCmd(ctx, "exec")
508	return cmdCopy
509}
510
511func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error {
512	// Parse queued replies.
513	if err := statusCmd.readReply(rd); err != nil {
514		return err
515	}
516
517	for range cmds {
518		if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) {
519			return err
520		}
521	}
522
523	// Parse number of replies.
524	line, err := rd.ReadLine()
525	if err != nil {
526		if err == Nil {
527			err = TxFailedErr
528		}
529		return err
530	}
531
532	switch line[0] {
533	case proto.ErrorReply:
534		return proto.ParseErrorReply(line)
535	case proto.ArrayReply:
536		// ok
537	default:
538		err := fmt.Errorf("redis: expected '*', but got line %q", line)
539		return err
540	}
541
542	return nil
543}
544
545//------------------------------------------------------------------------------
546
547// Client is a Redis client representing a pool of zero or more
548// underlying connections. It's safe for concurrent use by multiple
549// goroutines.
550type Client struct {
551	*baseClient
552	cmdable
553	hooks
554	ctx context.Context
555}
556
557// NewClient returns a client to the Redis Server specified by Options.
558func NewClient(opt *Options) *Client {
559	opt.init()
560
561	c := Client{
562		baseClient: newBaseClient(opt, newConnPool(opt)),
563		ctx:        context.Background(),
564	}
565	c.cmdable = c.Process
566
567	return &c
568}
569
570func (c *Client) clone() *Client {
571	clone := *c
572	clone.cmdable = clone.Process
573	clone.hooks.lock()
574	return &clone
575}
576
577func (c *Client) WithTimeout(timeout time.Duration) *Client {
578	clone := c.clone()
579	clone.baseClient = c.baseClient.withTimeout(timeout)
580	return clone
581}
582
583func (c *Client) Context() context.Context {
584	return c.ctx
585}
586
587func (c *Client) WithContext(ctx context.Context) *Client {
588	if ctx == nil {
589		panic("nil context")
590	}
591	clone := c.clone()
592	clone.ctx = ctx
593	return clone
594}
595
596func (c *Client) Conn(ctx context.Context) *Conn {
597	return newConn(ctx, c.opt, pool.NewStickyConnPool(c.connPool))
598}
599
600// Do creates a Cmd from the args and processes the cmd.
601func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd {
602	cmd := NewCmd(ctx, args...)
603	_ = c.Process(ctx, cmd)
604	return cmd
605}
606
607func (c *Client) Process(ctx context.Context, cmd Cmder) error {
608	return c.hooks.process(ctx, cmd, c.baseClient.process)
609}
610
611func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error {
612	return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
613}
614
615func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error {
616	return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline)
617}
618
619// Options returns read-only Options that were used to create the client.
620func (c *Client) Options() *Options {
621	return c.opt
622}
623
624type PoolStats pool.Stats
625
626// PoolStats returns connection pool stats.
627func (c *Client) PoolStats() *PoolStats {
628	stats := c.connPool.Stats()
629	return (*PoolStats)(stats)
630}
631
632func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
633	return c.Pipeline().Pipelined(ctx, fn)
634}
635
636func (c *Client) Pipeline() Pipeliner {
637	pipe := Pipeline{
638		ctx:  c.ctx,
639		exec: c.processPipeline,
640	}
641	pipe.init()
642	return &pipe
643}
644
645func (c *Client) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
646	return c.TxPipeline().Pipelined(ctx, fn)
647}
648
649// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
650func (c *Client) TxPipeline() Pipeliner {
651	pipe := Pipeline{
652		ctx:  c.ctx,
653		exec: c.processTxPipeline,
654	}
655	pipe.init()
656	return &pipe
657}
658
659func (c *Client) pubSub() *PubSub {
660	pubsub := &PubSub{
661		opt: c.opt,
662
663		newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
664			return c.newConn(ctx)
665		},
666		closeConn: c.connPool.CloseConn,
667	}
668	pubsub.init()
669	return pubsub
670}
671
672// Subscribe subscribes the client to the specified channels.
673// Channels can be omitted to create empty subscription.
674// Note that this method does not wait on a response from Redis, so the
675// subscription may not be active immediately. To force the connection to wait,
676// you may call the Receive() method on the returned *PubSub like so:
677//
678//    sub := client.Subscribe(queryResp)
679//    iface, err := sub.Receive()
680//    if err != nil {
681//        // handle error
682//    }
683//
684//    // Should be *Subscription, but others are possible if other actions have been
685//    // taken on sub since it was created.
686//    switch iface.(type) {
687//    case *Subscription:
688//        // subscribe succeeded
689//    case *Message:
690//        // received first message
691//    case *Pong:
692//        // pong received
693//    default:
694//        // handle error
695//    }
696//
697//    ch := sub.Channel()
698func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub {
699	pubsub := c.pubSub()
700	if len(channels) > 0 {
701		_ = pubsub.Subscribe(ctx, channels...)
702	}
703	return pubsub
704}
705
706// PSubscribe subscribes the client to the given patterns.
707// Patterns can be omitted to create empty subscription.
708func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub {
709	pubsub := c.pubSub()
710	if len(channels) > 0 {
711		_ = pubsub.PSubscribe(ctx, channels...)
712	}
713	return pubsub
714}
715
716//------------------------------------------------------------------------------
717
718type conn struct {
719	baseClient
720	cmdable
721	statefulCmdable
722}
723
724// Conn is like Client, but its pool contains single connection.
725type Conn struct {
726	*conn
727	ctx context.Context
728}
729
730func newConn(ctx context.Context, opt *Options, connPool pool.Pooler) *Conn {
731	c := Conn{
732		conn: &conn{
733			baseClient: baseClient{
734				opt:      opt,
735				connPool: connPool,
736			},
737		},
738		ctx: ctx,
739	}
740	c.cmdable = c.Process
741	c.statefulCmdable = c.Process
742	return &c
743}
744
745func (c *Conn) Process(ctx context.Context, cmd Cmder) error {
746	return c.baseClient.process(ctx, cmd)
747}
748
749func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
750	return c.Pipeline().Pipelined(ctx, fn)
751}
752
753func (c *Conn) Pipeline() Pipeliner {
754	pipe := Pipeline{
755		ctx:  c.ctx,
756		exec: c.processPipeline,
757	}
758	pipe.init()
759	return &pipe
760}
761
762func (c *Conn) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
763	return c.TxPipeline().Pipelined(ctx, fn)
764}
765
766// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
767func (c *Conn) TxPipeline() Pipeliner {
768	pipe := Pipeline{
769		ctx:  c.ctx,
770		exec: c.processTxPipeline,
771	}
772	pipe.init()
773	return &pipe
774}
775