1package redis
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"sync/atomic"
8	"time"
9
10	"github.com/go-redis/redis/v8/internal"
11	"github.com/go-redis/redis/v8/internal/pool"
12	"github.com/go-redis/redis/v8/internal/proto"
13	"go.opentelemetry.io/otel/attribute"
14)
15
16// Nil reply returned by Redis when key does not exist.
17const Nil = proto.Nil
18
19func SetLogger(logger internal.Logging) {
20	internal.Logger = logger
21}
22
23//------------------------------------------------------------------------------
24
25type Hook interface {
26	BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error)
27	AfterProcess(ctx context.Context, cmd Cmder) error
28
29	BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error)
30	AfterProcessPipeline(ctx context.Context, cmds []Cmder) error
31}
32
33type hooks struct {
34	hooks []Hook
35}
36
37func (hs *hooks) lock() {
38	hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)]
39}
40
41func (hs hooks) clone() hooks {
42	clone := hs
43	clone.lock()
44	return clone
45}
46
47func (hs *hooks) AddHook(hook Hook) {
48	hs.hooks = append(hs.hooks, hook)
49}
50
51func (hs hooks) process(
52	ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
53) error {
54	if len(hs.hooks) == 0 {
55		err := hs.withContext(ctx, func() error {
56			return fn(ctx, cmd)
57		})
58		cmd.SetErr(err)
59		return err
60	}
61
62	var hookIndex int
63	var retErr error
64
65	for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
66		ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd)
67		if retErr != nil {
68			cmd.SetErr(retErr)
69		}
70	}
71
72	if retErr == nil {
73		retErr = hs.withContext(ctx, func() error {
74			return fn(ctx, cmd)
75		})
76		cmd.SetErr(retErr)
77	}
78
79	for hookIndex--; hookIndex >= 0; hookIndex-- {
80		if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil {
81			retErr = err
82			cmd.SetErr(retErr)
83		}
84	}
85
86	return retErr
87}
88
89func (hs hooks) processPipeline(
90	ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
91) error {
92	if len(hs.hooks) == 0 {
93		err := hs.withContext(ctx, func() error {
94			return fn(ctx, cmds)
95		})
96		return err
97	}
98
99	var hookIndex int
100	var retErr error
101
102	for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
103		ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds)
104		if retErr != nil {
105			setCmdsErr(cmds, retErr)
106		}
107	}
108
109	if retErr == nil {
110		retErr = hs.withContext(ctx, func() error {
111			return fn(ctx, cmds)
112		})
113	}
114
115	for hookIndex--; hookIndex >= 0; hookIndex-- {
116		if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil {
117			retErr = err
118			setCmdsErr(cmds, retErr)
119		}
120	}
121
122	return retErr
123}
124
125func (hs hooks) processTxPipeline(
126	ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
127) error {
128	cmds = wrapMultiExec(ctx, cmds)
129	return hs.processPipeline(ctx, cmds, fn)
130}
131
132func (hs hooks) withContext(ctx context.Context, fn func() error) error {
133	return fn()
134}
135
136//------------------------------------------------------------------------------
137
138type baseClient struct {
139	opt      *Options
140	connPool pool.Pooler
141
142	onClose func() error // hook called when client is closed
143}
144
145func newBaseClient(opt *Options, connPool pool.Pooler) *baseClient {
146	return &baseClient{
147		opt:      opt,
148		connPool: connPool,
149	}
150}
151
152func (c *baseClient) clone() *baseClient {
153	clone := *c
154	return &clone
155}
156
157func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
158	opt := c.opt.clone()
159	opt.ReadTimeout = timeout
160	opt.WriteTimeout = timeout
161
162	clone := c.clone()
163	clone.opt = opt
164
165	return clone
166}
167
168func (c *baseClient) String() string {
169	return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
170}
171
172func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
173	cn, err := c.connPool.NewConn(ctx)
174	if err != nil {
175		return nil, err
176	}
177
178	err = c.initConn(ctx, cn)
179	if err != nil {
180		_ = c.connPool.CloseConn(cn)
181		return nil, err
182	}
183
184	return cn, nil
185}
186
187func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
188	if c.opt.Limiter != nil {
189		err := c.opt.Limiter.Allow()
190		if err != nil {
191			return nil, err
192		}
193	}
194
195	cn, err := c._getConn(ctx)
196	if err != nil {
197		if c.opt.Limiter != nil {
198			c.opt.Limiter.ReportResult(err)
199		}
200		return nil, err
201	}
202
203	return cn, nil
204}
205
206func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
207	cn, err := c.connPool.Get(ctx)
208	if err != nil {
209		return nil, err
210	}
211
212	if cn.Inited {
213		return cn, nil
214	}
215
216	if err := c.initConn(ctx, cn); err != nil {
217		c.connPool.Remove(ctx, cn, err)
218		if err := errors.Unwrap(err); err != nil {
219			return nil, err
220		}
221		return nil, err
222	}
223
224	return cn, nil
225}
226
227func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
228	if cn.Inited {
229		return nil
230	}
231	cn.Inited = true
232
233	if c.opt.Password == "" &&
234		c.opt.DB == 0 &&
235		!c.opt.readOnly &&
236		c.opt.OnConnect == nil {
237		return nil
238	}
239
240	ctx, span := internal.StartSpan(ctx, "redis.init_conn")
241	defer span.End()
242
243	connPool := pool.NewSingleConnPool(c.connPool, cn)
244	conn := newConn(ctx, c.opt, connPool)
245
246	_, err := conn.Pipelined(ctx, func(pipe Pipeliner) error {
247		if c.opt.Password != "" {
248			if c.opt.Username != "" {
249				pipe.AuthACL(ctx, c.opt.Username, c.opt.Password)
250			} else {
251				pipe.Auth(ctx, c.opt.Password)
252			}
253		}
254
255		if c.opt.DB > 0 {
256			pipe.Select(ctx, c.opt.DB)
257		}
258
259		if c.opt.readOnly {
260			pipe.ReadOnly(ctx)
261		}
262
263		return nil
264	})
265	if err != nil {
266		return err
267	}
268
269	if c.opt.OnConnect != nil {
270		return c.opt.OnConnect(ctx, conn)
271	}
272	return nil
273}
274
275func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) {
276	if c.opt.Limiter != nil {
277		c.opt.Limiter.ReportResult(err)
278	}
279
280	if isBadConn(err, false) {
281		c.connPool.Remove(ctx, cn, err)
282	} else {
283		c.connPool.Put(ctx, cn)
284	}
285}
286
287func (c *baseClient) withConn(
288	ctx context.Context, fn func(context.Context, *pool.Conn) error,
289) error {
290	ctx, span := internal.StartSpan(ctx, "redis.with_conn")
291	defer span.End()
292
293	cn, err := c.getConn(ctx)
294	if err != nil {
295		return err
296	}
297
298	if span.IsRecording() {
299		if remoteAddr := cn.RemoteAddr(); remoteAddr != nil {
300			span.SetAttributes(attribute.String("net.peer.ip", remoteAddr.String()))
301		}
302	}
303
304	defer func() {
305		c.releaseConn(ctx, cn, err)
306	}()
307
308	done := ctx.Done() //nolint:ifshort
309
310	if done == nil {
311		err = fn(ctx, cn)
312		return err
313	}
314
315	errc := make(chan error, 1)
316	go func() { errc <- fn(ctx, cn) }()
317
318	select {
319	case <-done:
320		_ = cn.Close()
321		// Wait for the goroutine to finish and send something.
322		<-errc
323
324		err = ctx.Err()
325		return err
326	case err = <-errc:
327		return err
328	}
329}
330
331func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
332	var lastErr error
333	for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
334		attempt := attempt
335
336		retry, err := c._process(ctx, cmd, attempt)
337		if err == nil || !retry {
338			return err
339		}
340
341		lastErr = err
342	}
343	return lastErr
344}
345
346func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool, error) {
347	if attempt > 0 {
348		if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
349			return false, err
350		}
351	}
352
353	retryTimeout := uint32(1)
354	err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
355		err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
356			return writeCmd(wr, cmd)
357		})
358		if err != nil {
359			return err
360		}
361
362		err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply)
363		if err != nil {
364			if cmd.readTimeout() == nil {
365				atomic.StoreUint32(&retryTimeout, 1)
366			}
367			return err
368		}
369
370		return nil
371	})
372	if err == nil {
373		return false, nil
374	}
375
376	retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1)
377	return retry, err
378}
379
380func (c *baseClient) retryBackoff(attempt int) time.Duration {
381	return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
382}
383
384func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
385	if timeout := cmd.readTimeout(); timeout != nil {
386		t := *timeout
387		if t == 0 {
388			return 0
389		}
390		return t + 10*time.Second
391	}
392	return c.opt.ReadTimeout
393}
394
395// Close closes the client, releasing any open resources.
396//
397// It is rare to Close a Client, as the Client is meant to be
398// long-lived and shared between many goroutines.
399func (c *baseClient) Close() error {
400	var firstErr error
401	if c.onClose != nil {
402		if err := c.onClose(); err != nil {
403			firstErr = err
404		}
405	}
406	if err := c.connPool.Close(); err != nil && firstErr == nil {
407		firstErr = err
408	}
409	return firstErr
410}
411
412func (c *baseClient) getAddr() string {
413	return c.opt.Addr
414}
415
416func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error {
417	return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds)
418}
419
420func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
421	return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds)
422}
423
424type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error)
425
426func (c *baseClient) generalProcessPipeline(
427	ctx context.Context, cmds []Cmder, p pipelineProcessor,
428) error {
429	err := c._generalProcessPipeline(ctx, cmds, p)
430	if err != nil {
431		setCmdsErr(cmds, err)
432		return err
433	}
434	return cmdsFirstErr(cmds)
435}
436
437func (c *baseClient) _generalProcessPipeline(
438	ctx context.Context, cmds []Cmder, p pipelineProcessor,
439) error {
440	var lastErr error
441	for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
442		if attempt > 0 {
443			if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
444				return err
445			}
446		}
447
448		var canRetry bool
449		lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
450			var err error
451			canRetry, err = p(ctx, cn, cmds)
452			return err
453		})
454		if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) {
455			return lastErr
456		}
457	}
458	return lastErr
459}
460
461func (c *baseClient) pipelineProcessCmds(
462	ctx context.Context, cn *pool.Conn, cmds []Cmder,
463) (bool, error) {
464	err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
465		return writeCmds(wr, cmds)
466	})
467	if err != nil {
468		return true, err
469	}
470
471	err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
472		return pipelineReadCmds(rd, cmds)
473	})
474	return true, err
475}
476
477func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
478	for _, cmd := range cmds {
479		err := cmd.readReply(rd)
480		cmd.SetErr(err)
481		if err != nil && !isRedisError(err) {
482			return err
483		}
484	}
485	return nil
486}
487
488func (c *baseClient) txPipelineProcessCmds(
489	ctx context.Context, cn *pool.Conn, cmds []Cmder,
490) (bool, error) {
491	err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
492		return writeCmds(wr, cmds)
493	})
494	if err != nil {
495		return true, err
496	}
497
498	err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
499		statusCmd := cmds[0].(*StatusCmd)
500		// Trim multi and exec.
501		cmds = cmds[1 : len(cmds)-1]
502
503		err := txPipelineReadQueued(rd, statusCmd, cmds)
504		if err != nil {
505			return err
506		}
507
508		return pipelineReadCmds(rd, cmds)
509	})
510	return false, err
511}
512
513func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder {
514	if len(cmds) == 0 {
515		panic("not reached")
516	}
517	cmdCopy := make([]Cmder, len(cmds)+2)
518	cmdCopy[0] = NewStatusCmd(ctx, "multi")
519	copy(cmdCopy[1:], cmds)
520	cmdCopy[len(cmdCopy)-1] = NewSliceCmd(ctx, "exec")
521	return cmdCopy
522}
523
524func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error {
525	// Parse queued replies.
526	if err := statusCmd.readReply(rd); err != nil {
527		return err
528	}
529
530	for range cmds {
531		if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) {
532			return err
533		}
534	}
535
536	// Parse number of replies.
537	line, err := rd.ReadLine()
538	if err != nil {
539		if err == Nil {
540			err = TxFailedErr
541		}
542		return err
543	}
544
545	switch line[0] {
546	case proto.ErrorReply:
547		return proto.ParseErrorReply(line)
548	case proto.ArrayReply:
549		// ok
550	default:
551		err := fmt.Errorf("redis: expected '*', but got line %q", line)
552		return err
553	}
554
555	return nil
556}
557
558//------------------------------------------------------------------------------
559
560// Client is a Redis client representing a pool of zero or more
561// underlying connections. It's safe for concurrent use by multiple
562// goroutines.
563type Client struct {
564	*baseClient
565	cmdable
566	hooks
567	ctx context.Context
568}
569
570// NewClient returns a client to the Redis Server specified by Options.
571func NewClient(opt *Options) *Client {
572	opt.init()
573
574	c := Client{
575		baseClient: newBaseClient(opt, newConnPool(opt)),
576		ctx:        context.Background(),
577	}
578	c.cmdable = c.Process
579
580	return &c
581}
582
583func (c *Client) clone() *Client {
584	clone := *c
585	clone.cmdable = clone.Process
586	clone.hooks.lock()
587	return &clone
588}
589
590func (c *Client) WithTimeout(timeout time.Duration) *Client {
591	clone := c.clone()
592	clone.baseClient = c.baseClient.withTimeout(timeout)
593	return clone
594}
595
596func (c *Client) Context() context.Context {
597	return c.ctx
598}
599
600func (c *Client) WithContext(ctx context.Context) *Client {
601	if ctx == nil {
602		panic("nil context")
603	}
604	clone := c.clone()
605	clone.ctx = ctx
606	return clone
607}
608
609func (c *Client) Conn(ctx context.Context) *Conn {
610	return newConn(ctx, c.opt, pool.NewStickyConnPool(c.connPool))
611}
612
613// Do creates a Cmd from the args and processes the cmd.
614func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd {
615	cmd := NewCmd(ctx, args...)
616	_ = c.Process(ctx, cmd)
617	return cmd
618}
619
620func (c *Client) Process(ctx context.Context, cmd Cmder) error {
621	return c.hooks.process(ctx, cmd, c.baseClient.process)
622}
623
624func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error {
625	return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
626}
627
628func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error {
629	return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline)
630}
631
632// Options returns read-only Options that were used to create the client.
633func (c *Client) Options() *Options {
634	return c.opt
635}
636
637type PoolStats pool.Stats
638
639// PoolStats returns connection pool stats.
640func (c *Client) PoolStats() *PoolStats {
641	stats := c.connPool.Stats()
642	return (*PoolStats)(stats)
643}
644
645func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
646	return c.Pipeline().Pipelined(ctx, fn)
647}
648
649func (c *Client) Pipeline() Pipeliner {
650	pipe := Pipeline{
651		ctx:  c.ctx,
652		exec: c.processPipeline,
653	}
654	pipe.init()
655	return &pipe
656}
657
658func (c *Client) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
659	return c.TxPipeline().Pipelined(ctx, fn)
660}
661
662// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
663func (c *Client) TxPipeline() Pipeliner {
664	pipe := Pipeline{
665		ctx:  c.ctx,
666		exec: c.processTxPipeline,
667	}
668	pipe.init()
669	return &pipe
670}
671
672func (c *Client) pubSub() *PubSub {
673	pubsub := &PubSub{
674		opt: c.opt,
675
676		newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
677			return c.newConn(ctx)
678		},
679		closeConn: c.connPool.CloseConn,
680	}
681	pubsub.init()
682	return pubsub
683}
684
685// Subscribe subscribes the client to the specified channels.
686// Channels can be omitted to create empty subscription.
687// Note that this method does not wait on a response from Redis, so the
688// subscription may not be active immediately. To force the connection to wait,
689// you may call the Receive() method on the returned *PubSub like so:
690//
691//    sub := client.Subscribe(queryResp)
692//    iface, err := sub.Receive()
693//    if err != nil {
694//        // handle error
695//    }
696//
697//    // Should be *Subscription, but others are possible if other actions have been
698//    // taken on sub since it was created.
699//    switch iface.(type) {
700//    case *Subscription:
701//        // subscribe succeeded
702//    case *Message:
703//        // received first message
704//    case *Pong:
705//        // pong received
706//    default:
707//        // handle error
708//    }
709//
710//    ch := sub.Channel()
711func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub {
712	pubsub := c.pubSub()
713	if len(channels) > 0 {
714		_ = pubsub.Subscribe(ctx, channels...)
715	}
716	return pubsub
717}
718
719// PSubscribe subscribes the client to the given patterns.
720// Patterns can be omitted to create empty subscription.
721func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub {
722	pubsub := c.pubSub()
723	if len(channels) > 0 {
724		_ = pubsub.PSubscribe(ctx, channels...)
725	}
726	return pubsub
727}
728
729//------------------------------------------------------------------------------
730
731type conn struct {
732	baseClient
733	cmdable
734	statefulCmdable
735	hooks // TODO: inherit hooks
736}
737
738// Conn is like Client, but its pool contains single connection.
739type Conn struct {
740	*conn
741	ctx context.Context
742}
743
744func newConn(ctx context.Context, opt *Options, connPool pool.Pooler) *Conn {
745	c := Conn{
746		conn: &conn{
747			baseClient: baseClient{
748				opt:      opt,
749				connPool: connPool,
750			},
751		},
752		ctx: ctx,
753	}
754	c.cmdable = c.Process
755	c.statefulCmdable = c.Process
756	return &c
757}
758
759func (c *Conn) Process(ctx context.Context, cmd Cmder) error {
760	return c.hooks.process(ctx, cmd, c.baseClient.process)
761}
762
763func (c *Conn) processPipeline(ctx context.Context, cmds []Cmder) error {
764	return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
765}
766
767func (c *Conn) processTxPipeline(ctx context.Context, cmds []Cmder) error {
768	return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline)
769}
770
771func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
772	return c.Pipeline().Pipelined(ctx, fn)
773}
774
775func (c *Conn) Pipeline() Pipeliner {
776	pipe := Pipeline{
777		ctx:  c.ctx,
778		exec: c.processPipeline,
779	}
780	pipe.init()
781	return &pipe
782}
783
784func (c *Conn) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
785	return c.TxPipeline().Pipelined(ctx, fn)
786}
787
788// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
789func (c *Conn) TxPipeline() Pipeliner {
790	pipe := Pipeline{
791		ctx:  c.ctx,
792		exec: c.processTxPipeline,
793	}
794	pipe.init()
795	return &pipe
796}
797