1package redis 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "sync/atomic" 8 "time" 9 10 "github.com/go-redis/redis/v8/internal" 11 "github.com/go-redis/redis/v8/internal/pool" 12 "github.com/go-redis/redis/v8/internal/proto" 13) 14 15// Nil reply returned by Redis when key does not exist. 16const Nil = proto.Nil 17 18func SetLogger(logger internal.Logging) { 19 internal.Logger = logger 20} 21 22//------------------------------------------------------------------------------ 23 24type Hook interface { 25 BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) 26 AfterProcess(ctx context.Context, cmd Cmder) error 27 28 BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) 29 AfterProcessPipeline(ctx context.Context, cmds []Cmder) error 30} 31 32type hooks struct { 33 hooks []Hook 34} 35 36func (hs *hooks) lock() { 37 hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)] 38} 39 40func (hs hooks) clone() hooks { 41 clone := hs 42 clone.lock() 43 return clone 44} 45 46func (hs *hooks) AddHook(hook Hook) { 47 hs.hooks = append(hs.hooks, hook) 48} 49 50func (hs hooks) process( 51 ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error, 52) error { 53 if len(hs.hooks) == 0 { 54 err := hs.withContext(ctx, func() error { 55 return fn(ctx, cmd) 56 }) 57 cmd.SetErr(err) 58 return err 59 } 60 61 var hookIndex int 62 var retErr error 63 64 for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ { 65 ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd) 66 if retErr != nil { 67 cmd.SetErr(retErr) 68 } 69 } 70 71 if retErr == nil { 72 retErr = hs.withContext(ctx, func() error { 73 return fn(ctx, cmd) 74 }) 75 cmd.SetErr(retErr) 76 } 77 78 for hookIndex--; hookIndex >= 0; hookIndex-- { 79 if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil { 80 retErr = err 81 cmd.SetErr(retErr) 82 } 83 } 84 85 return retErr 86} 87 88func (hs hooks) processPipeline( 89 ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, 90) error { 91 if len(hs.hooks) == 0 { 92 err := hs.withContext(ctx, func() error { 93 return fn(ctx, cmds) 94 }) 95 return err 96 } 97 98 var hookIndex int 99 var retErr error 100 101 for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ { 102 ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds) 103 if retErr != nil { 104 setCmdsErr(cmds, retErr) 105 } 106 } 107 108 if retErr == nil { 109 retErr = hs.withContext(ctx, func() error { 110 return fn(ctx, cmds) 111 }) 112 } 113 114 for hookIndex--; hookIndex >= 0; hookIndex-- { 115 if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil { 116 retErr = err 117 setCmdsErr(cmds, retErr) 118 } 119 } 120 121 return retErr 122} 123 124func (hs hooks) processTxPipeline( 125 ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, 126) error { 127 cmds = wrapMultiExec(ctx, cmds) 128 return hs.processPipeline(ctx, cmds, fn) 129} 130 131func (hs hooks) withContext(ctx context.Context, fn func() error) error { 132 return fn() 133} 134 135//------------------------------------------------------------------------------ 136 137type baseClient struct { 138 opt *Options 139 connPool pool.Pooler 140 141 onClose func() error // hook called when client is closed 142} 143 144func newBaseClient(opt *Options, connPool pool.Pooler) *baseClient { 145 return &baseClient{ 146 opt: opt, 147 connPool: connPool, 148 } 149} 150 151func (c *baseClient) clone() *baseClient { 152 clone := *c 153 return &clone 154} 155 156func (c *baseClient) withTimeout(timeout time.Duration) *baseClient { 157 opt := c.opt.clone() 158 opt.ReadTimeout = timeout 159 opt.WriteTimeout = timeout 160 161 clone := c.clone() 162 clone.opt = opt 163 164 return clone 165} 166 167func (c *baseClient) String() string { 168 return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB) 169} 170 171func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) { 172 cn, err := c.connPool.NewConn(ctx) 173 if err != nil { 174 return nil, err 175 } 176 177 err = c.initConn(ctx, cn) 178 if err != nil { 179 _ = c.connPool.CloseConn(cn) 180 return nil, err 181 } 182 183 return cn, nil 184} 185 186func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) { 187 if c.opt.Limiter != nil { 188 err := c.opt.Limiter.Allow() 189 if err != nil { 190 return nil, err 191 } 192 } 193 194 cn, err := c._getConn(ctx) 195 if err != nil { 196 if c.opt.Limiter != nil { 197 c.opt.Limiter.ReportResult(err) 198 } 199 return nil, err 200 } 201 202 return cn, nil 203} 204 205func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { 206 cn, err := c.connPool.Get(ctx) 207 if err != nil { 208 return nil, err 209 } 210 211 if cn.Inited { 212 return cn, nil 213 } 214 215 if err := c.initConn(ctx, cn); err != nil { 216 c.connPool.Remove(ctx, cn, err) 217 if err := errors.Unwrap(err); err != nil { 218 return nil, err 219 } 220 return nil, err 221 } 222 223 return cn, nil 224} 225 226func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { 227 if cn.Inited { 228 return nil 229 } 230 cn.Inited = true 231 232 if c.opt.Password == "" && 233 c.opt.DB == 0 && 234 !c.opt.readOnly && 235 c.opt.OnConnect == nil { 236 return nil 237 } 238 239 connPool := pool.NewSingleConnPool(c.connPool, cn) 240 conn := newConn(ctx, c.opt, connPool) 241 242 _, err := conn.Pipelined(ctx, func(pipe Pipeliner) error { 243 if c.opt.Password != "" { 244 if c.opt.Username != "" { 245 pipe.AuthACL(ctx, c.opt.Username, c.opt.Password) 246 } else { 247 pipe.Auth(ctx, c.opt.Password) 248 } 249 } 250 251 if c.opt.DB > 0 { 252 pipe.Select(ctx, c.opt.DB) 253 } 254 255 if c.opt.readOnly { 256 pipe.ReadOnly(ctx) 257 } 258 259 return nil 260 }) 261 if err != nil { 262 return err 263 } 264 265 if c.opt.OnConnect != nil { 266 return c.opt.OnConnect(ctx, conn) 267 } 268 return nil 269} 270 271func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) { 272 if c.opt.Limiter != nil { 273 c.opt.Limiter.ReportResult(err) 274 } 275 276 if isBadConn(err, false) { 277 c.connPool.Remove(ctx, cn, err) 278 } else { 279 c.connPool.Put(ctx, cn) 280 } 281} 282 283func (c *baseClient) withConn( 284 ctx context.Context, fn func(context.Context, *pool.Conn) error, 285) error { 286 cn, err := c.getConn(ctx) 287 if err != nil { 288 return err 289 } 290 291 defer func() { 292 c.releaseConn(ctx, cn, err) 293 }() 294 295 done := ctx.Done() //nolint:ifshort 296 297 if done == nil { 298 err = fn(ctx, cn) 299 return err 300 } 301 302 errc := make(chan error, 1) 303 go func() { errc <- fn(ctx, cn) }() 304 305 select { 306 case <-done: 307 _ = cn.Close() 308 // Wait for the goroutine to finish and send something. 309 <-errc 310 311 err = ctx.Err() 312 return err 313 case err = <-errc: 314 return err 315 } 316} 317 318func (c *baseClient) process(ctx context.Context, cmd Cmder) error { 319 var lastErr error 320 for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { 321 attempt := attempt 322 323 retry, err := c._process(ctx, cmd, attempt) 324 if err == nil || !retry { 325 return err 326 } 327 328 lastErr = err 329 } 330 return lastErr 331} 332 333func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool, error) { 334 if attempt > 0 { 335 if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { 336 return false, err 337 } 338 } 339 340 retryTimeout := uint32(1) 341 err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { 342 err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { 343 return writeCmd(wr, cmd) 344 }) 345 if err != nil { 346 return err 347 } 348 349 err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply) 350 if err != nil { 351 if cmd.readTimeout() == nil { 352 atomic.StoreUint32(&retryTimeout, 1) 353 } 354 return err 355 } 356 357 return nil 358 }) 359 if err == nil { 360 return false, nil 361 } 362 363 retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1) 364 return retry, err 365} 366 367func (c *baseClient) retryBackoff(attempt int) time.Duration { 368 return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) 369} 370 371func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration { 372 if timeout := cmd.readTimeout(); timeout != nil { 373 t := *timeout 374 if t == 0 { 375 return 0 376 } 377 return t + 10*time.Second 378 } 379 return c.opt.ReadTimeout 380} 381 382// Close closes the client, releasing any open resources. 383// 384// It is rare to Close a Client, as the Client is meant to be 385// long-lived and shared between many goroutines. 386func (c *baseClient) Close() error { 387 var firstErr error 388 if c.onClose != nil { 389 if err := c.onClose(); err != nil { 390 firstErr = err 391 } 392 } 393 if err := c.connPool.Close(); err != nil && firstErr == nil { 394 firstErr = err 395 } 396 return firstErr 397} 398 399func (c *baseClient) getAddr() string { 400 return c.opt.Addr 401} 402 403func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error { 404 return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds) 405} 406 407func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error { 408 return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds) 409} 410 411type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error) 412 413func (c *baseClient) generalProcessPipeline( 414 ctx context.Context, cmds []Cmder, p pipelineProcessor, 415) error { 416 err := c._generalProcessPipeline(ctx, cmds, p) 417 if err != nil { 418 setCmdsErr(cmds, err) 419 return err 420 } 421 return cmdsFirstErr(cmds) 422} 423 424func (c *baseClient) _generalProcessPipeline( 425 ctx context.Context, cmds []Cmder, p pipelineProcessor, 426) error { 427 var lastErr error 428 for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { 429 if attempt > 0 { 430 if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { 431 return err 432 } 433 } 434 435 var canRetry bool 436 lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { 437 var err error 438 canRetry, err = p(ctx, cn, cmds) 439 return err 440 }) 441 if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) { 442 return lastErr 443 } 444 } 445 return lastErr 446} 447 448func (c *baseClient) pipelineProcessCmds( 449 ctx context.Context, cn *pool.Conn, cmds []Cmder, 450) (bool, error) { 451 err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { 452 return writeCmds(wr, cmds) 453 }) 454 if err != nil { 455 return true, err 456 } 457 458 err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { 459 return pipelineReadCmds(rd, cmds) 460 }) 461 return true, err 462} 463 464func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error { 465 for _, cmd := range cmds { 466 err := cmd.readReply(rd) 467 cmd.SetErr(err) 468 if err != nil && !isRedisError(err) { 469 return err 470 } 471 } 472 return nil 473} 474 475func (c *baseClient) txPipelineProcessCmds( 476 ctx context.Context, cn *pool.Conn, cmds []Cmder, 477) (bool, error) { 478 err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { 479 return writeCmds(wr, cmds) 480 }) 481 if err != nil { 482 return true, err 483 } 484 485 err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { 486 statusCmd := cmds[0].(*StatusCmd) 487 // Trim multi and exec. 488 cmds = cmds[1 : len(cmds)-1] 489 490 err := txPipelineReadQueued(rd, statusCmd, cmds) 491 if err != nil { 492 return err 493 } 494 495 return pipelineReadCmds(rd, cmds) 496 }) 497 return false, err 498} 499 500func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder { 501 if len(cmds) == 0 { 502 panic("not reached") 503 } 504 cmdCopy := make([]Cmder, len(cmds)+2) 505 cmdCopy[0] = NewStatusCmd(ctx, "multi") 506 copy(cmdCopy[1:], cmds) 507 cmdCopy[len(cmdCopy)-1] = NewSliceCmd(ctx, "exec") 508 return cmdCopy 509} 510 511func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error { 512 // Parse queued replies. 513 if err := statusCmd.readReply(rd); err != nil { 514 return err 515 } 516 517 for range cmds { 518 if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) { 519 return err 520 } 521 } 522 523 // Parse number of replies. 524 line, err := rd.ReadLine() 525 if err != nil { 526 if err == Nil { 527 err = TxFailedErr 528 } 529 return err 530 } 531 532 switch line[0] { 533 case proto.ErrorReply: 534 return proto.ParseErrorReply(line) 535 case proto.ArrayReply: 536 // ok 537 default: 538 err := fmt.Errorf("redis: expected '*', but got line %q", line) 539 return err 540 } 541 542 return nil 543} 544 545//------------------------------------------------------------------------------ 546 547// Client is a Redis client representing a pool of zero or more 548// underlying connections. It's safe for concurrent use by multiple 549// goroutines. 550type Client struct { 551 *baseClient 552 cmdable 553 hooks 554 ctx context.Context 555} 556 557// NewClient returns a client to the Redis Server specified by Options. 558func NewClient(opt *Options) *Client { 559 opt.init() 560 561 c := Client{ 562 baseClient: newBaseClient(opt, newConnPool(opt)), 563 ctx: context.Background(), 564 } 565 c.cmdable = c.Process 566 567 return &c 568} 569 570func (c *Client) clone() *Client { 571 clone := *c 572 clone.cmdable = clone.Process 573 clone.hooks.lock() 574 return &clone 575} 576 577func (c *Client) WithTimeout(timeout time.Duration) *Client { 578 clone := c.clone() 579 clone.baseClient = c.baseClient.withTimeout(timeout) 580 return clone 581} 582 583func (c *Client) Context() context.Context { 584 return c.ctx 585} 586 587func (c *Client) WithContext(ctx context.Context) *Client { 588 if ctx == nil { 589 panic("nil context") 590 } 591 clone := c.clone() 592 clone.ctx = ctx 593 return clone 594} 595 596func (c *Client) Conn(ctx context.Context) *Conn { 597 return newConn(ctx, c.opt, pool.NewStickyConnPool(c.connPool)) 598} 599 600// Do creates a Cmd from the args and processes the cmd. 601func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd { 602 cmd := NewCmd(ctx, args...) 603 _ = c.Process(ctx, cmd) 604 return cmd 605} 606 607func (c *Client) Process(ctx context.Context, cmd Cmder) error { 608 return c.hooks.process(ctx, cmd, c.baseClient.process) 609} 610 611func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error { 612 return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline) 613} 614 615func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error { 616 return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) 617} 618 619// Options returns read-only Options that were used to create the client. 620func (c *Client) Options() *Options { 621 return c.opt 622} 623 624type PoolStats pool.Stats 625 626// PoolStats returns connection pool stats. 627func (c *Client) PoolStats() *PoolStats { 628 stats := c.connPool.Stats() 629 return (*PoolStats)(stats) 630} 631 632func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { 633 return c.Pipeline().Pipelined(ctx, fn) 634} 635 636func (c *Client) Pipeline() Pipeliner { 637 pipe := Pipeline{ 638 ctx: c.ctx, 639 exec: c.processPipeline, 640 } 641 pipe.init() 642 return &pipe 643} 644 645func (c *Client) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { 646 return c.TxPipeline().Pipelined(ctx, fn) 647} 648 649// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. 650func (c *Client) TxPipeline() Pipeliner { 651 pipe := Pipeline{ 652 ctx: c.ctx, 653 exec: c.processTxPipeline, 654 } 655 pipe.init() 656 return &pipe 657} 658 659func (c *Client) pubSub() *PubSub { 660 pubsub := &PubSub{ 661 opt: c.opt, 662 663 newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { 664 return c.newConn(ctx) 665 }, 666 closeConn: c.connPool.CloseConn, 667 } 668 pubsub.init() 669 return pubsub 670} 671 672// Subscribe subscribes the client to the specified channels. 673// Channels can be omitted to create empty subscription. 674// Note that this method does not wait on a response from Redis, so the 675// subscription may not be active immediately. To force the connection to wait, 676// you may call the Receive() method on the returned *PubSub like so: 677// 678// sub := client.Subscribe(queryResp) 679// iface, err := sub.Receive() 680// if err != nil { 681// // handle error 682// } 683// 684// // Should be *Subscription, but others are possible if other actions have been 685// // taken on sub since it was created. 686// switch iface.(type) { 687// case *Subscription: 688// // subscribe succeeded 689// case *Message: 690// // received first message 691// case *Pong: 692// // pong received 693// default: 694// // handle error 695// } 696// 697// ch := sub.Channel() 698func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub { 699 pubsub := c.pubSub() 700 if len(channels) > 0 { 701 _ = pubsub.Subscribe(ctx, channels...) 702 } 703 return pubsub 704} 705 706// PSubscribe subscribes the client to the given patterns. 707// Patterns can be omitted to create empty subscription. 708func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub { 709 pubsub := c.pubSub() 710 if len(channels) > 0 { 711 _ = pubsub.PSubscribe(ctx, channels...) 712 } 713 return pubsub 714} 715 716//------------------------------------------------------------------------------ 717 718type conn struct { 719 baseClient 720 cmdable 721 statefulCmdable 722 hooks // TODO: inherit hooks 723} 724 725// Conn is like Client, but its pool contains single connection. 726type Conn struct { 727 *conn 728 ctx context.Context 729} 730 731func newConn(ctx context.Context, opt *Options, connPool pool.Pooler) *Conn { 732 c := Conn{ 733 conn: &conn{ 734 baseClient: baseClient{ 735 opt: opt, 736 connPool: connPool, 737 }, 738 }, 739 ctx: ctx, 740 } 741 c.cmdable = c.Process 742 c.statefulCmdable = c.Process 743 return &c 744} 745 746func (c *Conn) Process(ctx context.Context, cmd Cmder) error { 747 return c.hooks.process(ctx, cmd, c.baseClient.process) 748} 749 750func (c *Conn) processPipeline(ctx context.Context, cmds []Cmder) error { 751 return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline) 752} 753 754func (c *Conn) processTxPipeline(ctx context.Context, cmds []Cmder) error { 755 return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) 756} 757 758func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { 759 return c.Pipeline().Pipelined(ctx, fn) 760} 761 762func (c *Conn) Pipeline() Pipeliner { 763 pipe := Pipeline{ 764 ctx: c.ctx, 765 exec: c.processPipeline, 766 } 767 pipe.init() 768 return &pipe 769} 770 771func (c *Conn) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { 772 return c.TxPipeline().Pipelined(ctx, fn) 773} 774 775// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. 776func (c *Conn) TxPipeline() Pipeliner { 777 pipe := Pipeline{ 778 ctx: c.ctx, 779 exec: c.processTxPipeline, 780 } 781 pipe.init() 782 return &pipe 783} 784