1package redis
2
3import (
4	"context"
5	"crypto/tls"
6	"fmt"
7	"math"
8	"math/rand"
9	"net"
10	"runtime"
11	"sort"
12	"sync"
13	"sync/atomic"
14	"time"
15
16	"github.com/go-redis/redis/v7/internal"
17	"github.com/go-redis/redis/v7/internal/hashtag"
18	"github.com/go-redis/redis/v7/internal/pool"
19	"github.com/go-redis/redis/v7/internal/proto"
20)
21
22var errClusterNoNodes = fmt.Errorf("redis: cluster has no nodes")
23
24// ClusterOptions are used to configure a cluster client and should be
25// passed to NewClusterClient.
26type ClusterOptions struct {
27	// A seed list of host:port addresses of cluster nodes.
28	Addrs []string
29
30	// The maximum number of retries before giving up. Command is retried
31	// on network errors and MOVED/ASK redirects.
32	// Default is 8 retries.
33	MaxRedirects int
34
35	// Enables read-only commands on slave nodes.
36	ReadOnly bool
37	// Allows routing read-only commands to the closest master or slave node.
38	// It automatically enables ReadOnly.
39	RouteByLatency bool
40	// Allows routing read-only commands to the random master or slave node.
41	// It automatically enables ReadOnly.
42	RouteRandomly bool
43
44	// Optional function that returns cluster slots information.
45	// It is useful to manually create cluster of standalone Redis servers
46	// and load-balance read/write operations between master and slaves.
47	// It can use service like ZooKeeper to maintain configuration information
48	// and Cluster.ReloadState to manually trigger state reloading.
49	ClusterSlots func() ([]ClusterSlot, error)
50
51	// Optional hook that is called when a new node is created.
52	OnNewNode func(*Client)
53
54	// Following options are copied from Options struct.
55
56	Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
57
58	OnConnect func(*Conn) error
59
60	Password string
61
62	MaxRetries      int
63	MinRetryBackoff time.Duration
64	MaxRetryBackoff time.Duration
65
66	DialTimeout  time.Duration
67	ReadTimeout  time.Duration
68	WriteTimeout time.Duration
69
70	// PoolSize applies per cluster node and not for the whole cluster.
71	PoolSize           int
72	MinIdleConns       int
73	MaxConnAge         time.Duration
74	PoolTimeout        time.Duration
75	IdleTimeout        time.Duration
76	IdleCheckFrequency time.Duration
77
78	TLSConfig *tls.Config
79}
80
81func (opt *ClusterOptions) init() {
82	if opt.MaxRedirects == -1 {
83		opt.MaxRedirects = 0
84	} else if opt.MaxRedirects == 0 {
85		opt.MaxRedirects = 8
86	}
87
88	if (opt.RouteByLatency || opt.RouteRandomly) && opt.ClusterSlots == nil {
89		opt.ReadOnly = true
90	}
91
92	if opt.PoolSize == 0 {
93		opt.PoolSize = 5 * runtime.NumCPU()
94	}
95
96	switch opt.ReadTimeout {
97	case -1:
98		opt.ReadTimeout = 0
99	case 0:
100		opt.ReadTimeout = 3 * time.Second
101	}
102	switch opt.WriteTimeout {
103	case -1:
104		opt.WriteTimeout = 0
105	case 0:
106		opt.WriteTimeout = opt.ReadTimeout
107	}
108
109	switch opt.MinRetryBackoff {
110	case -1:
111		opt.MinRetryBackoff = 0
112	case 0:
113		opt.MinRetryBackoff = 8 * time.Millisecond
114	}
115	switch opt.MaxRetryBackoff {
116	case -1:
117		opt.MaxRetryBackoff = 0
118	case 0:
119		opt.MaxRetryBackoff = 512 * time.Millisecond
120	}
121}
122
123func (opt *ClusterOptions) clientOptions() *Options {
124	const disableIdleCheck = -1
125
126	return &Options{
127		Dialer:    opt.Dialer,
128		OnConnect: opt.OnConnect,
129
130		MaxRetries:      opt.MaxRetries,
131		MinRetryBackoff: opt.MinRetryBackoff,
132		MaxRetryBackoff: opt.MaxRetryBackoff,
133		Password:        opt.Password,
134		readOnly:        opt.ReadOnly,
135
136		DialTimeout:  opt.DialTimeout,
137		ReadTimeout:  opt.ReadTimeout,
138		WriteTimeout: opt.WriteTimeout,
139
140		PoolSize:           opt.PoolSize,
141		MinIdleConns:       opt.MinIdleConns,
142		MaxConnAge:         opt.MaxConnAge,
143		PoolTimeout:        opt.PoolTimeout,
144		IdleTimeout:        opt.IdleTimeout,
145		IdleCheckFrequency: disableIdleCheck,
146
147		TLSConfig: opt.TLSConfig,
148	}
149}
150
151//------------------------------------------------------------------------------
152
153type clusterNode struct {
154	Client *Client
155
156	latency    uint32 // atomic
157	generation uint32 // atomic
158	failing    uint32 // atomic
159}
160
161func newClusterNode(clOpt *ClusterOptions, addr string) *clusterNode {
162	opt := clOpt.clientOptions()
163	opt.Addr = addr
164	node := clusterNode{
165		Client: NewClient(opt),
166	}
167
168	node.latency = math.MaxUint32
169	if clOpt.RouteByLatency {
170		go node.updateLatency()
171	}
172
173	if clOpt.OnNewNode != nil {
174		clOpt.OnNewNode(node.Client)
175	}
176
177	return &node
178}
179
180func (n *clusterNode) String() string {
181	return n.Client.String()
182}
183
184func (n *clusterNode) Close() error {
185	return n.Client.Close()
186}
187
188func (n *clusterNode) updateLatency() {
189	const probes = 10
190
191	var latency uint32
192	for i := 0; i < probes; i++ {
193		start := time.Now()
194		n.Client.Ping()
195		probe := uint32(time.Since(start) / time.Microsecond)
196		latency = (latency + probe) / 2
197	}
198	atomic.StoreUint32(&n.latency, latency)
199}
200
201func (n *clusterNode) Latency() time.Duration {
202	latency := atomic.LoadUint32(&n.latency)
203	return time.Duration(latency) * time.Microsecond
204}
205
206func (n *clusterNode) MarkAsFailing() {
207	atomic.StoreUint32(&n.failing, uint32(time.Now().Unix()))
208}
209
210func (n *clusterNode) Failing() bool {
211	const timeout = 15 // 15 seconds
212
213	failing := atomic.LoadUint32(&n.failing)
214	if failing == 0 {
215		return false
216	}
217	if time.Now().Unix()-int64(failing) < timeout {
218		return true
219	}
220	atomic.StoreUint32(&n.failing, 0)
221	return false
222}
223
224func (n *clusterNode) Generation() uint32 {
225	return atomic.LoadUint32(&n.generation)
226}
227
228func (n *clusterNode) SetGeneration(gen uint32) {
229	for {
230		v := atomic.LoadUint32(&n.generation)
231		if gen < v || atomic.CompareAndSwapUint32(&n.generation, v, gen) {
232			break
233		}
234	}
235}
236
237//------------------------------------------------------------------------------
238
239type clusterNodes struct {
240	opt *ClusterOptions
241
242	mu           sync.RWMutex
243	allAddrs     []string
244	allNodes     map[string]*clusterNode
245	clusterAddrs []string
246	closed       bool
247
248	_generation uint32 // atomic
249}
250
251func newClusterNodes(opt *ClusterOptions) *clusterNodes {
252	return &clusterNodes{
253		opt: opt,
254
255		allAddrs: opt.Addrs,
256		allNodes: make(map[string]*clusterNode),
257	}
258}
259
260func (c *clusterNodes) Close() error {
261	c.mu.Lock()
262	defer c.mu.Unlock()
263
264	if c.closed {
265		return nil
266	}
267	c.closed = true
268
269	var firstErr error
270	for _, node := range c.allNodes {
271		if err := node.Client.Close(); err != nil && firstErr == nil {
272			firstErr = err
273		}
274	}
275
276	c.allNodes = nil
277	c.clusterAddrs = nil
278
279	return firstErr
280}
281
282func (c *clusterNodes) Addrs() ([]string, error) {
283	var addrs []string
284	c.mu.RLock()
285	closed := c.closed
286	if !closed {
287		if len(c.clusterAddrs) > 0 {
288			addrs = c.clusterAddrs
289		} else {
290			addrs = c.allAddrs
291		}
292	}
293	c.mu.RUnlock()
294
295	if closed {
296		return nil, pool.ErrClosed
297	}
298	if len(addrs) == 0 {
299		return nil, errClusterNoNodes
300	}
301	return addrs, nil
302}
303
304func (c *clusterNodes) NextGeneration() uint32 {
305	return atomic.AddUint32(&c._generation, 1)
306}
307
308// GC removes unused nodes.
309func (c *clusterNodes) GC(generation uint32) {
310	//nolint:prealloc
311	var collected []*clusterNode
312	c.mu.Lock()
313	for addr, node := range c.allNodes {
314		if node.Generation() >= generation {
315			continue
316		}
317
318		c.clusterAddrs = remove(c.clusterAddrs, addr)
319		delete(c.allNodes, addr)
320		collected = append(collected, node)
321	}
322	c.mu.Unlock()
323
324	for _, node := range collected {
325		_ = node.Client.Close()
326	}
327}
328
329func (c *clusterNodes) Get(addr string) (*clusterNode, error) {
330	node, err := c.get(addr)
331	if err != nil {
332		return nil, err
333	}
334	if node != nil {
335		return node, nil
336	}
337
338	c.mu.Lock()
339	defer c.mu.Unlock()
340
341	if c.closed {
342		return nil, pool.ErrClosed
343	}
344
345	node, ok := c.allNodes[addr]
346	if ok {
347		return node, err
348	}
349
350	node = newClusterNode(c.opt, addr)
351
352	c.allAddrs = appendIfNotExists(c.allAddrs, addr)
353	c.clusterAddrs = append(c.clusterAddrs, addr)
354	c.allNodes[addr] = node
355
356	return node, err
357}
358
359func (c *clusterNodes) get(addr string) (*clusterNode, error) {
360	var node *clusterNode
361	var err error
362	c.mu.RLock()
363	if c.closed {
364		err = pool.ErrClosed
365	} else {
366		node = c.allNodes[addr]
367	}
368	c.mu.RUnlock()
369	return node, err
370}
371
372func (c *clusterNodes) All() ([]*clusterNode, error) {
373	c.mu.RLock()
374	defer c.mu.RUnlock()
375
376	if c.closed {
377		return nil, pool.ErrClosed
378	}
379
380	cp := make([]*clusterNode, 0, len(c.allNodes))
381	for _, node := range c.allNodes {
382		cp = append(cp, node)
383	}
384	return cp, nil
385}
386
387func (c *clusterNodes) Random() (*clusterNode, error) {
388	addrs, err := c.Addrs()
389	if err != nil {
390		return nil, err
391	}
392
393	n := rand.Intn(len(addrs))
394	return c.Get(addrs[n])
395}
396
397//------------------------------------------------------------------------------
398
399type clusterSlot struct {
400	start, end int
401	nodes      []*clusterNode
402}
403
404type clusterSlotSlice []*clusterSlot
405
406func (p clusterSlotSlice) Len() int {
407	return len(p)
408}
409
410func (p clusterSlotSlice) Less(i, j int) bool {
411	return p[i].start < p[j].start
412}
413
414func (p clusterSlotSlice) Swap(i, j int) {
415	p[i], p[j] = p[j], p[i]
416}
417
418type clusterState struct {
419	nodes   *clusterNodes
420	Masters []*clusterNode
421	Slaves  []*clusterNode
422
423	slots []*clusterSlot
424
425	generation uint32
426	createdAt  time.Time
427}
428
429func newClusterState(
430	nodes *clusterNodes, slots []ClusterSlot, origin string,
431) (*clusterState, error) {
432	c := clusterState{
433		nodes: nodes,
434
435		slots: make([]*clusterSlot, 0, len(slots)),
436
437		generation: nodes.NextGeneration(),
438		createdAt:  time.Now(),
439	}
440
441	originHost, _, _ := net.SplitHostPort(origin)
442	isLoopbackOrigin := isLoopback(originHost)
443
444	for _, slot := range slots {
445		var nodes []*clusterNode
446		for i, slotNode := range slot.Nodes {
447			addr := slotNode.Addr
448			if !isLoopbackOrigin {
449				addr = replaceLoopbackHost(addr, originHost)
450			}
451
452			node, err := c.nodes.Get(addr)
453			if err != nil {
454				return nil, err
455			}
456
457			node.SetGeneration(c.generation)
458			nodes = append(nodes, node)
459
460			if i == 0 {
461				c.Masters = appendUniqueNode(c.Masters, node)
462			} else {
463				c.Slaves = appendUniqueNode(c.Slaves, node)
464			}
465		}
466
467		c.slots = append(c.slots, &clusterSlot{
468			start: slot.Start,
469			end:   slot.End,
470			nodes: nodes,
471		})
472	}
473
474	sort.Sort(clusterSlotSlice(c.slots))
475
476	time.AfterFunc(time.Minute, func() {
477		nodes.GC(c.generation)
478	})
479
480	return &c, nil
481}
482
483func replaceLoopbackHost(nodeAddr, originHost string) string {
484	nodeHost, nodePort, err := net.SplitHostPort(nodeAddr)
485	if err != nil {
486		return nodeAddr
487	}
488
489	nodeIP := net.ParseIP(nodeHost)
490	if nodeIP == nil {
491		return nodeAddr
492	}
493
494	if !nodeIP.IsLoopback() {
495		return nodeAddr
496	}
497
498	// Use origin host which is not loopback and node port.
499	return net.JoinHostPort(originHost, nodePort)
500}
501
502func isLoopback(host string) bool {
503	ip := net.ParseIP(host)
504	if ip == nil {
505		return true
506	}
507	return ip.IsLoopback()
508}
509
510func (c *clusterState) slotMasterNode(slot int) (*clusterNode, error) {
511	nodes := c.slotNodes(slot)
512	if len(nodes) > 0 {
513		return nodes[0], nil
514	}
515	return c.nodes.Random()
516}
517
518func (c *clusterState) slotSlaveNode(slot int) (*clusterNode, error) {
519	nodes := c.slotNodes(slot)
520	switch len(nodes) {
521	case 0:
522		return c.nodes.Random()
523	case 1:
524		return nodes[0], nil
525	case 2:
526		if slave := nodes[1]; !slave.Failing() {
527			return slave, nil
528		}
529		return nodes[0], nil
530	default:
531		var slave *clusterNode
532		for i := 0; i < 10; i++ {
533			n := rand.Intn(len(nodes)-1) + 1
534			slave = nodes[n]
535			if !slave.Failing() {
536				return slave, nil
537			}
538		}
539
540		// All slaves are loading - use master.
541		return nodes[0], nil
542	}
543}
544
545func (c *clusterState) slotClosestNode(slot int) (*clusterNode, error) {
546	const threshold = time.Millisecond
547
548	nodes := c.slotNodes(slot)
549	if len(nodes) == 0 {
550		return c.nodes.Random()
551	}
552
553	var node *clusterNode
554	for _, n := range nodes {
555		if n.Failing() {
556			continue
557		}
558		if node == nil || node.Latency()-n.Latency() > threshold {
559			node = n
560		}
561	}
562	return node, nil
563}
564
565func (c *clusterState) slotRandomNode(slot int) (*clusterNode, error) {
566	nodes := c.slotNodes(slot)
567	if len(nodes) == 0 {
568		return c.nodes.Random()
569	}
570	n := rand.Intn(len(nodes))
571	return nodes[n], nil
572}
573
574func (c *clusterState) slotNodes(slot int) []*clusterNode {
575	i := sort.Search(len(c.slots), func(i int) bool {
576		return c.slots[i].end >= slot
577	})
578	if i >= len(c.slots) {
579		return nil
580	}
581	x := c.slots[i]
582	if slot >= x.start && slot <= x.end {
583		return x.nodes
584	}
585	return nil
586}
587
588//------------------------------------------------------------------------------
589
590type clusterStateHolder struct {
591	load func() (*clusterState, error)
592
593	state     atomic.Value
594	reloading uint32 // atomic
595}
596
597func newClusterStateHolder(fn func() (*clusterState, error)) *clusterStateHolder {
598	return &clusterStateHolder{
599		load: fn,
600	}
601}
602
603func (c *clusterStateHolder) Reload() (*clusterState, error) {
604	state, err := c.load()
605	if err != nil {
606		return nil, err
607	}
608	c.state.Store(state)
609	return state, nil
610}
611
612func (c *clusterStateHolder) LazyReload() {
613	if !atomic.CompareAndSwapUint32(&c.reloading, 0, 1) {
614		return
615	}
616	go func() {
617		defer atomic.StoreUint32(&c.reloading, 0)
618
619		_, err := c.Reload()
620		if err != nil {
621			return
622		}
623		time.Sleep(100 * time.Millisecond)
624	}()
625}
626
627func (c *clusterStateHolder) Get() (*clusterState, error) {
628	v := c.state.Load()
629	if v != nil {
630		state := v.(*clusterState)
631		if time.Since(state.createdAt) > time.Minute {
632			c.LazyReload()
633		}
634		return state, nil
635	}
636	return c.Reload()
637}
638
639func (c *clusterStateHolder) ReloadOrGet() (*clusterState, error) {
640	state, err := c.Reload()
641	if err == nil {
642		return state, nil
643	}
644	return c.Get()
645}
646
647//------------------------------------------------------------------------------
648
649type clusterClient struct {
650	opt           *ClusterOptions
651	nodes         *clusterNodes
652	state         *clusterStateHolder //nolint:structcheck
653	cmdsInfoCache *cmdsInfoCache      //nolint:structcheck
654}
655
656// ClusterClient is a Redis Cluster client representing a pool of zero
657// or more underlying connections. It's safe for concurrent use by
658// multiple goroutines.
659type ClusterClient struct {
660	*clusterClient
661	cmdable
662	hooks
663	ctx context.Context
664}
665
666// NewClusterClient returns a Redis Cluster client as described in
667// http://redis.io/topics/cluster-spec.
668func NewClusterClient(opt *ClusterOptions) *ClusterClient {
669	opt.init()
670
671	c := &ClusterClient{
672		clusterClient: &clusterClient{
673			opt:   opt,
674			nodes: newClusterNodes(opt),
675		},
676		ctx: context.Background(),
677	}
678	c.state = newClusterStateHolder(c.loadState)
679	c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo)
680	c.cmdable = c.Process
681
682	if opt.IdleCheckFrequency > 0 {
683		go c.reaper(opt.IdleCheckFrequency)
684	}
685
686	return c
687}
688
689func (c *ClusterClient) Context() context.Context {
690	return c.ctx
691}
692
693func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient {
694	if ctx == nil {
695		panic("nil context")
696	}
697	clone := *c
698	clone.cmdable = clone.Process
699	clone.hooks.lock()
700	clone.ctx = ctx
701	return &clone
702}
703
704// Options returns read-only Options that were used to create the client.
705func (c *ClusterClient) Options() *ClusterOptions {
706	return c.opt
707}
708
709// ReloadState reloads cluster state. If available it calls ClusterSlots func
710// to get cluster slots information.
711func (c *ClusterClient) ReloadState() error {
712	_, err := c.state.Reload()
713	return err
714}
715
716// Close closes the cluster client, releasing any open resources.
717//
718// It is rare to Close a ClusterClient, as the ClusterClient is meant
719// to be long-lived and shared between many goroutines.
720func (c *ClusterClient) Close() error {
721	return c.nodes.Close()
722}
723
724// Do creates a Cmd from the args and processes the cmd.
725func (c *ClusterClient) Do(args ...interface{}) *Cmd {
726	return c.DoContext(c.ctx, args...)
727}
728
729func (c *ClusterClient) DoContext(ctx context.Context, args ...interface{}) *Cmd {
730	cmd := NewCmd(args...)
731	_ = c.ProcessContext(ctx, cmd)
732	return cmd
733}
734
735func (c *ClusterClient) Process(cmd Cmder) error {
736	return c.ProcessContext(c.ctx, cmd)
737}
738
739func (c *ClusterClient) ProcessContext(ctx context.Context, cmd Cmder) error {
740	return c.hooks.process(ctx, cmd, c.process)
741}
742
743func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error {
744	err := c._process(ctx, cmd)
745	if err != nil {
746		cmd.SetErr(err)
747		return err
748	}
749	return nil
750}
751
752func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error {
753	cmdInfo := c.cmdInfo(cmd.Name())
754	slot := c.cmdSlot(cmd)
755
756	var node *clusterNode
757	var ask bool
758	var lastErr error
759	for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ {
760		if attempt > 0 {
761			if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
762				return err
763			}
764		}
765
766		if node == nil {
767			var err error
768			node, err = c.cmdNode(cmdInfo, slot)
769			if err != nil {
770				return err
771			}
772		}
773
774		if ask {
775			pipe := node.Client.Pipeline()
776			_ = pipe.Process(NewCmd("asking"))
777			_ = pipe.Process(cmd)
778			_, lastErr = pipe.ExecContext(ctx)
779			_ = pipe.Close()
780			ask = false
781		} else {
782			lastErr = node.Client.ProcessContext(ctx, cmd)
783		}
784
785		// If there is no error - we are done.
786		if lastErr == nil {
787			return nil
788		}
789		if lastErr != Nil {
790			c.state.LazyReload()
791		}
792		if lastErr == pool.ErrClosed || isReadOnlyError(lastErr) {
793			node = nil
794			continue
795		}
796
797		// If slave is loading - pick another node.
798		if c.opt.ReadOnly && isLoadingError(lastErr) {
799			node.MarkAsFailing()
800			node = nil
801			continue
802		}
803
804		var moved bool
805		var addr string
806		moved, ask, addr = isMovedError(lastErr)
807		if moved || ask {
808			var err error
809			node, err = c.nodes.Get(addr)
810			if err != nil {
811				return err
812			}
813			continue
814		}
815
816		if isRetryableError(lastErr, cmd.readTimeout() == nil) {
817			// First retry the same node.
818			if attempt == 0 {
819				continue
820			}
821
822			// Second try another node.
823			node.MarkAsFailing()
824			node = nil
825			continue
826		}
827
828		return lastErr
829	}
830	return lastErr
831}
832
833// ForEachMaster concurrently calls the fn on each master node in the cluster.
834// It returns the first error if any.
835func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error {
836	state, err := c.state.ReloadOrGet()
837	if err != nil {
838		return err
839	}
840
841	var wg sync.WaitGroup
842	errCh := make(chan error, 1)
843
844	for _, master := range state.Masters {
845		wg.Add(1)
846		go func(node *clusterNode) {
847			defer wg.Done()
848			err := fn(node.Client)
849			if err != nil {
850				select {
851				case errCh <- err:
852				default:
853				}
854			}
855		}(master)
856	}
857
858	wg.Wait()
859
860	select {
861	case err := <-errCh:
862		return err
863	default:
864		return nil
865	}
866}
867
868// ForEachSlave concurrently calls the fn on each slave node in the cluster.
869// It returns the first error if any.
870func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error {
871	state, err := c.state.ReloadOrGet()
872	if err != nil {
873		return err
874	}
875
876	var wg sync.WaitGroup
877	errCh := make(chan error, 1)
878
879	for _, slave := range state.Slaves {
880		wg.Add(1)
881		go func(node *clusterNode) {
882			defer wg.Done()
883			err := fn(node.Client)
884			if err != nil {
885				select {
886				case errCh <- err:
887				default:
888				}
889			}
890		}(slave)
891	}
892
893	wg.Wait()
894
895	select {
896	case err := <-errCh:
897		return err
898	default:
899		return nil
900	}
901}
902
903// ForEachNode concurrently calls the fn on each known node in the cluster.
904// It returns the first error if any.
905func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error {
906	state, err := c.state.ReloadOrGet()
907	if err != nil {
908		return err
909	}
910
911	var wg sync.WaitGroup
912	errCh := make(chan error, 1)
913
914	worker := func(node *clusterNode) {
915		defer wg.Done()
916		err := fn(node.Client)
917		if err != nil {
918			select {
919			case errCh <- err:
920			default:
921			}
922		}
923	}
924
925	for _, node := range state.Masters {
926		wg.Add(1)
927		go worker(node)
928	}
929	for _, node := range state.Slaves {
930		wg.Add(1)
931		go worker(node)
932	}
933
934	wg.Wait()
935
936	select {
937	case err := <-errCh:
938		return err
939	default:
940		return nil
941	}
942}
943
944// PoolStats returns accumulated connection pool stats.
945func (c *ClusterClient) PoolStats() *PoolStats {
946	var acc PoolStats
947
948	state, _ := c.state.Get()
949	if state == nil {
950		return &acc
951	}
952
953	for _, node := range state.Masters {
954		s := node.Client.connPool.Stats()
955		acc.Hits += s.Hits
956		acc.Misses += s.Misses
957		acc.Timeouts += s.Timeouts
958
959		acc.TotalConns += s.TotalConns
960		acc.IdleConns += s.IdleConns
961		acc.StaleConns += s.StaleConns
962	}
963
964	for _, node := range state.Slaves {
965		s := node.Client.connPool.Stats()
966		acc.Hits += s.Hits
967		acc.Misses += s.Misses
968		acc.Timeouts += s.Timeouts
969
970		acc.TotalConns += s.TotalConns
971		acc.IdleConns += s.IdleConns
972		acc.StaleConns += s.StaleConns
973	}
974
975	return &acc
976}
977
978func (c *ClusterClient) loadState() (*clusterState, error) {
979	if c.opt.ClusterSlots != nil {
980		slots, err := c.opt.ClusterSlots()
981		if err != nil {
982			return nil, err
983		}
984		return newClusterState(c.nodes, slots, "")
985	}
986
987	addrs, err := c.nodes.Addrs()
988	if err != nil {
989		return nil, err
990	}
991
992	var firstErr error
993	for _, addr := range addrs {
994		node, err := c.nodes.Get(addr)
995		if err != nil {
996			if firstErr == nil {
997				firstErr = err
998			}
999			continue
1000		}
1001
1002		slots, err := node.Client.ClusterSlots().Result()
1003		if err != nil {
1004			if firstErr == nil {
1005				firstErr = err
1006			}
1007			continue
1008		}
1009
1010		return newClusterState(c.nodes, slots, node.Client.opt.Addr)
1011	}
1012
1013	return nil, firstErr
1014}
1015
1016// reaper closes idle connections to the cluster.
1017func (c *ClusterClient) reaper(idleCheckFrequency time.Duration) {
1018	ticker := time.NewTicker(idleCheckFrequency)
1019	defer ticker.Stop()
1020
1021	for range ticker.C {
1022		nodes, err := c.nodes.All()
1023		if err != nil {
1024			break
1025		}
1026
1027		for _, node := range nodes {
1028			_, err := node.Client.connPool.(*pool.ConnPool).ReapStaleConns()
1029			if err != nil {
1030				internal.Logger.Printf("ReapStaleConns failed: %s", err)
1031			}
1032		}
1033	}
1034}
1035
1036func (c *ClusterClient) Pipeline() Pipeliner {
1037	pipe := Pipeline{
1038		ctx:  c.ctx,
1039		exec: c.processPipeline,
1040	}
1041	pipe.init()
1042	return &pipe
1043}
1044
1045func (c *ClusterClient) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
1046	return c.Pipeline().Pipelined(fn)
1047}
1048
1049func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error {
1050	return c.hooks.processPipeline(ctx, cmds, c._processPipeline)
1051}
1052
1053func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error {
1054	cmdsMap := newCmdsMap()
1055	err := c.mapCmdsByNode(cmdsMap, cmds)
1056	if err != nil {
1057		setCmdsErr(cmds, err)
1058		return err
1059	}
1060
1061	for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ {
1062		if attempt > 0 {
1063			if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
1064				setCmdsErr(cmds, err)
1065				return err
1066			}
1067		}
1068
1069		failedCmds := newCmdsMap()
1070		var wg sync.WaitGroup
1071
1072		for node, cmds := range cmdsMap.m {
1073			wg.Add(1)
1074			go func(node *clusterNode, cmds []Cmder) {
1075				defer wg.Done()
1076
1077				err := c._processPipelineNode(ctx, node, cmds, failedCmds)
1078				if err == nil {
1079					return
1080				}
1081				if attempt < c.opt.MaxRedirects {
1082					if err := c.mapCmdsByNode(failedCmds, cmds); err != nil {
1083						setCmdsErr(cmds, err)
1084					}
1085				} else {
1086					setCmdsErr(cmds, err)
1087				}
1088			}(node, cmds)
1089		}
1090
1091		wg.Wait()
1092		if len(failedCmds.m) == 0 {
1093			break
1094		}
1095		cmdsMap = failedCmds
1096	}
1097
1098	return cmdsFirstErr(cmds)
1099}
1100
1101func (c *ClusterClient) mapCmdsByNode(cmdsMap *cmdsMap, cmds []Cmder) error {
1102	state, err := c.state.Get()
1103	if err != nil {
1104		return err
1105	}
1106
1107	if c.opt.ReadOnly && c.cmdsAreReadOnly(cmds) {
1108		for _, cmd := range cmds {
1109			slot := c.cmdSlot(cmd)
1110			node, err := c.slotReadOnlyNode(state, slot)
1111			if err != nil {
1112				return err
1113			}
1114			cmdsMap.Add(node, cmd)
1115		}
1116		return nil
1117	}
1118
1119	for _, cmd := range cmds {
1120		slot := c.cmdSlot(cmd)
1121		node, err := state.slotMasterNode(slot)
1122		if err != nil {
1123			return err
1124		}
1125		cmdsMap.Add(node, cmd)
1126	}
1127	return nil
1128}
1129
1130func (c *ClusterClient) cmdsAreReadOnly(cmds []Cmder) bool {
1131	for _, cmd := range cmds {
1132		cmdInfo := c.cmdInfo(cmd.Name())
1133		if cmdInfo == nil || !cmdInfo.ReadOnly {
1134			return false
1135		}
1136	}
1137	return true
1138}
1139
1140func (c *ClusterClient) _processPipelineNode(
1141	ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap,
1142) error {
1143	return node.Client.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
1144		return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
1145			err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
1146				return writeCmds(wr, cmds)
1147			})
1148			if err != nil {
1149				return err
1150			}
1151
1152			return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
1153				return c.pipelineReadCmds(node, rd, cmds, failedCmds)
1154			})
1155		})
1156	})
1157}
1158
1159func (c *ClusterClient) pipelineReadCmds(
1160	node *clusterNode, rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap,
1161) error {
1162	for _, cmd := range cmds {
1163		err := cmd.readReply(rd)
1164		if err == nil {
1165			continue
1166		}
1167		if c.checkMovedErr(cmd, err, failedCmds) {
1168			continue
1169		}
1170
1171		if c.opt.ReadOnly && isLoadingError(err) {
1172			node.MarkAsFailing()
1173			return err
1174		}
1175		if isRedisError(err) {
1176			continue
1177		}
1178		return err
1179	}
1180	return nil
1181}
1182
1183func (c *ClusterClient) checkMovedErr(
1184	cmd Cmder, err error, failedCmds *cmdsMap,
1185) bool {
1186	moved, ask, addr := isMovedError(err)
1187	if !moved && !ask {
1188		return false
1189	}
1190
1191	node, err := c.nodes.Get(addr)
1192	if err != nil {
1193		return false
1194	}
1195
1196	if moved {
1197		c.state.LazyReload()
1198		failedCmds.Add(node, cmd)
1199		return true
1200	}
1201
1202	if ask {
1203		failedCmds.Add(node, NewCmd("asking"), cmd)
1204		return true
1205	}
1206
1207	panic("not reached")
1208}
1209
1210// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
1211func (c *ClusterClient) TxPipeline() Pipeliner {
1212	pipe := Pipeline{
1213		ctx:  c.ctx,
1214		exec: c.processTxPipeline,
1215	}
1216	pipe.init()
1217	return &pipe
1218}
1219
1220func (c *ClusterClient) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
1221	return c.TxPipeline().Pipelined(fn)
1222}
1223
1224func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
1225	return c.hooks.processPipeline(ctx, cmds, c._processTxPipeline)
1226}
1227
1228func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) error {
1229	state, err := c.state.Get()
1230	if err != nil {
1231		setCmdsErr(cmds, err)
1232		return err
1233	}
1234
1235	cmdsMap := c.mapCmdsBySlot(cmds)
1236	for slot, cmds := range cmdsMap {
1237		node, err := state.slotMasterNode(slot)
1238		if err != nil {
1239			setCmdsErr(cmds, err)
1240			continue
1241		}
1242
1243		cmdsMap := map[*clusterNode][]Cmder{node: cmds}
1244		for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ {
1245			if attempt > 0 {
1246				if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
1247					setCmdsErr(cmds, err)
1248					return err
1249				}
1250			}
1251
1252			failedCmds := newCmdsMap()
1253			var wg sync.WaitGroup
1254
1255			for node, cmds := range cmdsMap {
1256				wg.Add(1)
1257				go func(node *clusterNode, cmds []Cmder) {
1258					defer wg.Done()
1259
1260					err := c._processTxPipelineNode(ctx, node, cmds, failedCmds)
1261					if err == nil {
1262						return
1263					}
1264					if attempt < c.opt.MaxRedirects {
1265						if err := c.mapCmdsByNode(failedCmds, cmds); err != nil {
1266							setCmdsErr(cmds, err)
1267						}
1268					} else {
1269						setCmdsErr(cmds, err)
1270					}
1271				}(node, cmds)
1272			}
1273
1274			wg.Wait()
1275			if len(failedCmds.m) == 0 {
1276				break
1277			}
1278			cmdsMap = failedCmds.m
1279		}
1280	}
1281
1282	return cmdsFirstErr(cmds)
1283}
1284
1285func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder {
1286	cmdsMap := make(map[int][]Cmder)
1287	for _, cmd := range cmds {
1288		slot := c.cmdSlot(cmd)
1289		cmdsMap[slot] = append(cmdsMap[slot], cmd)
1290	}
1291	return cmdsMap
1292}
1293
1294func (c *ClusterClient) _processTxPipelineNode(
1295	ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap,
1296) error {
1297	return node.Client.hooks.processTxPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
1298		return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
1299			err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
1300				return writeCmds(wr, cmds)
1301			})
1302			if err != nil {
1303				return err
1304			}
1305
1306			return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
1307				statusCmd := cmds[0].(*StatusCmd)
1308				// Trim multi and exec.
1309				cmds = cmds[1 : len(cmds)-1]
1310
1311				err := c.txPipelineReadQueued(rd, statusCmd, cmds, failedCmds)
1312				if err != nil {
1313					moved, ask, addr := isMovedError(err)
1314					if moved || ask {
1315						return c.cmdsMoved(cmds, moved, ask, addr, failedCmds)
1316					}
1317					return err
1318				}
1319
1320				return pipelineReadCmds(rd, cmds)
1321			})
1322		})
1323	})
1324}
1325
1326func (c *ClusterClient) txPipelineReadQueued(
1327	rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder, failedCmds *cmdsMap,
1328) error {
1329	// Parse queued replies.
1330	if err := statusCmd.readReply(rd); err != nil {
1331		return err
1332	}
1333
1334	for _, cmd := range cmds {
1335		err := statusCmd.readReply(rd)
1336		if err == nil || c.checkMovedErr(cmd, err, failedCmds) || isRedisError(err) {
1337			continue
1338		}
1339		return err
1340	}
1341
1342	// Parse number of replies.
1343	line, err := rd.ReadLine()
1344	if err != nil {
1345		if err == Nil {
1346			err = TxFailedErr
1347		}
1348		return err
1349	}
1350
1351	switch line[0] {
1352	case proto.ErrorReply:
1353		return proto.ParseErrorReply(line)
1354	case proto.ArrayReply:
1355		// ok
1356	default:
1357		return fmt.Errorf("redis: expected '*', but got line %q", line)
1358	}
1359
1360	return nil
1361}
1362
1363func (c *ClusterClient) cmdsMoved(
1364	cmds []Cmder, moved, ask bool, addr string, failedCmds *cmdsMap,
1365) error {
1366	node, err := c.nodes.Get(addr)
1367	if err != nil {
1368		return err
1369	}
1370
1371	if moved {
1372		c.state.LazyReload()
1373		for _, cmd := range cmds {
1374			failedCmds.Add(node, cmd)
1375		}
1376		return nil
1377	}
1378
1379	if ask {
1380		for _, cmd := range cmds {
1381			failedCmds.Add(node, NewCmd("asking"), cmd)
1382		}
1383		return nil
1384	}
1385
1386	return nil
1387}
1388
1389func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error {
1390	return c.WatchContext(c.ctx, fn, keys...)
1391}
1392
1393func (c *ClusterClient) WatchContext(ctx context.Context, fn func(*Tx) error, keys ...string) error {
1394	if len(keys) == 0 {
1395		return fmt.Errorf("redis: Watch requires at least one key")
1396	}
1397
1398	slot := hashtag.Slot(keys[0])
1399	for _, key := range keys[1:] {
1400		if hashtag.Slot(key) != slot {
1401			err := fmt.Errorf("redis: Watch requires all keys to be in the same slot")
1402			return err
1403		}
1404	}
1405
1406	node, err := c.slotMasterNode(slot)
1407	if err != nil {
1408		return err
1409	}
1410
1411	for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ {
1412		if attempt > 0 {
1413			if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
1414				return err
1415			}
1416		}
1417
1418		err = node.Client.WatchContext(ctx, fn, keys...)
1419		if err == nil {
1420			break
1421		}
1422		if err != Nil {
1423			c.state.LazyReload()
1424		}
1425
1426		moved, ask, addr := isMovedError(err)
1427		if moved || ask {
1428			node, err = c.nodes.Get(addr)
1429			if err != nil {
1430				return err
1431			}
1432			continue
1433		}
1434
1435		if err == pool.ErrClosed || isReadOnlyError(err) {
1436			node, err = c.slotMasterNode(slot)
1437			if err != nil {
1438				return err
1439			}
1440			continue
1441		}
1442
1443		if isRetryableError(err, true) {
1444			continue
1445		}
1446
1447		return err
1448	}
1449
1450	return err
1451}
1452
1453func (c *ClusterClient) pubSub() *PubSub {
1454	var node *clusterNode
1455	pubsub := &PubSub{
1456		opt: c.opt.clientOptions(),
1457
1458		newConn: func(channels []string) (*pool.Conn, error) {
1459			if node != nil {
1460				panic("node != nil")
1461			}
1462
1463			var err error
1464			if len(channels) > 0 {
1465				slot := hashtag.Slot(channels[0])
1466				node, err = c.slotMasterNode(slot)
1467			} else {
1468				node, err = c.nodes.Random()
1469			}
1470			if err != nil {
1471				return nil, err
1472			}
1473
1474			cn, err := node.Client.newConn(context.TODO())
1475			if err != nil {
1476				node = nil
1477
1478				return nil, err
1479			}
1480
1481			return cn, nil
1482		},
1483		closeConn: func(cn *pool.Conn) error {
1484			err := node.Client.connPool.CloseConn(cn)
1485			node = nil
1486			return err
1487		},
1488	}
1489	pubsub.init()
1490
1491	return pubsub
1492}
1493
1494// Subscribe subscribes the client to the specified channels.
1495// Channels can be omitted to create empty subscription.
1496func (c *ClusterClient) Subscribe(channels ...string) *PubSub {
1497	pubsub := c.pubSub()
1498	if len(channels) > 0 {
1499		_ = pubsub.Subscribe(channels...)
1500	}
1501	return pubsub
1502}
1503
1504// PSubscribe subscribes the client to the given patterns.
1505// Patterns can be omitted to create empty subscription.
1506func (c *ClusterClient) PSubscribe(channels ...string) *PubSub {
1507	pubsub := c.pubSub()
1508	if len(channels) > 0 {
1509		_ = pubsub.PSubscribe(channels...)
1510	}
1511	return pubsub
1512}
1513
1514func (c *ClusterClient) retryBackoff(attempt int) time.Duration {
1515	return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
1516}
1517
1518func (c *ClusterClient) cmdsInfo() (map[string]*CommandInfo, error) {
1519	addrs, err := c.nodes.Addrs()
1520	if err != nil {
1521		return nil, err
1522	}
1523
1524	var firstErr error
1525	for _, addr := range addrs {
1526		node, err := c.nodes.Get(addr)
1527		if err != nil {
1528			return nil, err
1529		}
1530		if node == nil {
1531			continue
1532		}
1533
1534		info, err := node.Client.Command().Result()
1535		if err == nil {
1536			return info, nil
1537		}
1538		if firstErr == nil {
1539			firstErr = err
1540		}
1541	}
1542	return nil, firstErr
1543}
1544
1545func (c *ClusterClient) cmdInfo(name string) *CommandInfo {
1546	cmdsInfo, err := c.cmdsInfoCache.Get()
1547	if err != nil {
1548		return nil
1549	}
1550
1551	info := cmdsInfo[name]
1552	if info == nil {
1553		internal.Logger.Printf("info for cmd=%s not found", name)
1554	}
1555	return info
1556}
1557
1558func (c *ClusterClient) cmdSlot(cmd Cmder) int {
1559	args := cmd.Args()
1560	if args[0] == "cluster" && args[1] == "getkeysinslot" {
1561		return args[2].(int)
1562	}
1563
1564	cmdInfo := c.cmdInfo(cmd.Name())
1565	return cmdSlot(cmd, cmdFirstKeyPos(cmd, cmdInfo))
1566}
1567
1568func cmdSlot(cmd Cmder, pos int) int {
1569	if pos == 0 {
1570		return hashtag.RandomSlot()
1571	}
1572	firstKey := cmd.stringArg(pos)
1573	return hashtag.Slot(firstKey)
1574}
1575
1576func (c *ClusterClient) cmdNode(cmdInfo *CommandInfo, slot int) (*clusterNode, error) {
1577	state, err := c.state.Get()
1578	if err != nil {
1579		return nil, err
1580	}
1581
1582	if c.opt.ReadOnly && cmdInfo != nil && cmdInfo.ReadOnly {
1583		return c.slotReadOnlyNode(state, slot)
1584	}
1585	return state.slotMasterNode(slot)
1586}
1587
1588func (c *clusterClient) slotReadOnlyNode(state *clusterState, slot int) (*clusterNode, error) {
1589	if c.opt.RouteByLatency {
1590		return state.slotClosestNode(slot)
1591	}
1592	if c.opt.RouteRandomly {
1593		return state.slotRandomNode(slot)
1594	}
1595	return state.slotSlaveNode(slot)
1596}
1597
1598func (c *ClusterClient) slotMasterNode(slot int) (*clusterNode, error) {
1599	state, err := c.state.Get()
1600	if err != nil {
1601		return nil, err
1602	}
1603	return state.slotMasterNode(slot)
1604}
1605
1606func appendUniqueNode(nodes []*clusterNode, node *clusterNode) []*clusterNode {
1607	for _, n := range nodes {
1608		if n == node {
1609			return nodes
1610		}
1611	}
1612	return append(nodes, node)
1613}
1614
1615func appendIfNotExists(ss []string, es ...string) []string {
1616loop:
1617	for _, e := range es {
1618		for _, s := range ss {
1619			if s == e {
1620				continue loop
1621			}
1622		}
1623		ss = append(ss, e)
1624	}
1625	return ss
1626}
1627
1628func remove(ss []string, es ...string) []string {
1629	if len(es) == 0 {
1630		return ss[:0]
1631	}
1632	for _, e := range es {
1633		for i, s := range ss {
1634			if s == e {
1635				ss = append(ss[:i], ss[i+1:]...)
1636				break
1637			}
1638		}
1639	}
1640	return ss
1641}
1642
1643//------------------------------------------------------------------------------
1644
1645type cmdsMap struct {
1646	mu sync.Mutex
1647	m  map[*clusterNode][]Cmder
1648}
1649
1650func newCmdsMap() *cmdsMap {
1651	return &cmdsMap{
1652		m: make(map[*clusterNode][]Cmder),
1653	}
1654}
1655
1656func (m *cmdsMap) Add(node *clusterNode, cmds ...Cmder) {
1657	m.mu.Lock()
1658	m.m[node] = append(m.m[node], cmds...)
1659	m.mu.Unlock()
1660}
1661