1package redis 2 3import ( 4 "context" 5 "fmt" 6 "time" 7 8 "github.com/go-redis/redis/v8/internal" 9 "github.com/go-redis/redis/v8/internal/pool" 10 "github.com/go-redis/redis/v8/internal/proto" 11 "go.opentelemetry.io/otel/api/trace" 12 "go.opentelemetry.io/otel/label" 13) 14 15// Nil reply returned by Redis when key does not exist. 16const Nil = proto.Nil 17 18func SetLogger(logger internal.Logging) { 19 internal.Logger = logger 20} 21 22//------------------------------------------------------------------------------ 23 24type Hook interface { 25 BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) 26 AfterProcess(ctx context.Context, cmd Cmder) error 27 28 BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) 29 AfterProcessPipeline(ctx context.Context, cmds []Cmder) error 30} 31 32type hooks struct { 33 hooks []Hook 34} 35 36func (hs *hooks) lock() { 37 hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)] 38} 39 40func (hs hooks) clone() hooks { 41 clone := hs 42 clone.lock() 43 return clone 44} 45 46func (hs *hooks) AddHook(hook Hook) { 47 hs.hooks = append(hs.hooks, hook) 48} 49 50func (hs hooks) process( 51 ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error, 52) error { 53 if len(hs.hooks) == 0 { 54 err := hs.withContext(ctx, func() error { 55 return fn(ctx, cmd) 56 }) 57 cmd.SetErr(err) 58 return err 59 } 60 61 var hookIndex int 62 var retErr error 63 64 for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ { 65 ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd) 66 if retErr != nil { 67 cmd.SetErr(retErr) 68 } 69 } 70 71 if retErr == nil { 72 retErr = hs.withContext(ctx, func() error { 73 return fn(ctx, cmd) 74 }) 75 cmd.SetErr(retErr) 76 } 77 78 for hookIndex--; hookIndex >= 0; hookIndex-- { 79 if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil { 80 retErr = err 81 cmd.SetErr(retErr) 82 } 83 } 84 85 return retErr 86} 87 88func (hs hooks) processPipeline( 89 ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, 90) error { 91 if len(hs.hooks) == 0 { 92 err := hs.withContext(ctx, func() error { 93 return fn(ctx, cmds) 94 }) 95 return err 96 } 97 98 var hookIndex int 99 var retErr error 100 101 for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ { 102 ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds) 103 if retErr != nil { 104 setCmdsErr(cmds, retErr) 105 } 106 } 107 108 if retErr == nil { 109 retErr = hs.withContext(ctx, func() error { 110 return fn(ctx, cmds) 111 }) 112 } 113 114 for hookIndex--; hookIndex >= 0; hookIndex-- { 115 if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil { 116 retErr = err 117 setCmdsErr(cmds, retErr) 118 } 119 } 120 121 return retErr 122} 123 124func (hs hooks) processTxPipeline( 125 ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, 126) error { 127 cmds = wrapMultiExec(ctx, cmds) 128 return hs.processPipeline(ctx, cmds, fn) 129} 130 131func (hs hooks) withContext(ctx context.Context, fn func() error) error { 132 done := ctx.Done() 133 if done == nil { 134 return fn() 135 } 136 137 errc := make(chan error, 1) 138 go func() { errc <- fn() }() 139 140 select { 141 case <-done: 142 return ctx.Err() 143 case err := <-errc: 144 return err 145 } 146} 147 148//------------------------------------------------------------------------------ 149 150type baseClient struct { 151 opt *Options 152 connPool pool.Pooler 153 154 onClose func() error // hook called when client is closed 155} 156 157func newBaseClient(opt *Options, connPool pool.Pooler) *baseClient { 158 return &baseClient{ 159 opt: opt, 160 connPool: connPool, 161 } 162} 163 164func (c *baseClient) clone() *baseClient { 165 clone := *c 166 return &clone 167} 168 169func (c *baseClient) withTimeout(timeout time.Duration) *baseClient { 170 opt := c.opt.clone() 171 opt.ReadTimeout = timeout 172 opt.WriteTimeout = timeout 173 174 clone := c.clone() 175 clone.opt = opt 176 177 return clone 178} 179 180func (c *baseClient) String() string { 181 return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB) 182} 183 184func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) { 185 cn, err := c.connPool.NewConn(ctx) 186 if err != nil { 187 return nil, err 188 } 189 190 err = c.initConn(ctx, cn) 191 if err != nil { 192 _ = c.connPool.CloseConn(cn) 193 return nil, err 194 } 195 196 return cn, nil 197} 198 199func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) { 200 if c.opt.Limiter != nil { 201 err := c.opt.Limiter.Allow() 202 if err != nil { 203 return nil, err 204 } 205 } 206 207 cn, err := c._getConn(ctx) 208 if err != nil { 209 if c.opt.Limiter != nil { 210 c.opt.Limiter.ReportResult(err) 211 } 212 return nil, err 213 } 214 215 return cn, nil 216} 217 218func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { 219 cn, err := c.connPool.Get(ctx) 220 if err != nil { 221 return nil, err 222 } 223 224 if cn.Inited { 225 return cn, nil 226 } 227 228 err = internal.WithSpan(ctx, "init_conn", func(ctx context.Context, span trace.Span) error { 229 return c.initConn(ctx, cn) 230 }) 231 if err != nil { 232 c.connPool.Remove(ctx, cn, err) 233 if err := internal.Unwrap(err); err != nil { 234 return nil, err 235 } 236 return nil, err 237 } 238 239 return cn, nil 240} 241 242func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { 243 if cn.Inited { 244 return nil 245 } 246 cn.Inited = true 247 248 if c.opt.Password == "" && 249 c.opt.DB == 0 && 250 !c.opt.readOnly && 251 c.opt.OnConnect == nil { 252 return nil 253 } 254 255 connPool := pool.NewSingleConnPool(c.connPool, cn) 256 conn := newConn(ctx, c.opt, connPool) 257 258 _, err := conn.Pipelined(ctx, func(pipe Pipeliner) error { 259 if c.opt.Password != "" { 260 if c.opt.Username != "" { 261 pipe.AuthACL(ctx, c.opt.Username, c.opt.Password) 262 } else { 263 pipe.Auth(ctx, c.opt.Password) 264 } 265 } 266 267 if c.opt.DB > 0 { 268 pipe.Select(ctx, c.opt.DB) 269 } 270 271 if c.opt.readOnly { 272 pipe.ReadOnly(ctx) 273 } 274 275 return nil 276 }) 277 if err != nil { 278 return err 279 } 280 281 if c.opt.OnConnect != nil { 282 return c.opt.OnConnect(ctx, conn) 283 } 284 return nil 285} 286 287func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) { 288 if c.opt.Limiter != nil { 289 c.opt.Limiter.ReportResult(err) 290 } 291 292 if isBadConn(err, false) { 293 c.connPool.Remove(ctx, cn, err) 294 } else { 295 c.connPool.Put(ctx, cn) 296 } 297} 298 299func (c *baseClient) withConn( 300 ctx context.Context, fn func(context.Context, *pool.Conn) error, 301) error { 302 return internal.WithSpan(ctx, "with_conn", func(ctx context.Context, span trace.Span) error { 303 cn, err := c.getConn(ctx) 304 if err != nil { 305 return err 306 } 307 308 if span.IsRecording() { 309 if remoteAddr := cn.RemoteAddr(); remoteAddr != nil { 310 span.SetAttributes(label.String("net.peer.ip", remoteAddr.String())) 311 } 312 } 313 314 defer func() { 315 c.releaseConn(ctx, cn, err) 316 }() 317 318 err = fn(ctx, cn) 319 return err 320 }) 321} 322 323func (c *baseClient) process(ctx context.Context, cmd Cmder) error { 324 var lastErr error 325 for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { 326 attempt := attempt 327 328 var retry bool 329 err := internal.WithSpan(ctx, "process", func(ctx context.Context, span trace.Span) error { 330 if attempt > 0 { 331 if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { 332 return err 333 } 334 } 335 336 retryTimeout := true 337 err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { 338 err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { 339 return writeCmd(wr, cmd) 340 }) 341 if err != nil { 342 return err 343 } 344 345 err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply) 346 if err != nil { 347 retryTimeout = cmd.readTimeout() == nil 348 return err 349 } 350 351 return nil 352 }) 353 if err == nil { 354 return nil 355 } 356 retry = shouldRetry(err, retryTimeout) 357 return err 358 }) 359 if err == nil || !retry { 360 return err 361 } 362 lastErr = err 363 } 364 return lastErr 365} 366 367func (c *baseClient) retryBackoff(attempt int) time.Duration { 368 return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) 369} 370 371func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration { 372 if timeout := cmd.readTimeout(); timeout != nil { 373 t := *timeout 374 if t == 0 { 375 return 0 376 } 377 return t + 10*time.Second 378 } 379 return c.opt.ReadTimeout 380} 381 382// Close closes the client, releasing any open resources. 383// 384// It is rare to Close a Client, as the Client is meant to be 385// long-lived and shared between many goroutines. 386func (c *baseClient) Close() error { 387 var firstErr error 388 if c.onClose != nil { 389 if err := c.onClose(); err != nil { 390 firstErr = err 391 } 392 } 393 if err := c.connPool.Close(); err != nil && firstErr == nil { 394 firstErr = err 395 } 396 return firstErr 397} 398 399func (c *baseClient) getAddr() string { 400 return c.opt.Addr 401} 402 403func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error { 404 return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds) 405} 406 407func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error { 408 return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds) 409} 410 411type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error) 412 413func (c *baseClient) generalProcessPipeline( 414 ctx context.Context, cmds []Cmder, p pipelineProcessor, 415) error { 416 err := c._generalProcessPipeline(ctx, cmds, p) 417 if err != nil { 418 setCmdsErr(cmds, err) 419 return err 420 } 421 return cmdsFirstErr(cmds) 422} 423 424func (c *baseClient) _generalProcessPipeline( 425 ctx context.Context, cmds []Cmder, p pipelineProcessor, 426) error { 427 var lastErr error 428 for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { 429 if attempt > 0 { 430 if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { 431 return err 432 } 433 } 434 435 var canRetry bool 436 lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { 437 var err error 438 canRetry, err = p(ctx, cn, cmds) 439 return err 440 }) 441 if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) { 442 return lastErr 443 } 444 } 445 return lastErr 446} 447 448func (c *baseClient) pipelineProcessCmds( 449 ctx context.Context, cn *pool.Conn, cmds []Cmder, 450) (bool, error) { 451 err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { 452 return writeCmds(wr, cmds) 453 }) 454 if err != nil { 455 return true, err 456 } 457 458 err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { 459 return pipelineReadCmds(rd, cmds) 460 }) 461 return true, err 462} 463 464func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error { 465 for _, cmd := range cmds { 466 err := cmd.readReply(rd) 467 cmd.SetErr(err) 468 if err != nil && !isRedisError(err) { 469 return err 470 } 471 } 472 return nil 473} 474 475func (c *baseClient) txPipelineProcessCmds( 476 ctx context.Context, cn *pool.Conn, cmds []Cmder, 477) (bool, error) { 478 err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { 479 return writeCmds(wr, cmds) 480 }) 481 if err != nil { 482 return true, err 483 } 484 485 err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { 486 statusCmd := cmds[0].(*StatusCmd) 487 // Trim multi and exec. 488 cmds = cmds[1 : len(cmds)-1] 489 490 err := txPipelineReadQueued(rd, statusCmd, cmds) 491 if err != nil { 492 return err 493 } 494 495 return pipelineReadCmds(rd, cmds) 496 }) 497 return false, err 498} 499 500func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder { 501 if len(cmds) == 0 { 502 panic("not reached") 503 } 504 cmdCopy := make([]Cmder, len(cmds)+2) 505 cmdCopy[0] = NewStatusCmd(ctx, "multi") 506 copy(cmdCopy[1:], cmds) 507 cmdCopy[len(cmdCopy)-1] = NewSliceCmd(ctx, "exec") 508 return cmdCopy 509} 510 511func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error { 512 // Parse queued replies. 513 if err := statusCmd.readReply(rd); err != nil { 514 return err 515 } 516 517 for range cmds { 518 if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) { 519 return err 520 } 521 } 522 523 // Parse number of replies. 524 line, err := rd.ReadLine() 525 if err != nil { 526 if err == Nil { 527 err = TxFailedErr 528 } 529 return err 530 } 531 532 switch line[0] { 533 case proto.ErrorReply: 534 return proto.ParseErrorReply(line) 535 case proto.ArrayReply: 536 // ok 537 default: 538 err := fmt.Errorf("redis: expected '*', but got line %q", line) 539 return err 540 } 541 542 return nil 543} 544 545//------------------------------------------------------------------------------ 546 547// Client is a Redis client representing a pool of zero or more 548// underlying connections. It's safe for concurrent use by multiple 549// goroutines. 550type Client struct { 551 *baseClient 552 cmdable 553 hooks 554 ctx context.Context 555} 556 557// NewClient returns a client to the Redis Server specified by Options. 558func NewClient(opt *Options) *Client { 559 opt.init() 560 561 c := Client{ 562 baseClient: newBaseClient(opt, newConnPool(opt)), 563 ctx: context.Background(), 564 } 565 c.cmdable = c.Process 566 567 return &c 568} 569 570func (c *Client) clone() *Client { 571 clone := *c 572 clone.cmdable = clone.Process 573 clone.hooks.lock() 574 return &clone 575} 576 577func (c *Client) WithTimeout(timeout time.Duration) *Client { 578 clone := c.clone() 579 clone.baseClient = c.baseClient.withTimeout(timeout) 580 return clone 581} 582 583func (c *Client) Context() context.Context { 584 return c.ctx 585} 586 587func (c *Client) WithContext(ctx context.Context) *Client { 588 if ctx == nil { 589 panic("nil context") 590 } 591 clone := c.clone() 592 clone.ctx = ctx 593 return clone 594} 595 596func (c *Client) Conn(ctx context.Context) *Conn { 597 return newConn(ctx, c.opt, pool.NewStickyConnPool(c.connPool)) 598} 599 600// Do creates a Cmd from the args and processes the cmd. 601func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd { 602 cmd := NewCmd(ctx, args...) 603 _ = c.Process(ctx, cmd) 604 return cmd 605} 606 607func (c *Client) Process(ctx context.Context, cmd Cmder) error { 608 return c.hooks.process(ctx, cmd, c.baseClient.process) 609} 610 611func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error { 612 return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline) 613} 614 615func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error { 616 return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) 617} 618 619// Options returns read-only Options that were used to create the client. 620func (c *Client) Options() *Options { 621 return c.opt 622} 623 624type PoolStats pool.Stats 625 626// PoolStats returns connection pool stats. 627func (c *Client) PoolStats() *PoolStats { 628 stats := c.connPool.Stats() 629 return (*PoolStats)(stats) 630} 631 632func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { 633 return c.Pipeline().Pipelined(ctx, fn) 634} 635 636func (c *Client) Pipeline() Pipeliner { 637 pipe := Pipeline{ 638 ctx: c.ctx, 639 exec: c.processPipeline, 640 } 641 pipe.init() 642 return &pipe 643} 644 645func (c *Client) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { 646 return c.TxPipeline().Pipelined(ctx, fn) 647} 648 649// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. 650func (c *Client) TxPipeline() Pipeliner { 651 pipe := Pipeline{ 652 ctx: c.ctx, 653 exec: c.processTxPipeline, 654 } 655 pipe.init() 656 return &pipe 657} 658 659func (c *Client) pubSub() *PubSub { 660 pubsub := &PubSub{ 661 opt: c.opt, 662 663 newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { 664 return c.newConn(ctx) 665 }, 666 closeConn: c.connPool.CloseConn, 667 } 668 pubsub.init() 669 return pubsub 670} 671 672// Subscribe subscribes the client to the specified channels. 673// Channels can be omitted to create empty subscription. 674// Note that this method does not wait on a response from Redis, so the 675// subscription may not be active immediately. To force the connection to wait, 676// you may call the Receive() method on the returned *PubSub like so: 677// 678// sub := client.Subscribe(queryResp) 679// iface, err := sub.Receive() 680// if err != nil { 681// // handle error 682// } 683// 684// // Should be *Subscription, but others are possible if other actions have been 685// // taken on sub since it was created. 686// switch iface.(type) { 687// case *Subscription: 688// // subscribe succeeded 689// case *Message: 690// // received first message 691// case *Pong: 692// // pong received 693// default: 694// // handle error 695// } 696// 697// ch := sub.Channel() 698func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub { 699 pubsub := c.pubSub() 700 if len(channels) > 0 { 701 _ = pubsub.Subscribe(ctx, channels...) 702 } 703 return pubsub 704} 705 706// PSubscribe subscribes the client to the given patterns. 707// Patterns can be omitted to create empty subscription. 708func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub { 709 pubsub := c.pubSub() 710 if len(channels) > 0 { 711 _ = pubsub.PSubscribe(ctx, channels...) 712 } 713 return pubsub 714} 715 716//------------------------------------------------------------------------------ 717 718type conn struct { 719 baseClient 720 cmdable 721 statefulCmdable 722} 723 724// Conn is like Client, but its pool contains single connection. 725type Conn struct { 726 *conn 727 ctx context.Context 728} 729 730func newConn(ctx context.Context, opt *Options, connPool pool.Pooler) *Conn { 731 c := Conn{ 732 conn: &conn{ 733 baseClient: baseClient{ 734 opt: opt, 735 connPool: connPool, 736 }, 737 }, 738 ctx: ctx, 739 } 740 c.cmdable = c.Process 741 c.statefulCmdable = c.Process 742 return &c 743} 744 745func (c *Conn) Process(ctx context.Context, cmd Cmder) error { 746 return c.baseClient.process(ctx, cmd) 747} 748 749func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { 750 return c.Pipeline().Pipelined(ctx, fn) 751} 752 753func (c *Conn) Pipeline() Pipeliner { 754 pipe := Pipeline{ 755 ctx: c.ctx, 756 exec: c.processPipeline, 757 } 758 pipe.init() 759 return &pipe 760} 761 762func (c *Conn) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { 763 return c.TxPipeline().Pipelined(ctx, fn) 764} 765 766// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. 767func (c *Conn) TxPipeline() Pipeliner { 768 pipe := Pipeline{ 769 ctx: c.ctx, 770 exec: c.processTxPipeline, 771 } 772 pipe.init() 773 return &pipe 774} 775