1package redis 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "sync/atomic" 8 "time" 9 10 "github.com/go-redis/redis/v8/internal" 11 "github.com/go-redis/redis/v8/internal/pool" 12 "github.com/go-redis/redis/v8/internal/proto" 13 "go.opentelemetry.io/otel/attribute" 14) 15 16// Nil reply returned by Redis when key does not exist. 17const Nil = proto.Nil 18 19func SetLogger(logger internal.Logging) { 20 internal.Logger = logger 21} 22 23//------------------------------------------------------------------------------ 24 25type Hook interface { 26 BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) 27 AfterProcess(ctx context.Context, cmd Cmder) error 28 29 BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) 30 AfterProcessPipeline(ctx context.Context, cmds []Cmder) error 31} 32 33type hooks struct { 34 hooks []Hook 35} 36 37func (hs *hooks) lock() { 38 hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)] 39} 40 41func (hs hooks) clone() hooks { 42 clone := hs 43 clone.lock() 44 return clone 45} 46 47func (hs *hooks) AddHook(hook Hook) { 48 hs.hooks = append(hs.hooks, hook) 49} 50 51func (hs hooks) process( 52 ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error, 53) error { 54 if len(hs.hooks) == 0 { 55 err := hs.withContext(ctx, func() error { 56 return fn(ctx, cmd) 57 }) 58 cmd.SetErr(err) 59 return err 60 } 61 62 var hookIndex int 63 var retErr error 64 65 for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ { 66 ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd) 67 if retErr != nil { 68 cmd.SetErr(retErr) 69 } 70 } 71 72 if retErr == nil { 73 retErr = hs.withContext(ctx, func() error { 74 return fn(ctx, cmd) 75 }) 76 cmd.SetErr(retErr) 77 } 78 79 for hookIndex--; hookIndex >= 0; hookIndex-- { 80 if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil { 81 retErr = err 82 cmd.SetErr(retErr) 83 } 84 } 85 86 return retErr 87} 88 89func (hs hooks) processPipeline( 90 ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, 91) error { 92 if len(hs.hooks) == 0 { 93 err := hs.withContext(ctx, func() error { 94 return fn(ctx, cmds) 95 }) 96 return err 97 } 98 99 var hookIndex int 100 var retErr error 101 102 for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ { 103 ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds) 104 if retErr != nil { 105 setCmdsErr(cmds, retErr) 106 } 107 } 108 109 if retErr == nil { 110 retErr = hs.withContext(ctx, func() error { 111 return fn(ctx, cmds) 112 }) 113 } 114 115 for hookIndex--; hookIndex >= 0; hookIndex-- { 116 if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil { 117 retErr = err 118 setCmdsErr(cmds, retErr) 119 } 120 } 121 122 return retErr 123} 124 125func (hs hooks) processTxPipeline( 126 ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, 127) error { 128 cmds = wrapMultiExec(ctx, cmds) 129 return hs.processPipeline(ctx, cmds, fn) 130} 131 132func (hs hooks) withContext(ctx context.Context, fn func() error) error { 133 return fn() 134} 135 136//------------------------------------------------------------------------------ 137 138type baseClient struct { 139 opt *Options 140 connPool pool.Pooler 141 142 onClose func() error // hook called when client is closed 143} 144 145func newBaseClient(opt *Options, connPool pool.Pooler) *baseClient { 146 return &baseClient{ 147 opt: opt, 148 connPool: connPool, 149 } 150} 151 152func (c *baseClient) clone() *baseClient { 153 clone := *c 154 return &clone 155} 156 157func (c *baseClient) withTimeout(timeout time.Duration) *baseClient { 158 opt := c.opt.clone() 159 opt.ReadTimeout = timeout 160 opt.WriteTimeout = timeout 161 162 clone := c.clone() 163 clone.opt = opt 164 165 return clone 166} 167 168func (c *baseClient) String() string { 169 return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB) 170} 171 172func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) { 173 cn, err := c.connPool.NewConn(ctx) 174 if err != nil { 175 return nil, err 176 } 177 178 err = c.initConn(ctx, cn) 179 if err != nil { 180 _ = c.connPool.CloseConn(cn) 181 return nil, err 182 } 183 184 return cn, nil 185} 186 187func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) { 188 if c.opt.Limiter != nil { 189 err := c.opt.Limiter.Allow() 190 if err != nil { 191 return nil, err 192 } 193 } 194 195 cn, err := c._getConn(ctx) 196 if err != nil { 197 if c.opt.Limiter != nil { 198 c.opt.Limiter.ReportResult(err) 199 } 200 return nil, err 201 } 202 203 return cn, nil 204} 205 206func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { 207 cn, err := c.connPool.Get(ctx) 208 if err != nil { 209 return nil, err 210 } 211 212 if cn.Inited { 213 return cn, nil 214 } 215 216 if err := c.initConn(ctx, cn); err != nil { 217 c.connPool.Remove(ctx, cn, err) 218 if err := errors.Unwrap(err); err != nil { 219 return nil, err 220 } 221 return nil, err 222 } 223 224 return cn, nil 225} 226 227func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { 228 if cn.Inited { 229 return nil 230 } 231 cn.Inited = true 232 233 if c.opt.Password == "" && 234 c.opt.DB == 0 && 235 !c.opt.readOnly && 236 c.opt.OnConnect == nil { 237 return nil 238 } 239 240 ctx, span := internal.StartSpan(ctx, "redis.init_conn") 241 defer span.End() 242 243 connPool := pool.NewSingleConnPool(c.connPool, cn) 244 conn := newConn(ctx, c.opt, connPool) 245 246 _, err := conn.Pipelined(ctx, func(pipe Pipeliner) error { 247 if c.opt.Password != "" { 248 if c.opt.Username != "" { 249 pipe.AuthACL(ctx, c.opt.Username, c.opt.Password) 250 } else { 251 pipe.Auth(ctx, c.opt.Password) 252 } 253 } 254 255 if c.opt.DB > 0 { 256 pipe.Select(ctx, c.opt.DB) 257 } 258 259 if c.opt.readOnly { 260 pipe.ReadOnly(ctx) 261 } 262 263 return nil 264 }) 265 if err != nil { 266 return err 267 } 268 269 if c.opt.OnConnect != nil { 270 return c.opt.OnConnect(ctx, conn) 271 } 272 return nil 273} 274 275func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) { 276 if c.opt.Limiter != nil { 277 c.opt.Limiter.ReportResult(err) 278 } 279 280 if isBadConn(err, false) { 281 c.connPool.Remove(ctx, cn, err) 282 } else { 283 c.connPool.Put(ctx, cn) 284 } 285} 286 287func (c *baseClient) withConn( 288 ctx context.Context, fn func(context.Context, *pool.Conn) error, 289) error { 290 ctx, span := internal.StartSpan(ctx, "redis.with_conn") 291 defer span.End() 292 293 cn, err := c.getConn(ctx) 294 if err != nil { 295 return err 296 } 297 298 if span.IsRecording() { 299 if remoteAddr := cn.RemoteAddr(); remoteAddr != nil { 300 span.SetAttributes(attribute.String("net.peer.ip", remoteAddr.String())) 301 } 302 } 303 304 defer func() { 305 c.releaseConn(ctx, cn, err) 306 }() 307 308 done := ctx.Done() //nolint:ifshort 309 310 if done == nil { 311 err = fn(ctx, cn) 312 return err 313 } 314 315 errc := make(chan error, 1) 316 go func() { errc <- fn(ctx, cn) }() 317 318 select { 319 case <-done: 320 _ = cn.Close() 321 // Wait for the goroutine to finish and send something. 322 <-errc 323 324 err = ctx.Err() 325 return err 326 case err = <-errc: 327 return err 328 } 329} 330 331func (c *baseClient) process(ctx context.Context, cmd Cmder) error { 332 var lastErr error 333 for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { 334 attempt := attempt 335 336 retry, err := c._process(ctx, cmd, attempt) 337 if err == nil || !retry { 338 return err 339 } 340 341 lastErr = err 342 } 343 return lastErr 344} 345 346func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool, error) { 347 if attempt > 0 { 348 if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { 349 return false, err 350 } 351 } 352 353 retryTimeout := uint32(1) 354 err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { 355 err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { 356 return writeCmd(wr, cmd) 357 }) 358 if err != nil { 359 return err 360 } 361 362 err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply) 363 if err != nil { 364 if cmd.readTimeout() == nil { 365 atomic.StoreUint32(&retryTimeout, 1) 366 } 367 return err 368 } 369 370 return nil 371 }) 372 if err == nil { 373 return false, nil 374 } 375 376 retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1) 377 return retry, err 378} 379 380func (c *baseClient) retryBackoff(attempt int) time.Duration { 381 return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) 382} 383 384func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration { 385 if timeout := cmd.readTimeout(); timeout != nil { 386 t := *timeout 387 if t == 0 { 388 return 0 389 } 390 return t + 10*time.Second 391 } 392 return c.opt.ReadTimeout 393} 394 395// Close closes the client, releasing any open resources. 396// 397// It is rare to Close a Client, as the Client is meant to be 398// long-lived and shared between many goroutines. 399func (c *baseClient) Close() error { 400 var firstErr error 401 if c.onClose != nil { 402 if err := c.onClose(); err != nil { 403 firstErr = err 404 } 405 } 406 if err := c.connPool.Close(); err != nil && firstErr == nil { 407 firstErr = err 408 } 409 return firstErr 410} 411 412func (c *baseClient) getAddr() string { 413 return c.opt.Addr 414} 415 416func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error { 417 return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds) 418} 419 420func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error { 421 return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds) 422} 423 424type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error) 425 426func (c *baseClient) generalProcessPipeline( 427 ctx context.Context, cmds []Cmder, p pipelineProcessor, 428) error { 429 err := c._generalProcessPipeline(ctx, cmds, p) 430 if err != nil { 431 setCmdsErr(cmds, err) 432 return err 433 } 434 return cmdsFirstErr(cmds) 435} 436 437func (c *baseClient) _generalProcessPipeline( 438 ctx context.Context, cmds []Cmder, p pipelineProcessor, 439) error { 440 var lastErr error 441 for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { 442 if attempt > 0 { 443 if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { 444 return err 445 } 446 } 447 448 var canRetry bool 449 lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { 450 var err error 451 canRetry, err = p(ctx, cn, cmds) 452 return err 453 }) 454 if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) { 455 return lastErr 456 } 457 } 458 return lastErr 459} 460 461func (c *baseClient) pipelineProcessCmds( 462 ctx context.Context, cn *pool.Conn, cmds []Cmder, 463) (bool, error) { 464 err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { 465 return writeCmds(wr, cmds) 466 }) 467 if err != nil { 468 return true, err 469 } 470 471 err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { 472 return pipelineReadCmds(rd, cmds) 473 }) 474 return true, err 475} 476 477func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error { 478 for _, cmd := range cmds { 479 err := cmd.readReply(rd) 480 cmd.SetErr(err) 481 if err != nil && !isRedisError(err) { 482 return err 483 } 484 } 485 return nil 486} 487 488func (c *baseClient) txPipelineProcessCmds( 489 ctx context.Context, cn *pool.Conn, cmds []Cmder, 490) (bool, error) { 491 err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { 492 return writeCmds(wr, cmds) 493 }) 494 if err != nil { 495 return true, err 496 } 497 498 err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { 499 statusCmd := cmds[0].(*StatusCmd) 500 // Trim multi and exec. 501 cmds = cmds[1 : len(cmds)-1] 502 503 err := txPipelineReadQueued(rd, statusCmd, cmds) 504 if err != nil { 505 return err 506 } 507 508 return pipelineReadCmds(rd, cmds) 509 }) 510 return false, err 511} 512 513func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder { 514 if len(cmds) == 0 { 515 panic("not reached") 516 } 517 cmdCopy := make([]Cmder, len(cmds)+2) 518 cmdCopy[0] = NewStatusCmd(ctx, "multi") 519 copy(cmdCopy[1:], cmds) 520 cmdCopy[len(cmdCopy)-1] = NewSliceCmd(ctx, "exec") 521 return cmdCopy 522} 523 524func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error { 525 // Parse queued replies. 526 if err := statusCmd.readReply(rd); err != nil { 527 return err 528 } 529 530 for range cmds { 531 if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) { 532 return err 533 } 534 } 535 536 // Parse number of replies. 537 line, err := rd.ReadLine() 538 if err != nil { 539 if err == Nil { 540 err = TxFailedErr 541 } 542 return err 543 } 544 545 switch line[0] { 546 case proto.ErrorReply: 547 return proto.ParseErrorReply(line) 548 case proto.ArrayReply: 549 // ok 550 default: 551 err := fmt.Errorf("redis: expected '*', but got line %q", line) 552 return err 553 } 554 555 return nil 556} 557 558//------------------------------------------------------------------------------ 559 560// Client is a Redis client representing a pool of zero or more 561// underlying connections. It's safe for concurrent use by multiple 562// goroutines. 563type Client struct { 564 *baseClient 565 cmdable 566 hooks 567 ctx context.Context 568} 569 570// NewClient returns a client to the Redis Server specified by Options. 571func NewClient(opt *Options) *Client { 572 opt.init() 573 574 c := Client{ 575 baseClient: newBaseClient(opt, newConnPool(opt)), 576 ctx: context.Background(), 577 } 578 c.cmdable = c.Process 579 580 return &c 581} 582 583func (c *Client) clone() *Client { 584 clone := *c 585 clone.cmdable = clone.Process 586 clone.hooks.lock() 587 return &clone 588} 589 590func (c *Client) WithTimeout(timeout time.Duration) *Client { 591 clone := c.clone() 592 clone.baseClient = c.baseClient.withTimeout(timeout) 593 return clone 594} 595 596func (c *Client) Context() context.Context { 597 return c.ctx 598} 599 600func (c *Client) WithContext(ctx context.Context) *Client { 601 if ctx == nil { 602 panic("nil context") 603 } 604 clone := c.clone() 605 clone.ctx = ctx 606 return clone 607} 608 609func (c *Client) Conn(ctx context.Context) *Conn { 610 return newConn(ctx, c.opt, pool.NewStickyConnPool(c.connPool)) 611} 612 613// Do creates a Cmd from the args and processes the cmd. 614func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd { 615 cmd := NewCmd(ctx, args...) 616 _ = c.Process(ctx, cmd) 617 return cmd 618} 619 620func (c *Client) Process(ctx context.Context, cmd Cmder) error { 621 return c.hooks.process(ctx, cmd, c.baseClient.process) 622} 623 624func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error { 625 return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline) 626} 627 628func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error { 629 return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) 630} 631 632// Options returns read-only Options that were used to create the client. 633func (c *Client) Options() *Options { 634 return c.opt 635} 636 637type PoolStats pool.Stats 638 639// PoolStats returns connection pool stats. 640func (c *Client) PoolStats() *PoolStats { 641 stats := c.connPool.Stats() 642 return (*PoolStats)(stats) 643} 644 645func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { 646 return c.Pipeline().Pipelined(ctx, fn) 647} 648 649func (c *Client) Pipeline() Pipeliner { 650 pipe := Pipeline{ 651 ctx: c.ctx, 652 exec: c.processPipeline, 653 } 654 pipe.init() 655 return &pipe 656} 657 658func (c *Client) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { 659 return c.TxPipeline().Pipelined(ctx, fn) 660} 661 662// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. 663func (c *Client) TxPipeline() Pipeliner { 664 pipe := Pipeline{ 665 ctx: c.ctx, 666 exec: c.processTxPipeline, 667 } 668 pipe.init() 669 return &pipe 670} 671 672func (c *Client) pubSub() *PubSub { 673 pubsub := &PubSub{ 674 opt: c.opt, 675 676 newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { 677 return c.newConn(ctx) 678 }, 679 closeConn: c.connPool.CloseConn, 680 } 681 pubsub.init() 682 return pubsub 683} 684 685// Subscribe subscribes the client to the specified channels. 686// Channels can be omitted to create empty subscription. 687// Note that this method does not wait on a response from Redis, so the 688// subscription may not be active immediately. To force the connection to wait, 689// you may call the Receive() method on the returned *PubSub like so: 690// 691// sub := client.Subscribe(queryResp) 692// iface, err := sub.Receive() 693// if err != nil { 694// // handle error 695// } 696// 697// // Should be *Subscription, but others are possible if other actions have been 698// // taken on sub since it was created. 699// switch iface.(type) { 700// case *Subscription: 701// // subscribe succeeded 702// case *Message: 703// // received first message 704// case *Pong: 705// // pong received 706// default: 707// // handle error 708// } 709// 710// ch := sub.Channel() 711func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub { 712 pubsub := c.pubSub() 713 if len(channels) > 0 { 714 _ = pubsub.Subscribe(ctx, channels...) 715 } 716 return pubsub 717} 718 719// PSubscribe subscribes the client to the given patterns. 720// Patterns can be omitted to create empty subscription. 721func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub { 722 pubsub := c.pubSub() 723 if len(channels) > 0 { 724 _ = pubsub.PSubscribe(ctx, channels...) 725 } 726 return pubsub 727} 728 729//------------------------------------------------------------------------------ 730 731type conn struct { 732 baseClient 733 cmdable 734 statefulCmdable 735 hooks // TODO: inherit hooks 736} 737 738// Conn is like Client, but its pool contains single connection. 739type Conn struct { 740 *conn 741 ctx context.Context 742} 743 744func newConn(ctx context.Context, opt *Options, connPool pool.Pooler) *Conn { 745 c := Conn{ 746 conn: &conn{ 747 baseClient: baseClient{ 748 opt: opt, 749 connPool: connPool, 750 }, 751 }, 752 ctx: ctx, 753 } 754 c.cmdable = c.Process 755 c.statefulCmdable = c.Process 756 return &c 757} 758 759func (c *Conn) Process(ctx context.Context, cmd Cmder) error { 760 return c.hooks.process(ctx, cmd, c.baseClient.process) 761} 762 763func (c *Conn) processPipeline(ctx context.Context, cmds []Cmder) error { 764 return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline) 765} 766 767func (c *Conn) processTxPipeline(ctx context.Context, cmds []Cmder) error { 768 return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) 769} 770 771func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { 772 return c.Pipeline().Pipelined(ctx, fn) 773} 774 775func (c *Conn) Pipeline() Pipeliner { 776 pipe := Pipeline{ 777 ctx: c.ctx, 778 exec: c.processPipeline, 779 } 780 pipe.init() 781 return &pipe 782} 783 784func (c *Conn) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { 785 return c.TxPipeline().Pipelined(ctx, fn) 786} 787 788// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. 789func (c *Conn) TxPipeline() Pipeliner { 790 pipe := Pipeline{ 791 ctx: c.ctx, 792 exec: c.processTxPipeline, 793 } 794 pipe.init() 795 return &pipe 796} 797