1package redis
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"sync/atomic"
8	"time"
9
10	"github.com/go-redis/redis/v8/internal"
11	"github.com/go-redis/redis/v8/internal/pool"
12	"github.com/go-redis/redis/v8/internal/proto"
13)
14
15// Nil reply returned by Redis when key does not exist.
16const Nil = proto.Nil
17
18func SetLogger(logger internal.Logging) {
19	internal.Logger = logger
20}
21
22//------------------------------------------------------------------------------
23
24type Hook interface {
25	BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error)
26	AfterProcess(ctx context.Context, cmd Cmder) error
27
28	BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error)
29	AfterProcessPipeline(ctx context.Context, cmds []Cmder) error
30}
31
32type hooks struct {
33	hooks []Hook
34}
35
36func (hs *hooks) lock() {
37	hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)]
38}
39
40func (hs hooks) clone() hooks {
41	clone := hs
42	clone.lock()
43	return clone
44}
45
46func (hs *hooks) AddHook(hook Hook) {
47	hs.hooks = append(hs.hooks, hook)
48}
49
50func (hs hooks) process(
51	ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
52) error {
53	if len(hs.hooks) == 0 {
54		err := hs.withContext(ctx, func() error {
55			return fn(ctx, cmd)
56		})
57		cmd.SetErr(err)
58		return err
59	}
60
61	var hookIndex int
62	var retErr error
63
64	for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
65		ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd)
66		if retErr != nil {
67			cmd.SetErr(retErr)
68		}
69	}
70
71	if retErr == nil {
72		retErr = hs.withContext(ctx, func() error {
73			return fn(ctx, cmd)
74		})
75		cmd.SetErr(retErr)
76	}
77
78	for hookIndex--; hookIndex >= 0; hookIndex-- {
79		if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil {
80			retErr = err
81			cmd.SetErr(retErr)
82		}
83	}
84
85	return retErr
86}
87
88func (hs hooks) processPipeline(
89	ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
90) error {
91	if len(hs.hooks) == 0 {
92		err := hs.withContext(ctx, func() error {
93			return fn(ctx, cmds)
94		})
95		return err
96	}
97
98	var hookIndex int
99	var retErr error
100
101	for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
102		ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds)
103		if retErr != nil {
104			setCmdsErr(cmds, retErr)
105		}
106	}
107
108	if retErr == nil {
109		retErr = hs.withContext(ctx, func() error {
110			return fn(ctx, cmds)
111		})
112	}
113
114	for hookIndex--; hookIndex >= 0; hookIndex-- {
115		if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil {
116			retErr = err
117			setCmdsErr(cmds, retErr)
118		}
119	}
120
121	return retErr
122}
123
124func (hs hooks) processTxPipeline(
125	ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
126) error {
127	cmds = wrapMultiExec(ctx, cmds)
128	return hs.processPipeline(ctx, cmds, fn)
129}
130
131func (hs hooks) withContext(ctx context.Context, fn func() error) error {
132	return fn()
133}
134
135//------------------------------------------------------------------------------
136
137type baseClient struct {
138	opt      *Options
139	connPool pool.Pooler
140
141	onClose func() error // hook called when client is closed
142}
143
144func newBaseClient(opt *Options, connPool pool.Pooler) *baseClient {
145	return &baseClient{
146		opt:      opt,
147		connPool: connPool,
148	}
149}
150
151func (c *baseClient) clone() *baseClient {
152	clone := *c
153	return &clone
154}
155
156func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
157	opt := c.opt.clone()
158	opt.ReadTimeout = timeout
159	opt.WriteTimeout = timeout
160
161	clone := c.clone()
162	clone.opt = opt
163
164	return clone
165}
166
167func (c *baseClient) String() string {
168	return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
169}
170
171func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
172	cn, err := c.connPool.NewConn(ctx)
173	if err != nil {
174		return nil, err
175	}
176
177	err = c.initConn(ctx, cn)
178	if err != nil {
179		_ = c.connPool.CloseConn(cn)
180		return nil, err
181	}
182
183	return cn, nil
184}
185
186func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
187	if c.opt.Limiter != nil {
188		err := c.opt.Limiter.Allow()
189		if err != nil {
190			return nil, err
191		}
192	}
193
194	cn, err := c._getConn(ctx)
195	if err != nil {
196		if c.opt.Limiter != nil {
197			c.opt.Limiter.ReportResult(err)
198		}
199		return nil, err
200	}
201
202	return cn, nil
203}
204
205func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
206	cn, err := c.connPool.Get(ctx)
207	if err != nil {
208		return nil, err
209	}
210
211	if cn.Inited {
212		return cn, nil
213	}
214
215	if err := c.initConn(ctx, cn); err != nil {
216		c.connPool.Remove(ctx, cn, err)
217		if err := errors.Unwrap(err); err != nil {
218			return nil, err
219		}
220		return nil, err
221	}
222
223	return cn, nil
224}
225
226func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
227	if cn.Inited {
228		return nil
229	}
230	cn.Inited = true
231
232	if c.opt.Password == "" &&
233		c.opt.DB == 0 &&
234		!c.opt.readOnly &&
235		c.opt.OnConnect == nil {
236		return nil
237	}
238
239	connPool := pool.NewSingleConnPool(c.connPool, cn)
240	conn := newConn(ctx, c.opt, connPool)
241
242	_, err := conn.Pipelined(ctx, func(pipe Pipeliner) error {
243		if c.opt.Password != "" {
244			if c.opt.Username != "" {
245				pipe.AuthACL(ctx, c.opt.Username, c.opt.Password)
246			} else {
247				pipe.Auth(ctx, c.opt.Password)
248			}
249		}
250
251		if c.opt.DB > 0 {
252			pipe.Select(ctx, c.opt.DB)
253		}
254
255		if c.opt.readOnly {
256			pipe.ReadOnly(ctx)
257		}
258
259		return nil
260	})
261	if err != nil {
262		return err
263	}
264
265	if c.opt.OnConnect != nil {
266		return c.opt.OnConnect(ctx, conn)
267	}
268	return nil
269}
270
271func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) {
272	if c.opt.Limiter != nil {
273		c.opt.Limiter.ReportResult(err)
274	}
275
276	if isBadConn(err, false) {
277		c.connPool.Remove(ctx, cn, err)
278	} else {
279		c.connPool.Put(ctx, cn)
280	}
281}
282
283func (c *baseClient) withConn(
284	ctx context.Context, fn func(context.Context, *pool.Conn) error,
285) error {
286	cn, err := c.getConn(ctx)
287	if err != nil {
288		return err
289	}
290
291	defer func() {
292		c.releaseConn(ctx, cn, err)
293	}()
294
295	done := ctx.Done() //nolint:ifshort
296
297	if done == nil {
298		err = fn(ctx, cn)
299		return err
300	}
301
302	errc := make(chan error, 1)
303	go func() { errc <- fn(ctx, cn) }()
304
305	select {
306	case <-done:
307		_ = cn.Close()
308		// Wait for the goroutine to finish and send something.
309		<-errc
310
311		err = ctx.Err()
312		return err
313	case err = <-errc:
314		return err
315	}
316}
317
318func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
319	var lastErr error
320	for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
321		attempt := attempt
322
323		retry, err := c._process(ctx, cmd, attempt)
324		if err == nil || !retry {
325			return err
326		}
327
328		lastErr = err
329	}
330	return lastErr
331}
332
333func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool, error) {
334	if attempt > 0 {
335		if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
336			return false, err
337		}
338	}
339
340	retryTimeout := uint32(1)
341	err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
342		err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
343			return writeCmd(wr, cmd)
344		})
345		if err != nil {
346			return err
347		}
348
349		err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply)
350		if err != nil {
351			if cmd.readTimeout() == nil {
352				atomic.StoreUint32(&retryTimeout, 1)
353			}
354			return err
355		}
356
357		return nil
358	})
359	if err == nil {
360		return false, nil
361	}
362
363	retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1)
364	return retry, err
365}
366
367func (c *baseClient) retryBackoff(attempt int) time.Duration {
368	return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
369}
370
371func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
372	if timeout := cmd.readTimeout(); timeout != nil {
373		t := *timeout
374		if t == 0 {
375			return 0
376		}
377		return t + 10*time.Second
378	}
379	return c.opt.ReadTimeout
380}
381
382// Close closes the client, releasing any open resources.
383//
384// It is rare to Close a Client, as the Client is meant to be
385// long-lived and shared between many goroutines.
386func (c *baseClient) Close() error {
387	var firstErr error
388	if c.onClose != nil {
389		if err := c.onClose(); err != nil {
390			firstErr = err
391		}
392	}
393	if err := c.connPool.Close(); err != nil && firstErr == nil {
394		firstErr = err
395	}
396	return firstErr
397}
398
399func (c *baseClient) getAddr() string {
400	return c.opt.Addr
401}
402
403func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error {
404	return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds)
405}
406
407func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
408	return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds)
409}
410
411type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error)
412
413func (c *baseClient) generalProcessPipeline(
414	ctx context.Context, cmds []Cmder, p pipelineProcessor,
415) error {
416	err := c._generalProcessPipeline(ctx, cmds, p)
417	if err != nil {
418		setCmdsErr(cmds, err)
419		return err
420	}
421	return cmdsFirstErr(cmds)
422}
423
424func (c *baseClient) _generalProcessPipeline(
425	ctx context.Context, cmds []Cmder, p pipelineProcessor,
426) error {
427	var lastErr error
428	for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
429		if attempt > 0 {
430			if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
431				return err
432			}
433		}
434
435		var canRetry bool
436		lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
437			var err error
438			canRetry, err = p(ctx, cn, cmds)
439			return err
440		})
441		if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) {
442			return lastErr
443		}
444	}
445	return lastErr
446}
447
448func (c *baseClient) pipelineProcessCmds(
449	ctx context.Context, cn *pool.Conn, cmds []Cmder,
450) (bool, error) {
451	err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
452		return writeCmds(wr, cmds)
453	})
454	if err != nil {
455		return true, err
456	}
457
458	err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
459		return pipelineReadCmds(rd, cmds)
460	})
461	return true, err
462}
463
464func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
465	for _, cmd := range cmds {
466		err := cmd.readReply(rd)
467		cmd.SetErr(err)
468		if err != nil && !isRedisError(err) {
469			return err
470		}
471	}
472	return nil
473}
474
475func (c *baseClient) txPipelineProcessCmds(
476	ctx context.Context, cn *pool.Conn, cmds []Cmder,
477) (bool, error) {
478	err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
479		return writeCmds(wr, cmds)
480	})
481	if err != nil {
482		return true, err
483	}
484
485	err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
486		statusCmd := cmds[0].(*StatusCmd)
487		// Trim multi and exec.
488		cmds = cmds[1 : len(cmds)-1]
489
490		err := txPipelineReadQueued(rd, statusCmd, cmds)
491		if err != nil {
492			return err
493		}
494
495		return pipelineReadCmds(rd, cmds)
496	})
497	return false, err
498}
499
500func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder {
501	if len(cmds) == 0 {
502		panic("not reached")
503	}
504	cmdCopy := make([]Cmder, len(cmds)+2)
505	cmdCopy[0] = NewStatusCmd(ctx, "multi")
506	copy(cmdCopy[1:], cmds)
507	cmdCopy[len(cmdCopy)-1] = NewSliceCmd(ctx, "exec")
508	return cmdCopy
509}
510
511func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error {
512	// Parse queued replies.
513	if err := statusCmd.readReply(rd); err != nil {
514		return err
515	}
516
517	for range cmds {
518		if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) {
519			return err
520		}
521	}
522
523	// Parse number of replies.
524	line, err := rd.ReadLine()
525	if err != nil {
526		if err == Nil {
527			err = TxFailedErr
528		}
529		return err
530	}
531
532	switch line[0] {
533	case proto.ErrorReply:
534		return proto.ParseErrorReply(line)
535	case proto.ArrayReply:
536		// ok
537	default:
538		err := fmt.Errorf("redis: expected '*', but got line %q", line)
539		return err
540	}
541
542	return nil
543}
544
545//------------------------------------------------------------------------------
546
547// Client is a Redis client representing a pool of zero or more
548// underlying connections. It's safe for concurrent use by multiple
549// goroutines.
550type Client struct {
551	*baseClient
552	cmdable
553	hooks
554	ctx context.Context
555}
556
557// NewClient returns a client to the Redis Server specified by Options.
558func NewClient(opt *Options) *Client {
559	opt.init()
560
561	c := Client{
562		baseClient: newBaseClient(opt, newConnPool(opt)),
563		ctx:        context.Background(),
564	}
565	c.cmdable = c.Process
566
567	return &c
568}
569
570func (c *Client) clone() *Client {
571	clone := *c
572	clone.cmdable = clone.Process
573	clone.hooks.lock()
574	return &clone
575}
576
577func (c *Client) WithTimeout(timeout time.Duration) *Client {
578	clone := c.clone()
579	clone.baseClient = c.baseClient.withTimeout(timeout)
580	return clone
581}
582
583func (c *Client) Context() context.Context {
584	return c.ctx
585}
586
587func (c *Client) WithContext(ctx context.Context) *Client {
588	if ctx == nil {
589		panic("nil context")
590	}
591	clone := c.clone()
592	clone.ctx = ctx
593	return clone
594}
595
596func (c *Client) Conn(ctx context.Context) *Conn {
597	return newConn(ctx, c.opt, pool.NewStickyConnPool(c.connPool))
598}
599
600// Do creates a Cmd from the args and processes the cmd.
601func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd {
602	cmd := NewCmd(ctx, args...)
603	_ = c.Process(ctx, cmd)
604	return cmd
605}
606
607func (c *Client) Process(ctx context.Context, cmd Cmder) error {
608	return c.hooks.process(ctx, cmd, c.baseClient.process)
609}
610
611func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error {
612	return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
613}
614
615func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error {
616	return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline)
617}
618
619// Options returns read-only Options that were used to create the client.
620func (c *Client) Options() *Options {
621	return c.opt
622}
623
624type PoolStats pool.Stats
625
626// PoolStats returns connection pool stats.
627func (c *Client) PoolStats() *PoolStats {
628	stats := c.connPool.Stats()
629	return (*PoolStats)(stats)
630}
631
632func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
633	return c.Pipeline().Pipelined(ctx, fn)
634}
635
636func (c *Client) Pipeline() Pipeliner {
637	pipe := Pipeline{
638		ctx:  c.ctx,
639		exec: c.processPipeline,
640	}
641	pipe.init()
642	return &pipe
643}
644
645func (c *Client) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
646	return c.TxPipeline().Pipelined(ctx, fn)
647}
648
649// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
650func (c *Client) TxPipeline() Pipeliner {
651	pipe := Pipeline{
652		ctx:  c.ctx,
653		exec: c.processTxPipeline,
654	}
655	pipe.init()
656	return &pipe
657}
658
659func (c *Client) pubSub() *PubSub {
660	pubsub := &PubSub{
661		opt: c.opt,
662
663		newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
664			return c.newConn(ctx)
665		},
666		closeConn: c.connPool.CloseConn,
667	}
668	pubsub.init()
669	return pubsub
670}
671
672// Subscribe subscribes the client to the specified channels.
673// Channels can be omitted to create empty subscription.
674// Note that this method does not wait on a response from Redis, so the
675// subscription may not be active immediately. To force the connection to wait,
676// you may call the Receive() method on the returned *PubSub like so:
677//
678//    sub := client.Subscribe(queryResp)
679//    iface, err := sub.Receive()
680//    if err != nil {
681//        // handle error
682//    }
683//
684//    // Should be *Subscription, but others are possible if other actions have been
685//    // taken on sub since it was created.
686//    switch iface.(type) {
687//    case *Subscription:
688//        // subscribe succeeded
689//    case *Message:
690//        // received first message
691//    case *Pong:
692//        // pong received
693//    default:
694//        // handle error
695//    }
696//
697//    ch := sub.Channel()
698func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub {
699	pubsub := c.pubSub()
700	if len(channels) > 0 {
701		_ = pubsub.Subscribe(ctx, channels...)
702	}
703	return pubsub
704}
705
706// PSubscribe subscribes the client to the given patterns.
707// Patterns can be omitted to create empty subscription.
708func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub {
709	pubsub := c.pubSub()
710	if len(channels) > 0 {
711		_ = pubsub.PSubscribe(ctx, channels...)
712	}
713	return pubsub
714}
715
716//------------------------------------------------------------------------------
717
718type conn struct {
719	baseClient
720	cmdable
721	statefulCmdable
722	hooks // TODO: inherit hooks
723}
724
725// Conn is like Client, but its pool contains single connection.
726type Conn struct {
727	*conn
728	ctx context.Context
729}
730
731func newConn(ctx context.Context, opt *Options, connPool pool.Pooler) *Conn {
732	c := Conn{
733		conn: &conn{
734			baseClient: baseClient{
735				opt:      opt,
736				connPool: connPool,
737			},
738		},
739		ctx: ctx,
740	}
741	c.cmdable = c.Process
742	c.statefulCmdable = c.Process
743	return &c
744}
745
746func (c *Conn) Process(ctx context.Context, cmd Cmder) error {
747	return c.hooks.process(ctx, cmd, c.baseClient.process)
748}
749
750func (c *Conn) processPipeline(ctx context.Context, cmds []Cmder) error {
751	return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
752}
753
754func (c *Conn) processTxPipeline(ctx context.Context, cmds []Cmder) error {
755	return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline)
756}
757
758func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
759	return c.Pipeline().Pipelined(ctx, fn)
760}
761
762func (c *Conn) Pipeline() Pipeliner {
763	pipe := Pipeline{
764		ctx:  c.ctx,
765		exec: c.processPipeline,
766	}
767	pipe.init()
768	return &pipe
769}
770
771func (c *Conn) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
772	return c.TxPipeline().Pipelined(ctx, fn)
773}
774
775// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
776func (c *Conn) TxPipeline() Pipeliner {
777	pipe := Pipeline{
778		ctx:  c.ctx,
779		exec: c.processTxPipeline,
780	}
781	pipe.init()
782	return &pipe
783}
784