1package gocql
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"net"
8	"strconv"
9	"strings"
10	"sync"
11	"time"
12)
13
14type nodeState int32
15
16func (n nodeState) String() string {
17	if n == NodeUp {
18		return "UP"
19	} else if n == NodeDown {
20		return "DOWN"
21	}
22	return fmt.Sprintf("UNKNOWN_%d", n)
23}
24
25const (
26	NodeUp nodeState = iota
27	NodeDown
28)
29
30type cassVersion struct {
31	Major, Minor, Patch int
32}
33
34func (c *cassVersion) Set(v string) error {
35	if v == "" {
36		return nil
37	}
38
39	return c.UnmarshalCQL(nil, []byte(v))
40}
41
42func (c *cassVersion) UnmarshalCQL(info TypeInfo, data []byte) error {
43	return c.unmarshal(data)
44}
45
46func (c *cassVersion) unmarshal(data []byte) error {
47	version := strings.TrimSuffix(string(data), "-SNAPSHOT")
48	version = strings.TrimPrefix(version, "v")
49	v := strings.Split(version, ".")
50
51	if len(v) < 2 {
52		return fmt.Errorf("invalid version string: %s", data)
53	}
54
55	var err error
56	c.Major, err = strconv.Atoi(v[0])
57	if err != nil {
58		return fmt.Errorf("invalid major version %v: %v", v[0], err)
59	}
60
61	c.Minor, err = strconv.Atoi(v[1])
62	if err != nil {
63		return fmt.Errorf("invalid minor version %v: %v", v[1], err)
64	}
65
66	if len(v) > 2 {
67		c.Patch, err = strconv.Atoi(v[2])
68		if err != nil {
69			return fmt.Errorf("invalid patch version %v: %v", v[2], err)
70		}
71	}
72
73	return nil
74}
75
76func (c cassVersion) Before(major, minor, patch int) bool {
77	// We're comparing us (cassVersion) with the provided version (major, minor, patch)
78	// We return true if our version is lower (comes before) than the provided one.
79	if c.Major < major {
80		return true
81	} else if c.Major == major {
82		if c.Minor < minor {
83			return true
84		} else if c.Minor == minor && c.Patch < patch {
85			return true
86		}
87
88	}
89	return false
90}
91
92func (c cassVersion) AtLeast(major, minor, patch int) bool {
93	return !c.Before(major, minor, patch)
94}
95
96func (c cassVersion) String() string {
97	return fmt.Sprintf("v%d.%d.%d", c.Major, c.Minor, c.Patch)
98}
99
100func (c cassVersion) nodeUpDelay() time.Duration {
101	if c.Major >= 2 && c.Minor >= 2 {
102		// CASSANDRA-8236
103		return 0
104	}
105
106	return 10 * time.Second
107}
108
109type HostInfo struct {
110	// TODO(zariel): reduce locking maybe, not all values will change, but to ensure
111	// that we are thread safe use a mutex to access all fields.
112	mu               sync.RWMutex
113	peer             net.IP
114	broadcastAddress net.IP
115	listenAddress    net.IP
116	rpcAddress       net.IP
117	preferredIP      net.IP
118	connectAddress   net.IP
119	port             int
120	dataCenter       string
121	rack             string
122	hostId           string
123	workload         string
124	graph            bool
125	dseVersion       string
126	partitioner      string
127	clusterName      string
128	version          cassVersion
129	state            nodeState
130	tokens           []string
131}
132
133func (h *HostInfo) Equal(host *HostInfo) bool {
134	if h == host {
135		// prevent rlock reentry
136		return true
137	}
138
139	return h.ConnectAddress().Equal(host.ConnectAddress())
140}
141
142func (h *HostInfo) Peer() net.IP {
143	h.mu.RLock()
144	defer h.mu.RUnlock()
145	return h.peer
146}
147
148func (h *HostInfo) setPeer(peer net.IP) *HostInfo {
149	h.mu.Lock()
150	defer h.mu.Unlock()
151	h.peer = peer
152	return h
153}
154
155func (h *HostInfo) invalidConnectAddr() bool {
156	h.mu.RLock()
157	defer h.mu.RUnlock()
158	addr, _ := h.connectAddressLocked()
159	return !validIpAddr(addr)
160}
161
162func validIpAddr(addr net.IP) bool {
163	return addr != nil && !addr.IsUnspecified()
164}
165
166func (h *HostInfo) connectAddressLocked() (net.IP, string) {
167	if validIpAddr(h.connectAddress) {
168		return h.connectAddress, "connect_address"
169	} else if validIpAddr(h.rpcAddress) {
170		return h.rpcAddress, "rpc_adress"
171	} else if validIpAddr(h.preferredIP) {
172		// where does perferred_ip get set?
173		return h.preferredIP, "preferred_ip"
174	} else if validIpAddr(h.broadcastAddress) {
175		return h.broadcastAddress, "broadcast_address"
176	} else if validIpAddr(h.peer) {
177		return h.peer, "peer"
178	}
179	return net.IPv4zero, "invalid"
180}
181
182// Returns the address that should be used to connect to the host.
183// If you wish to override this, use an AddressTranslator or
184// use a HostFilter to SetConnectAddress()
185func (h *HostInfo) ConnectAddress() net.IP {
186	h.mu.RLock()
187	defer h.mu.RUnlock()
188
189	if addr, _ := h.connectAddressLocked(); validIpAddr(addr) {
190		return addr
191	}
192	panic(fmt.Sprintf("no valid connect address for host: %v. Is your cluster configured correctly?", h))
193}
194
195func (h *HostInfo) SetConnectAddress(address net.IP) *HostInfo {
196	// TODO(zariel): should this not be exported?
197	h.mu.Lock()
198	defer h.mu.Unlock()
199	h.connectAddress = address
200	return h
201}
202
203func (h *HostInfo) BroadcastAddress() net.IP {
204	h.mu.RLock()
205	defer h.mu.RUnlock()
206	return h.broadcastAddress
207}
208
209func (h *HostInfo) ListenAddress() net.IP {
210	h.mu.RLock()
211	defer h.mu.RUnlock()
212	return h.listenAddress
213}
214
215func (h *HostInfo) RPCAddress() net.IP {
216	h.mu.RLock()
217	defer h.mu.RUnlock()
218	return h.rpcAddress
219}
220
221func (h *HostInfo) PreferredIP() net.IP {
222	h.mu.RLock()
223	defer h.mu.RUnlock()
224	return h.preferredIP
225}
226
227func (h *HostInfo) DataCenter() string {
228	h.mu.RLock()
229	defer h.mu.RUnlock()
230	return h.dataCenter
231}
232
233func (h *HostInfo) setDataCenter(dataCenter string) *HostInfo {
234	h.mu.Lock()
235	defer h.mu.Unlock()
236	h.dataCenter = dataCenter
237	return h
238}
239
240func (h *HostInfo) Rack() string {
241	h.mu.RLock()
242	defer h.mu.RUnlock()
243	return h.rack
244}
245
246func (h *HostInfo) setRack(rack string) *HostInfo {
247	h.mu.Lock()
248	defer h.mu.Unlock()
249	h.rack = rack
250	return h
251}
252
253func (h *HostInfo) HostID() string {
254	h.mu.RLock()
255	defer h.mu.RUnlock()
256	return h.hostId
257}
258
259func (h *HostInfo) setHostID(hostID string) *HostInfo {
260	h.mu.Lock()
261	defer h.mu.Unlock()
262	h.hostId = hostID
263	return h
264}
265
266func (h *HostInfo) WorkLoad() string {
267	h.mu.RLock()
268	defer h.mu.RUnlock()
269	return h.workload
270}
271
272func (h *HostInfo) Graph() bool {
273	h.mu.RLock()
274	defer h.mu.RUnlock()
275	return h.graph
276}
277
278func (h *HostInfo) DSEVersion() string {
279	h.mu.RLock()
280	defer h.mu.RUnlock()
281	return h.dseVersion
282}
283
284func (h *HostInfo) Partitioner() string {
285	h.mu.RLock()
286	defer h.mu.RUnlock()
287	return h.partitioner
288}
289
290func (h *HostInfo) ClusterName() string {
291	h.mu.RLock()
292	defer h.mu.RUnlock()
293	return h.clusterName
294}
295
296func (h *HostInfo) Version() cassVersion {
297	h.mu.RLock()
298	defer h.mu.RUnlock()
299	return h.version
300}
301
302func (h *HostInfo) setVersion(major, minor, patch int) *HostInfo {
303	h.mu.Lock()
304	defer h.mu.Unlock()
305	h.version = cassVersion{major, minor, patch}
306	return h
307}
308
309func (h *HostInfo) State() nodeState {
310	h.mu.RLock()
311	defer h.mu.RUnlock()
312	return h.state
313}
314
315func (h *HostInfo) setState(state nodeState) *HostInfo {
316	h.mu.Lock()
317	defer h.mu.Unlock()
318	h.state = state
319	return h
320}
321
322func (h *HostInfo) Tokens() []string {
323	h.mu.RLock()
324	defer h.mu.RUnlock()
325	return h.tokens
326}
327
328func (h *HostInfo) setTokens(tokens []string) *HostInfo {
329	h.mu.Lock()
330	defer h.mu.Unlock()
331	h.tokens = tokens
332	return h
333}
334
335func (h *HostInfo) Port() int {
336	h.mu.RLock()
337	defer h.mu.RUnlock()
338	return h.port
339}
340
341func (h *HostInfo) setPort(port int) *HostInfo {
342	h.mu.Lock()
343	defer h.mu.Unlock()
344	h.port = port
345	return h
346}
347
348func (h *HostInfo) update(from *HostInfo) {
349	if h == from {
350		return
351	}
352
353	h.mu.Lock()
354	defer h.mu.Unlock()
355
356	from.mu.RLock()
357	defer from.mu.RUnlock()
358
359	// autogenerated do not update
360	if h.peer == nil {
361		h.peer = from.peer
362	}
363	if h.broadcastAddress == nil {
364		h.broadcastAddress = from.broadcastAddress
365	}
366	if h.listenAddress == nil {
367		h.listenAddress = from.listenAddress
368	}
369	if h.rpcAddress == nil {
370		h.rpcAddress = from.rpcAddress
371	}
372	if h.preferredIP == nil {
373		h.preferredIP = from.preferredIP
374	}
375	if h.connectAddress == nil {
376		h.connectAddress = from.connectAddress
377	}
378	if h.port == 0 {
379		h.port = from.port
380	}
381	if h.dataCenter == "" {
382		h.dataCenter = from.dataCenter
383	}
384	if h.rack == "" {
385		h.rack = from.rack
386	}
387	if h.hostId == "" {
388		h.hostId = from.hostId
389	}
390	if h.workload == "" {
391		h.workload = from.workload
392	}
393	if h.dseVersion == "" {
394		h.dseVersion = from.dseVersion
395	}
396	if h.partitioner == "" {
397		h.partitioner = from.partitioner
398	}
399	if h.clusterName == "" {
400		h.clusterName = from.clusterName
401	}
402	if h.version == (cassVersion{}) {
403		h.version = from.version
404	}
405	if h.tokens == nil {
406		h.tokens = from.tokens
407	}
408}
409
410func (h *HostInfo) IsUp() bool {
411	return h != nil && h.State() == NodeUp
412}
413
414func (h *HostInfo) String() string {
415	h.mu.RLock()
416	defer h.mu.RUnlock()
417
418	connectAddr, source := h.connectAddressLocked()
419	return fmt.Sprintf("[HostInfo connectAddress=%q peer=%q rpc_address=%q broadcast_address=%q "+
420		"preferred_ip=%q connect_addr=%q connect_addr_source=%q "+
421		"port=%d data_centre=%q rack=%q host_id=%q version=%q state=%s num_tokens=%d]",
422		h.connectAddress, h.peer, h.rpcAddress, h.broadcastAddress, h.preferredIP,
423		connectAddr, source,
424		h.port, h.dataCenter, h.rack, h.hostId, h.version, h.state, len(h.tokens))
425}
426
427// Polls system.peers at a specific interval to find new hosts
428type ringDescriber struct {
429	session         *Session
430	mu              sync.Mutex
431	prevHosts       []*HostInfo
432	prevPartitioner string
433}
434
435// Returns true if we are using system_schema.keyspaces instead of system.schema_keyspaces
436func checkSystemSchema(control *controlConn) (bool, error) {
437	iter := control.query("SELECT * FROM system_schema.keyspaces")
438	if err := iter.err; err != nil {
439		if errf, ok := err.(*errorFrame); ok {
440			if errf.code == errSyntax {
441				return false, nil
442			}
443		}
444
445		return false, err
446	}
447
448	return true, nil
449}
450
451// Given a map that represents a row from either system.local or system.peers
452// return as much information as we can in *HostInfo
453func (s *Session) hostInfoFromMap(row map[string]interface{}, port int) (*HostInfo, error) {
454	const assertErrorMsg = "Assertion failed for %s"
455	var ok bool
456
457	// Default to our connected port if the cluster doesn't have port information
458	host := HostInfo{
459		port: port,
460	}
461
462	for key, value := range row {
463		switch key {
464		case "data_center":
465			host.dataCenter, ok = value.(string)
466			if !ok {
467				return nil, fmt.Errorf(assertErrorMsg, "data_center")
468			}
469		case "rack":
470			host.rack, ok = value.(string)
471			if !ok {
472				return nil, fmt.Errorf(assertErrorMsg, "rack")
473			}
474		case "host_id":
475			hostId, ok := value.(UUID)
476			if !ok {
477				return nil, fmt.Errorf(assertErrorMsg, "host_id")
478			}
479			host.hostId = hostId.String()
480		case "release_version":
481			version, ok := value.(string)
482			if !ok {
483				return nil, fmt.Errorf(assertErrorMsg, "release_version")
484			}
485			host.version.Set(version)
486		case "peer":
487			ip, ok := value.(string)
488			if !ok {
489				return nil, fmt.Errorf(assertErrorMsg, "peer")
490			}
491			host.peer = net.ParseIP(ip)
492		case "cluster_name":
493			host.clusterName, ok = value.(string)
494			if !ok {
495				return nil, fmt.Errorf(assertErrorMsg, "cluster_name")
496			}
497		case "partitioner":
498			host.partitioner, ok = value.(string)
499			if !ok {
500				return nil, fmt.Errorf(assertErrorMsg, "partitioner")
501			}
502		case "broadcast_address":
503			ip, ok := value.(string)
504			if !ok {
505				return nil, fmt.Errorf(assertErrorMsg, "broadcast_address")
506			}
507			host.broadcastAddress = net.ParseIP(ip)
508		case "preferred_ip":
509			ip, ok := value.(string)
510			if !ok {
511				return nil, fmt.Errorf(assertErrorMsg, "preferred_ip")
512			}
513			host.preferredIP = net.ParseIP(ip)
514		case "rpc_address":
515			ip, ok := value.(string)
516			if !ok {
517				return nil, fmt.Errorf(assertErrorMsg, "rpc_address")
518			}
519			host.rpcAddress = net.ParseIP(ip)
520		case "listen_address":
521			ip, ok := value.(string)
522			if !ok {
523				return nil, fmt.Errorf(assertErrorMsg, "listen_address")
524			}
525			host.listenAddress = net.ParseIP(ip)
526		case "workload":
527			host.workload, ok = value.(string)
528			if !ok {
529				return nil, fmt.Errorf(assertErrorMsg, "workload")
530			}
531		case "graph":
532			host.graph, ok = value.(bool)
533			if !ok {
534				return nil, fmt.Errorf(assertErrorMsg, "graph")
535			}
536		case "tokens":
537			host.tokens, ok = value.([]string)
538			if !ok {
539				return nil, fmt.Errorf(assertErrorMsg, "tokens")
540			}
541		case "dse_version":
542			host.dseVersion, ok = value.(string)
543			if !ok {
544				return nil, fmt.Errorf(assertErrorMsg, "dse_version")
545			}
546		}
547		// TODO(thrawn01): Add 'port'? once CASSANDRA-7544 is complete
548		// Not sure what the port field will be called until the JIRA issue is complete
549	}
550
551	ip, port := s.cfg.translateAddressPort(host.ConnectAddress(), host.port)
552	host.connectAddress = ip
553	host.port = port
554
555	return &host, nil
556}
557
558// Ask the control node for host info on all it's known peers
559func (r *ringDescriber) getClusterPeerInfo() ([]*HostInfo, error) {
560	var hosts []*HostInfo
561	iter := r.session.control.withConnHost(func(ch *connHost) *Iter {
562		hosts = append(hosts, ch.host)
563		return ch.conn.query(context.TODO(), "SELECT * FROM system.peers")
564	})
565
566	if iter == nil {
567		return nil, errNoControl
568	}
569
570	rows, err := iter.SliceMap()
571	if err != nil {
572		// TODO(zariel): make typed error
573		return nil, fmt.Errorf("unable to fetch peer host info: %s", err)
574	}
575
576	for _, row := range rows {
577		// extract all available info about the peer
578		host, err := r.session.hostInfoFromMap(row, r.session.cfg.Port)
579		if err != nil {
580			return nil, err
581		} else if !isValidPeer(host) {
582			// If it's not a valid peer
583			Logger.Printf("Found invalid peer '%s' "+
584				"Likely due to a gossip or snitch issue, this host will be ignored", host)
585			continue
586		}
587
588		hosts = append(hosts, host)
589	}
590
591	return hosts, nil
592}
593
594// Return true if the host is a valid peer
595func isValidPeer(host *HostInfo) bool {
596	return !(len(host.RPCAddress()) == 0 ||
597		host.hostId == "" ||
598		host.dataCenter == "" ||
599		host.rack == "" ||
600		len(host.tokens) == 0)
601}
602
603// Return a list of hosts the cluster knows about
604func (r *ringDescriber) GetHosts() ([]*HostInfo, string, error) {
605	r.mu.Lock()
606	defer r.mu.Unlock()
607
608	hosts, err := r.getClusterPeerInfo()
609	if err != nil {
610		return r.prevHosts, r.prevPartitioner, err
611	}
612
613	var partitioner string
614	if len(hosts) > 0 {
615		partitioner = hosts[0].Partitioner()
616	}
617
618	return hosts, partitioner, nil
619}
620
621// Given an ip/port return HostInfo for the specified ip/port
622func (r *ringDescriber) getHostInfo(ip net.IP, port int) (*HostInfo, error) {
623	var host *HostInfo
624	iter := r.session.control.withConnHost(func(ch *connHost) *Iter {
625		if ch.host.ConnectAddress().Equal(ip) {
626			host = ch.host
627			return nil
628		}
629
630		return ch.conn.query(context.TODO(), "SELECT * FROM system.peers")
631	})
632
633	if iter != nil {
634		rows, err := iter.SliceMap()
635		if err != nil {
636			return nil, err
637		}
638
639		for _, row := range rows {
640			h, err := r.session.hostInfoFromMap(row, port)
641			if err != nil {
642				return nil, err
643			}
644
645			if h.ConnectAddress().Equal(ip) {
646				host = h
647				break
648			}
649		}
650
651		if host == nil {
652			return nil, errors.New("host not found in peers table")
653		}
654	}
655
656	if host == nil {
657		return nil, errors.New("unable to fetch host info: invalid control connection")
658	} else if host.invalidConnectAddr() {
659		return nil, fmt.Errorf("host ConnectAddress invalid ip=%v: %v", ip, host)
660	}
661
662	return host, nil
663}
664
665func (r *ringDescriber) refreshRing() error {
666	// if we have 0 hosts this will return the previous list of hosts to
667	// attempt to reconnect to the cluster otherwise we would never find
668	// downed hosts again, could possibly have an optimisation to only
669	// try to add new hosts if GetHosts didnt error and the hosts didnt change.
670	hosts, partitioner, err := r.GetHosts()
671	if err != nil {
672		return err
673	}
674
675	prevHosts := r.session.ring.currentHosts()
676
677	// TODO: move this to session
678	for _, h := range hosts {
679		if filter := r.session.cfg.HostFilter; filter != nil && !filter.Accept(h) {
680			continue
681		}
682
683		if host, ok := r.session.ring.addHostIfMissing(h); !ok {
684			r.session.pool.addHost(h)
685			r.session.policy.AddHost(h)
686		} else {
687			host.update(h)
688		}
689		delete(prevHosts, h.ConnectAddress().String())
690	}
691
692	// TODO(zariel): it may be worth having a mutex covering the overall ring state
693	// in a session so that everything sees a consistent state. Becuase as is today
694	// events can come in and due to ordering an UP host could be removed from the cluster
695	for _, host := range prevHosts {
696		r.session.removeHost(host)
697	}
698
699	r.session.metadata.setPartitioner(partitioner)
700	r.session.policy.SetPartitioner(partitioner)
701	return nil
702}
703