1package redis 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "sync/atomic" 8 "time" 9 10 "github.com/go-redis/redis/v8/internal" 11 "github.com/go-redis/redis/v8/internal/pool" 12 "github.com/go-redis/redis/v8/internal/proto" 13 "go.opentelemetry.io/otel/attribute" 14 "go.opentelemetry.io/otel/trace" 15) 16 17// Nil reply returned by Redis when key does not exist. 18const Nil = proto.Nil 19 20func SetLogger(logger internal.Logging) { 21 internal.Logger = logger 22} 23 24//------------------------------------------------------------------------------ 25 26type Hook interface { 27 BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) 28 AfterProcess(ctx context.Context, cmd Cmder) error 29 30 BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) 31 AfterProcessPipeline(ctx context.Context, cmds []Cmder) error 32} 33 34type hooks struct { 35 hooks []Hook 36} 37 38func (hs *hooks) lock() { 39 hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)] 40} 41 42func (hs hooks) clone() hooks { 43 clone := hs 44 clone.lock() 45 return clone 46} 47 48func (hs *hooks) AddHook(hook Hook) { 49 hs.hooks = append(hs.hooks, hook) 50} 51 52func (hs hooks) process( 53 ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error, 54) error { 55 if len(hs.hooks) == 0 { 56 err := hs.withContext(ctx, func() error { 57 return fn(ctx, cmd) 58 }) 59 cmd.SetErr(err) 60 return err 61 } 62 63 var hookIndex int 64 var retErr error 65 66 for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ { 67 ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd) 68 if retErr != nil { 69 cmd.SetErr(retErr) 70 } 71 } 72 73 if retErr == nil { 74 retErr = hs.withContext(ctx, func() error { 75 return fn(ctx, cmd) 76 }) 77 cmd.SetErr(retErr) 78 } 79 80 for hookIndex--; hookIndex >= 0; hookIndex-- { 81 if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil { 82 retErr = err 83 cmd.SetErr(retErr) 84 } 85 } 86 87 return retErr 88} 89 90func (hs hooks) processPipeline( 91 ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, 92) error { 93 if len(hs.hooks) == 0 { 94 err := hs.withContext(ctx, func() error { 95 return fn(ctx, cmds) 96 }) 97 return err 98 } 99 100 var hookIndex int 101 var retErr error 102 103 for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ { 104 ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds) 105 if retErr != nil { 106 setCmdsErr(cmds, retErr) 107 } 108 } 109 110 if retErr == nil { 111 retErr = hs.withContext(ctx, func() error { 112 return fn(ctx, cmds) 113 }) 114 } 115 116 for hookIndex--; hookIndex >= 0; hookIndex-- { 117 if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil { 118 retErr = err 119 setCmdsErr(cmds, retErr) 120 } 121 } 122 123 return retErr 124} 125 126func (hs hooks) processTxPipeline( 127 ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, 128) error { 129 cmds = wrapMultiExec(ctx, cmds) 130 return hs.processPipeline(ctx, cmds, fn) 131} 132 133func (hs hooks) withContext(ctx context.Context, fn func() error) error { 134 return fn() 135} 136 137//------------------------------------------------------------------------------ 138 139type baseClient struct { 140 opt *Options 141 connPool pool.Pooler 142 143 onClose func() error // hook called when client is closed 144} 145 146func newBaseClient(opt *Options, connPool pool.Pooler) *baseClient { 147 return &baseClient{ 148 opt: opt, 149 connPool: connPool, 150 } 151} 152 153func (c *baseClient) clone() *baseClient { 154 clone := *c 155 return &clone 156} 157 158func (c *baseClient) withTimeout(timeout time.Duration) *baseClient { 159 opt := c.opt.clone() 160 opt.ReadTimeout = timeout 161 opt.WriteTimeout = timeout 162 163 clone := c.clone() 164 clone.opt = opt 165 166 return clone 167} 168 169func (c *baseClient) String() string { 170 return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB) 171} 172 173func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) { 174 cn, err := c.connPool.NewConn(ctx) 175 if err != nil { 176 return nil, err 177 } 178 179 err = c.initConn(ctx, cn) 180 if err != nil { 181 _ = c.connPool.CloseConn(cn) 182 return nil, err 183 } 184 185 return cn, nil 186} 187 188func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) { 189 if c.opt.Limiter != nil { 190 err := c.opt.Limiter.Allow() 191 if err != nil { 192 return nil, err 193 } 194 } 195 196 cn, err := c._getConn(ctx) 197 if err != nil { 198 if c.opt.Limiter != nil { 199 c.opt.Limiter.ReportResult(err) 200 } 201 return nil, err 202 } 203 204 return cn, nil 205} 206 207func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { 208 cn, err := c.connPool.Get(ctx) 209 if err != nil { 210 return nil, err 211 } 212 213 if cn.Inited { 214 return cn, nil 215 } 216 217 err = internal.WithSpan(ctx, "redis.init_conn", func(ctx context.Context, span trace.Span) error { 218 return c.initConn(ctx, cn) 219 }) 220 if err != nil { 221 c.connPool.Remove(ctx, cn, err) 222 if err := errors.Unwrap(err); err != nil { 223 return nil, err 224 } 225 return nil, err 226 } 227 228 return cn, nil 229} 230 231func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { 232 if cn.Inited { 233 return nil 234 } 235 cn.Inited = true 236 237 if c.opt.Password == "" && 238 c.opt.DB == 0 && 239 !c.opt.readOnly && 240 c.opt.OnConnect == nil { 241 return nil 242 } 243 244 connPool := pool.NewSingleConnPool(c.connPool, cn) 245 conn := newConn(ctx, c.opt, connPool) 246 247 _, err := conn.Pipelined(ctx, func(pipe Pipeliner) error { 248 if c.opt.Password != "" { 249 if c.opt.Username != "" { 250 pipe.AuthACL(ctx, c.opt.Username, c.opt.Password) 251 } else { 252 pipe.Auth(ctx, c.opt.Password) 253 } 254 } 255 256 if c.opt.DB > 0 { 257 pipe.Select(ctx, c.opt.DB) 258 } 259 260 if c.opt.readOnly { 261 pipe.ReadOnly(ctx) 262 } 263 264 return nil 265 }) 266 if err != nil { 267 return err 268 } 269 270 if c.opt.OnConnect != nil { 271 return c.opt.OnConnect(ctx, conn) 272 } 273 return nil 274} 275 276func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) { 277 if c.opt.Limiter != nil { 278 c.opt.Limiter.ReportResult(err) 279 } 280 281 if isBadConn(err, false) { 282 c.connPool.Remove(ctx, cn, err) 283 } else { 284 c.connPool.Put(ctx, cn) 285 } 286} 287 288func (c *baseClient) withConn( 289 ctx context.Context, fn func(context.Context, *pool.Conn) error, 290) error { 291 return internal.WithSpan(ctx, "redis.with_conn", func(ctx context.Context, span trace.Span) error { 292 cn, err := c.getConn(ctx) 293 if err != nil { 294 return err 295 } 296 297 if span.IsRecording() { 298 if remoteAddr := cn.RemoteAddr(); remoteAddr != nil { 299 span.SetAttributes(attribute.String("net.peer.ip", remoteAddr.String())) 300 } 301 } 302 303 defer func() { 304 c.releaseConn(ctx, cn, err) 305 }() 306 307 done := ctx.Done() 308 if done == nil { 309 err = fn(ctx, cn) 310 return err 311 } 312 313 errc := make(chan error, 1) 314 go func() { errc <- fn(ctx, cn) }() 315 316 select { 317 case <-done: 318 _ = cn.Close() 319 // Wait for the goroutine to finish and send something. 320 <-errc 321 322 err = ctx.Err() 323 return err 324 case err = <-errc: 325 return err 326 } 327 }) 328} 329 330func (c *baseClient) process(ctx context.Context, cmd Cmder) error { 331 var lastErr error 332 for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { 333 attempt := attempt 334 335 var retry bool 336 err := internal.WithSpan(ctx, "redis.process", func(ctx context.Context, span trace.Span) error { 337 if attempt > 0 { 338 if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { 339 return err 340 } 341 } 342 343 retryTimeout := uint32(1) 344 err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { 345 err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { 346 return writeCmd(wr, cmd) 347 }) 348 if err != nil { 349 return err 350 } 351 352 err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply) 353 if err != nil { 354 if cmd.readTimeout() == nil { 355 atomic.StoreUint32(&retryTimeout, 1) 356 } 357 return err 358 } 359 360 return nil 361 }) 362 if err == nil { 363 return nil 364 } 365 retry = shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1) 366 return err 367 }) 368 if err == nil || !retry { 369 return err 370 } 371 lastErr = err 372 } 373 return lastErr 374} 375 376func (c *baseClient) retryBackoff(attempt int) time.Duration { 377 return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) 378} 379 380func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration { 381 if timeout := cmd.readTimeout(); timeout != nil { 382 t := *timeout 383 if t == 0 { 384 return 0 385 } 386 return t + 10*time.Second 387 } 388 return c.opt.ReadTimeout 389} 390 391// Close closes the client, releasing any open resources. 392// 393// It is rare to Close a Client, as the Client is meant to be 394// long-lived and shared between many goroutines. 395func (c *baseClient) Close() error { 396 var firstErr error 397 if c.onClose != nil { 398 if err := c.onClose(); err != nil { 399 firstErr = err 400 } 401 } 402 if err := c.connPool.Close(); err != nil && firstErr == nil { 403 firstErr = err 404 } 405 return firstErr 406} 407 408func (c *baseClient) getAddr() string { 409 return c.opt.Addr 410} 411 412func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error { 413 return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds) 414} 415 416func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error { 417 return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds) 418} 419 420type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error) 421 422func (c *baseClient) generalProcessPipeline( 423 ctx context.Context, cmds []Cmder, p pipelineProcessor, 424) error { 425 err := c._generalProcessPipeline(ctx, cmds, p) 426 if err != nil { 427 setCmdsErr(cmds, err) 428 return err 429 } 430 return cmdsFirstErr(cmds) 431} 432 433func (c *baseClient) _generalProcessPipeline( 434 ctx context.Context, cmds []Cmder, p pipelineProcessor, 435) error { 436 var lastErr error 437 for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { 438 if attempt > 0 { 439 if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { 440 return err 441 } 442 } 443 444 var canRetry bool 445 lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { 446 var err error 447 canRetry, err = p(ctx, cn, cmds) 448 return err 449 }) 450 if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) { 451 return lastErr 452 } 453 } 454 return lastErr 455} 456 457func (c *baseClient) pipelineProcessCmds( 458 ctx context.Context, cn *pool.Conn, cmds []Cmder, 459) (bool, error) { 460 err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { 461 return writeCmds(wr, cmds) 462 }) 463 if err != nil { 464 return true, err 465 } 466 467 err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { 468 return pipelineReadCmds(rd, cmds) 469 }) 470 return true, err 471} 472 473func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error { 474 for _, cmd := range cmds { 475 err := cmd.readReply(rd) 476 cmd.SetErr(err) 477 if err != nil && !isRedisError(err) { 478 return err 479 } 480 } 481 return nil 482} 483 484func (c *baseClient) txPipelineProcessCmds( 485 ctx context.Context, cn *pool.Conn, cmds []Cmder, 486) (bool, error) { 487 err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { 488 return writeCmds(wr, cmds) 489 }) 490 if err != nil { 491 return true, err 492 } 493 494 err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { 495 statusCmd := cmds[0].(*StatusCmd) 496 // Trim multi and exec. 497 cmds = cmds[1 : len(cmds)-1] 498 499 err := txPipelineReadQueued(rd, statusCmd, cmds) 500 if err != nil { 501 return err 502 } 503 504 return pipelineReadCmds(rd, cmds) 505 }) 506 return false, err 507} 508 509func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder { 510 if len(cmds) == 0 { 511 panic("not reached") 512 } 513 cmdCopy := make([]Cmder, len(cmds)+2) 514 cmdCopy[0] = NewStatusCmd(ctx, "multi") 515 copy(cmdCopy[1:], cmds) 516 cmdCopy[len(cmdCopy)-1] = NewSliceCmd(ctx, "exec") 517 return cmdCopy 518} 519 520func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error { 521 // Parse queued replies. 522 if err := statusCmd.readReply(rd); err != nil { 523 return err 524 } 525 526 for range cmds { 527 if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) { 528 return err 529 } 530 } 531 532 // Parse number of replies. 533 line, err := rd.ReadLine() 534 if err != nil { 535 if err == Nil { 536 err = TxFailedErr 537 } 538 return err 539 } 540 541 switch line[0] { 542 case proto.ErrorReply: 543 return proto.ParseErrorReply(line) 544 case proto.ArrayReply: 545 // ok 546 default: 547 err := fmt.Errorf("redis: expected '*', but got line %q", line) 548 return err 549 } 550 551 return nil 552} 553 554//------------------------------------------------------------------------------ 555 556// Client is a Redis client representing a pool of zero or more 557// underlying connections. It's safe for concurrent use by multiple 558// goroutines. 559type Client struct { 560 *baseClient 561 cmdable 562 hooks 563 ctx context.Context 564} 565 566// NewClient returns a client to the Redis Server specified by Options. 567func NewClient(opt *Options) *Client { 568 opt.init() 569 570 c := Client{ 571 baseClient: newBaseClient(opt, newConnPool(opt)), 572 ctx: context.Background(), 573 } 574 c.cmdable = c.Process 575 576 return &c 577} 578 579func (c *Client) clone() *Client { 580 clone := *c 581 clone.cmdable = clone.Process 582 clone.hooks.lock() 583 return &clone 584} 585 586func (c *Client) WithTimeout(timeout time.Duration) *Client { 587 clone := c.clone() 588 clone.baseClient = c.baseClient.withTimeout(timeout) 589 return clone 590} 591 592func (c *Client) Context() context.Context { 593 return c.ctx 594} 595 596func (c *Client) WithContext(ctx context.Context) *Client { 597 if ctx == nil { 598 panic("nil context") 599 } 600 clone := c.clone() 601 clone.ctx = ctx 602 return clone 603} 604 605func (c *Client) Conn(ctx context.Context) *Conn { 606 return newConn(ctx, c.opt, pool.NewStickyConnPool(c.connPool)) 607} 608 609// Do creates a Cmd from the args and processes the cmd. 610func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd { 611 cmd := NewCmd(ctx, args...) 612 _ = c.Process(ctx, cmd) 613 return cmd 614} 615 616func (c *Client) Process(ctx context.Context, cmd Cmder) error { 617 return c.hooks.process(ctx, cmd, c.baseClient.process) 618} 619 620func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error { 621 return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline) 622} 623 624func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error { 625 return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) 626} 627 628// Options returns read-only Options that were used to create the client. 629func (c *Client) Options() *Options { 630 return c.opt 631} 632 633type PoolStats pool.Stats 634 635// PoolStats returns connection pool stats. 636func (c *Client) PoolStats() *PoolStats { 637 stats := c.connPool.Stats() 638 return (*PoolStats)(stats) 639} 640 641func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { 642 return c.Pipeline().Pipelined(ctx, fn) 643} 644 645func (c *Client) Pipeline() Pipeliner { 646 pipe := Pipeline{ 647 ctx: c.ctx, 648 exec: c.processPipeline, 649 } 650 pipe.init() 651 return &pipe 652} 653 654func (c *Client) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { 655 return c.TxPipeline().Pipelined(ctx, fn) 656} 657 658// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. 659func (c *Client) TxPipeline() Pipeliner { 660 pipe := Pipeline{ 661 ctx: c.ctx, 662 exec: c.processTxPipeline, 663 } 664 pipe.init() 665 return &pipe 666} 667 668func (c *Client) pubSub() *PubSub { 669 pubsub := &PubSub{ 670 opt: c.opt, 671 672 newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { 673 return c.newConn(ctx) 674 }, 675 closeConn: c.connPool.CloseConn, 676 } 677 pubsub.init() 678 return pubsub 679} 680 681// Subscribe subscribes the client to the specified channels. 682// Channels can be omitted to create empty subscription. 683// Note that this method does not wait on a response from Redis, so the 684// subscription may not be active immediately. To force the connection to wait, 685// you may call the Receive() method on the returned *PubSub like so: 686// 687// sub := client.Subscribe(queryResp) 688// iface, err := sub.Receive() 689// if err != nil { 690// // handle error 691// } 692// 693// // Should be *Subscription, but others are possible if other actions have been 694// // taken on sub since it was created. 695// switch iface.(type) { 696// case *Subscription: 697// // subscribe succeeded 698// case *Message: 699// // received first message 700// case *Pong: 701// // pong received 702// default: 703// // handle error 704// } 705// 706// ch := sub.Channel() 707func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub { 708 pubsub := c.pubSub() 709 if len(channels) > 0 { 710 _ = pubsub.Subscribe(ctx, channels...) 711 } 712 return pubsub 713} 714 715// PSubscribe subscribes the client to the given patterns. 716// Patterns can be omitted to create empty subscription. 717func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub { 718 pubsub := c.pubSub() 719 if len(channels) > 0 { 720 _ = pubsub.PSubscribe(ctx, channels...) 721 } 722 return pubsub 723} 724 725//------------------------------------------------------------------------------ 726 727type conn struct { 728 baseClient 729 cmdable 730 statefulCmdable 731 hooks // TODO: inherit hooks 732} 733 734// Conn is like Client, but its pool contains single connection. 735type Conn struct { 736 *conn 737 ctx context.Context 738} 739 740func newConn(ctx context.Context, opt *Options, connPool pool.Pooler) *Conn { 741 c := Conn{ 742 conn: &conn{ 743 baseClient: baseClient{ 744 opt: opt, 745 connPool: connPool, 746 }, 747 }, 748 ctx: ctx, 749 } 750 c.cmdable = c.Process 751 c.statefulCmdable = c.Process 752 return &c 753} 754 755func (c *Conn) Process(ctx context.Context, cmd Cmder) error { 756 return c.hooks.process(ctx, cmd, c.baseClient.process) 757} 758 759func (c *Conn) processPipeline(ctx context.Context, cmds []Cmder) error { 760 return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline) 761} 762 763func (c *Conn) processTxPipeline(ctx context.Context, cmds []Cmder) error { 764 return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) 765} 766 767func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { 768 return c.Pipeline().Pipelined(ctx, fn) 769} 770 771func (c *Conn) Pipeline() Pipeliner { 772 pipe := Pipeline{ 773 ctx: c.ctx, 774 exec: c.processPipeline, 775 } 776 pipe.init() 777 return &pipe 778} 779 780func (c *Conn) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { 781 return c.TxPipeline().Pipelined(ctx, fn) 782} 783 784// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. 785func (c *Conn) TxPipeline() Pipeliner { 786 pipe := Pipeline{ 787 ctx: c.ctx, 788 exec: c.processTxPipeline, 789 } 790 pipe.init() 791 return &pipe 792} 793