1package redis
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"sync/atomic"
8	"time"
9
10	"github.com/go-redis/redis/v8/internal"
11	"github.com/go-redis/redis/v8/internal/pool"
12	"github.com/go-redis/redis/v8/internal/proto"
13	"go.opentelemetry.io/otel/attribute"
14	"go.opentelemetry.io/otel/trace"
15)
16
17// Nil reply returned by Redis when key does not exist.
18const Nil = proto.Nil
19
20func SetLogger(logger internal.Logging) {
21	internal.Logger = logger
22}
23
24//------------------------------------------------------------------------------
25
26type Hook interface {
27	BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error)
28	AfterProcess(ctx context.Context, cmd Cmder) error
29
30	BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error)
31	AfterProcessPipeline(ctx context.Context, cmds []Cmder) error
32}
33
34type hooks struct {
35	hooks []Hook
36}
37
38func (hs *hooks) lock() {
39	hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)]
40}
41
42func (hs hooks) clone() hooks {
43	clone := hs
44	clone.lock()
45	return clone
46}
47
48func (hs *hooks) AddHook(hook Hook) {
49	hs.hooks = append(hs.hooks, hook)
50}
51
52func (hs hooks) process(
53	ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
54) error {
55	if len(hs.hooks) == 0 {
56		err := hs.withContext(ctx, func() error {
57			return fn(ctx, cmd)
58		})
59		cmd.SetErr(err)
60		return err
61	}
62
63	var hookIndex int
64	var retErr error
65
66	for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
67		ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd)
68		if retErr != nil {
69			cmd.SetErr(retErr)
70		}
71	}
72
73	if retErr == nil {
74		retErr = hs.withContext(ctx, func() error {
75			return fn(ctx, cmd)
76		})
77		cmd.SetErr(retErr)
78	}
79
80	for hookIndex--; hookIndex >= 0; hookIndex-- {
81		if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil {
82			retErr = err
83			cmd.SetErr(retErr)
84		}
85	}
86
87	return retErr
88}
89
90func (hs hooks) processPipeline(
91	ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
92) error {
93	if len(hs.hooks) == 0 {
94		err := hs.withContext(ctx, func() error {
95			return fn(ctx, cmds)
96		})
97		return err
98	}
99
100	var hookIndex int
101	var retErr error
102
103	for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
104		ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds)
105		if retErr != nil {
106			setCmdsErr(cmds, retErr)
107		}
108	}
109
110	if retErr == nil {
111		retErr = hs.withContext(ctx, func() error {
112			return fn(ctx, cmds)
113		})
114	}
115
116	for hookIndex--; hookIndex >= 0; hookIndex-- {
117		if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil {
118			retErr = err
119			setCmdsErr(cmds, retErr)
120		}
121	}
122
123	return retErr
124}
125
126func (hs hooks) processTxPipeline(
127	ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
128) error {
129	cmds = wrapMultiExec(ctx, cmds)
130	return hs.processPipeline(ctx, cmds, fn)
131}
132
133func (hs hooks) withContext(ctx context.Context, fn func() error) error {
134	return fn()
135}
136
137//------------------------------------------------------------------------------
138
139type baseClient struct {
140	opt      *Options
141	connPool pool.Pooler
142
143	onClose func() error // hook called when client is closed
144}
145
146func newBaseClient(opt *Options, connPool pool.Pooler) *baseClient {
147	return &baseClient{
148		opt:      opt,
149		connPool: connPool,
150	}
151}
152
153func (c *baseClient) clone() *baseClient {
154	clone := *c
155	return &clone
156}
157
158func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
159	opt := c.opt.clone()
160	opt.ReadTimeout = timeout
161	opt.WriteTimeout = timeout
162
163	clone := c.clone()
164	clone.opt = opt
165
166	return clone
167}
168
169func (c *baseClient) String() string {
170	return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
171}
172
173func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
174	cn, err := c.connPool.NewConn(ctx)
175	if err != nil {
176		return nil, err
177	}
178
179	err = c.initConn(ctx, cn)
180	if err != nil {
181		_ = c.connPool.CloseConn(cn)
182		return nil, err
183	}
184
185	return cn, nil
186}
187
188func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
189	if c.opt.Limiter != nil {
190		err := c.opt.Limiter.Allow()
191		if err != nil {
192			return nil, err
193		}
194	}
195
196	cn, err := c._getConn(ctx)
197	if err != nil {
198		if c.opt.Limiter != nil {
199			c.opt.Limiter.ReportResult(err)
200		}
201		return nil, err
202	}
203
204	return cn, nil
205}
206
207func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
208	cn, err := c.connPool.Get(ctx)
209	if err != nil {
210		return nil, err
211	}
212
213	if cn.Inited {
214		return cn, nil
215	}
216
217	err = internal.WithSpan(ctx, "redis.init_conn", func(ctx context.Context, span trace.Span) error {
218		return c.initConn(ctx, cn)
219	})
220	if err != nil {
221		c.connPool.Remove(ctx, cn, err)
222		if err := errors.Unwrap(err); err != nil {
223			return nil, err
224		}
225		return nil, err
226	}
227
228	return cn, nil
229}
230
231func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
232	if cn.Inited {
233		return nil
234	}
235	cn.Inited = true
236
237	if c.opt.Password == "" &&
238		c.opt.DB == 0 &&
239		!c.opt.readOnly &&
240		c.opt.OnConnect == nil {
241		return nil
242	}
243
244	connPool := pool.NewSingleConnPool(c.connPool, cn)
245	conn := newConn(ctx, c.opt, connPool)
246
247	_, err := conn.Pipelined(ctx, func(pipe Pipeliner) error {
248		if c.opt.Password != "" {
249			if c.opt.Username != "" {
250				pipe.AuthACL(ctx, c.opt.Username, c.opt.Password)
251			} else {
252				pipe.Auth(ctx, c.opt.Password)
253			}
254		}
255
256		if c.opt.DB > 0 {
257			pipe.Select(ctx, c.opt.DB)
258		}
259
260		if c.opt.readOnly {
261			pipe.ReadOnly(ctx)
262		}
263
264		return nil
265	})
266	if err != nil {
267		return err
268	}
269
270	if c.opt.OnConnect != nil {
271		return c.opt.OnConnect(ctx, conn)
272	}
273	return nil
274}
275
276func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) {
277	if c.opt.Limiter != nil {
278		c.opt.Limiter.ReportResult(err)
279	}
280
281	if isBadConn(err, false) {
282		c.connPool.Remove(ctx, cn, err)
283	} else {
284		c.connPool.Put(ctx, cn)
285	}
286}
287
288func (c *baseClient) withConn(
289	ctx context.Context, fn func(context.Context, *pool.Conn) error,
290) error {
291	return internal.WithSpan(ctx, "redis.with_conn", func(ctx context.Context, span trace.Span) error {
292		cn, err := c.getConn(ctx)
293		if err != nil {
294			return err
295		}
296
297		if span.IsRecording() {
298			if remoteAddr := cn.RemoteAddr(); remoteAddr != nil {
299				span.SetAttributes(attribute.String("net.peer.ip", remoteAddr.String()))
300			}
301		}
302
303		defer func() {
304			c.releaseConn(ctx, cn, err)
305		}()
306
307		done := ctx.Done()
308		if done == nil {
309			err = fn(ctx, cn)
310			return err
311		}
312
313		errc := make(chan error, 1)
314		go func() { errc <- fn(ctx, cn) }()
315
316		select {
317		case <-done:
318			_ = cn.Close()
319			// Wait for the goroutine to finish and send something.
320			<-errc
321
322			err = ctx.Err()
323			return err
324		case err = <-errc:
325			return err
326		}
327	})
328}
329
330func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
331	var lastErr error
332	for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
333		attempt := attempt
334
335		var retry bool
336		err := internal.WithSpan(ctx, "redis.process", func(ctx context.Context, span trace.Span) error {
337			if attempt > 0 {
338				if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
339					return err
340				}
341			}
342
343			retryTimeout := uint32(1)
344			err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
345				err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
346					return writeCmd(wr, cmd)
347				})
348				if err != nil {
349					return err
350				}
351
352				err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply)
353				if err != nil {
354					if cmd.readTimeout() == nil {
355						atomic.StoreUint32(&retryTimeout, 1)
356					}
357					return err
358				}
359
360				return nil
361			})
362			if err == nil {
363				return nil
364			}
365			retry = shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1)
366			return err
367		})
368		if err == nil || !retry {
369			return err
370		}
371		lastErr = err
372	}
373	return lastErr
374}
375
376func (c *baseClient) retryBackoff(attempt int) time.Duration {
377	return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
378}
379
380func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
381	if timeout := cmd.readTimeout(); timeout != nil {
382		t := *timeout
383		if t == 0 {
384			return 0
385		}
386		return t + 10*time.Second
387	}
388	return c.opt.ReadTimeout
389}
390
391// Close closes the client, releasing any open resources.
392//
393// It is rare to Close a Client, as the Client is meant to be
394// long-lived and shared between many goroutines.
395func (c *baseClient) Close() error {
396	var firstErr error
397	if c.onClose != nil {
398		if err := c.onClose(); err != nil {
399			firstErr = err
400		}
401	}
402	if err := c.connPool.Close(); err != nil && firstErr == nil {
403		firstErr = err
404	}
405	return firstErr
406}
407
408func (c *baseClient) getAddr() string {
409	return c.opt.Addr
410}
411
412func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error {
413	return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds)
414}
415
416func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
417	return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds)
418}
419
420type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error)
421
422func (c *baseClient) generalProcessPipeline(
423	ctx context.Context, cmds []Cmder, p pipelineProcessor,
424) error {
425	err := c._generalProcessPipeline(ctx, cmds, p)
426	if err != nil {
427		setCmdsErr(cmds, err)
428		return err
429	}
430	return cmdsFirstErr(cmds)
431}
432
433func (c *baseClient) _generalProcessPipeline(
434	ctx context.Context, cmds []Cmder, p pipelineProcessor,
435) error {
436	var lastErr error
437	for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
438		if attempt > 0 {
439			if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
440				return err
441			}
442		}
443
444		var canRetry bool
445		lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
446			var err error
447			canRetry, err = p(ctx, cn, cmds)
448			return err
449		})
450		if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) {
451			return lastErr
452		}
453	}
454	return lastErr
455}
456
457func (c *baseClient) pipelineProcessCmds(
458	ctx context.Context, cn *pool.Conn, cmds []Cmder,
459) (bool, error) {
460	err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
461		return writeCmds(wr, cmds)
462	})
463	if err != nil {
464		return true, err
465	}
466
467	err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
468		return pipelineReadCmds(rd, cmds)
469	})
470	return true, err
471}
472
473func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
474	for _, cmd := range cmds {
475		err := cmd.readReply(rd)
476		cmd.SetErr(err)
477		if err != nil && !isRedisError(err) {
478			return err
479		}
480	}
481	return nil
482}
483
484func (c *baseClient) txPipelineProcessCmds(
485	ctx context.Context, cn *pool.Conn, cmds []Cmder,
486) (bool, error) {
487	err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
488		return writeCmds(wr, cmds)
489	})
490	if err != nil {
491		return true, err
492	}
493
494	err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
495		statusCmd := cmds[0].(*StatusCmd)
496		// Trim multi and exec.
497		cmds = cmds[1 : len(cmds)-1]
498
499		err := txPipelineReadQueued(rd, statusCmd, cmds)
500		if err != nil {
501			return err
502		}
503
504		return pipelineReadCmds(rd, cmds)
505	})
506	return false, err
507}
508
509func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder {
510	if len(cmds) == 0 {
511		panic("not reached")
512	}
513	cmdCopy := make([]Cmder, len(cmds)+2)
514	cmdCopy[0] = NewStatusCmd(ctx, "multi")
515	copy(cmdCopy[1:], cmds)
516	cmdCopy[len(cmdCopy)-1] = NewSliceCmd(ctx, "exec")
517	return cmdCopy
518}
519
520func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error {
521	// Parse queued replies.
522	if err := statusCmd.readReply(rd); err != nil {
523		return err
524	}
525
526	for range cmds {
527		if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) {
528			return err
529		}
530	}
531
532	// Parse number of replies.
533	line, err := rd.ReadLine()
534	if err != nil {
535		if err == Nil {
536			err = TxFailedErr
537		}
538		return err
539	}
540
541	switch line[0] {
542	case proto.ErrorReply:
543		return proto.ParseErrorReply(line)
544	case proto.ArrayReply:
545		// ok
546	default:
547		err := fmt.Errorf("redis: expected '*', but got line %q", line)
548		return err
549	}
550
551	return nil
552}
553
554//------------------------------------------------------------------------------
555
556// Client is a Redis client representing a pool of zero or more
557// underlying connections. It's safe for concurrent use by multiple
558// goroutines.
559type Client struct {
560	*baseClient
561	cmdable
562	hooks
563	ctx context.Context
564}
565
566// NewClient returns a client to the Redis Server specified by Options.
567func NewClient(opt *Options) *Client {
568	opt.init()
569
570	c := Client{
571		baseClient: newBaseClient(opt, newConnPool(opt)),
572		ctx:        context.Background(),
573	}
574	c.cmdable = c.Process
575
576	return &c
577}
578
579func (c *Client) clone() *Client {
580	clone := *c
581	clone.cmdable = clone.Process
582	clone.hooks.lock()
583	return &clone
584}
585
586func (c *Client) WithTimeout(timeout time.Duration) *Client {
587	clone := c.clone()
588	clone.baseClient = c.baseClient.withTimeout(timeout)
589	return clone
590}
591
592func (c *Client) Context() context.Context {
593	return c.ctx
594}
595
596func (c *Client) WithContext(ctx context.Context) *Client {
597	if ctx == nil {
598		panic("nil context")
599	}
600	clone := c.clone()
601	clone.ctx = ctx
602	return clone
603}
604
605func (c *Client) Conn(ctx context.Context) *Conn {
606	return newConn(ctx, c.opt, pool.NewStickyConnPool(c.connPool))
607}
608
609// Do creates a Cmd from the args and processes the cmd.
610func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd {
611	cmd := NewCmd(ctx, args...)
612	_ = c.Process(ctx, cmd)
613	return cmd
614}
615
616func (c *Client) Process(ctx context.Context, cmd Cmder) error {
617	return c.hooks.process(ctx, cmd, c.baseClient.process)
618}
619
620func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error {
621	return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
622}
623
624func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error {
625	return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline)
626}
627
628// Options returns read-only Options that were used to create the client.
629func (c *Client) Options() *Options {
630	return c.opt
631}
632
633type PoolStats pool.Stats
634
635// PoolStats returns connection pool stats.
636func (c *Client) PoolStats() *PoolStats {
637	stats := c.connPool.Stats()
638	return (*PoolStats)(stats)
639}
640
641func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
642	return c.Pipeline().Pipelined(ctx, fn)
643}
644
645func (c *Client) Pipeline() Pipeliner {
646	pipe := Pipeline{
647		ctx:  c.ctx,
648		exec: c.processPipeline,
649	}
650	pipe.init()
651	return &pipe
652}
653
654func (c *Client) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
655	return c.TxPipeline().Pipelined(ctx, fn)
656}
657
658// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
659func (c *Client) TxPipeline() Pipeliner {
660	pipe := Pipeline{
661		ctx:  c.ctx,
662		exec: c.processTxPipeline,
663	}
664	pipe.init()
665	return &pipe
666}
667
668func (c *Client) pubSub() *PubSub {
669	pubsub := &PubSub{
670		opt: c.opt,
671
672		newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
673			return c.newConn(ctx)
674		},
675		closeConn: c.connPool.CloseConn,
676	}
677	pubsub.init()
678	return pubsub
679}
680
681// Subscribe subscribes the client to the specified channels.
682// Channels can be omitted to create empty subscription.
683// Note that this method does not wait on a response from Redis, so the
684// subscription may not be active immediately. To force the connection to wait,
685// you may call the Receive() method on the returned *PubSub like so:
686//
687//    sub := client.Subscribe(queryResp)
688//    iface, err := sub.Receive()
689//    if err != nil {
690//        // handle error
691//    }
692//
693//    // Should be *Subscription, but others are possible if other actions have been
694//    // taken on sub since it was created.
695//    switch iface.(type) {
696//    case *Subscription:
697//        // subscribe succeeded
698//    case *Message:
699//        // received first message
700//    case *Pong:
701//        // pong received
702//    default:
703//        // handle error
704//    }
705//
706//    ch := sub.Channel()
707func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub {
708	pubsub := c.pubSub()
709	if len(channels) > 0 {
710		_ = pubsub.Subscribe(ctx, channels...)
711	}
712	return pubsub
713}
714
715// PSubscribe subscribes the client to the given patterns.
716// Patterns can be omitted to create empty subscription.
717func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub {
718	pubsub := c.pubSub()
719	if len(channels) > 0 {
720		_ = pubsub.PSubscribe(ctx, channels...)
721	}
722	return pubsub
723}
724
725//------------------------------------------------------------------------------
726
727type conn struct {
728	baseClient
729	cmdable
730	statefulCmdable
731	hooks // TODO: inherit hooks
732}
733
734// Conn is like Client, but its pool contains single connection.
735type Conn struct {
736	*conn
737	ctx context.Context
738}
739
740func newConn(ctx context.Context, opt *Options, connPool pool.Pooler) *Conn {
741	c := Conn{
742		conn: &conn{
743			baseClient: baseClient{
744				opt:      opt,
745				connPool: connPool,
746			},
747		},
748		ctx: ctx,
749	}
750	c.cmdable = c.Process
751	c.statefulCmdable = c.Process
752	return &c
753}
754
755func (c *Conn) Process(ctx context.Context, cmd Cmder) error {
756	return c.hooks.process(ctx, cmd, c.baseClient.process)
757}
758
759func (c *Conn) processPipeline(ctx context.Context, cmds []Cmder) error {
760	return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
761}
762
763func (c *Conn) processTxPipeline(ctx context.Context, cmds []Cmder) error {
764	return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline)
765}
766
767func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
768	return c.Pipeline().Pipelined(ctx, fn)
769}
770
771func (c *Conn) Pipeline() Pipeliner {
772	pipe := Pipeline{
773		ctx:  c.ctx,
774		exec: c.processPipeline,
775	}
776	pipe.init()
777	return &pipe
778}
779
780func (c *Conn) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
781	return c.TxPipeline().Pipelined(ctx, fn)
782}
783
784// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
785func (c *Conn) TxPipeline() Pipeliner {
786	pipe := Pipeline{
787		ctx:  c.ctx,
788		exec: c.processTxPipeline,
789	}
790	pipe.init()
791	return &pipe
792}
793