1package gocql
2
3import (
4	"context"
5	crand "crypto/rand"
6	"errors"
7	"fmt"
8	"math/rand"
9	"net"
10	"os"
11	"regexp"
12	"strconv"
13	"sync"
14	"sync/atomic"
15	"time"
16)
17
18var (
19	randr    *rand.Rand
20	mutRandr sync.Mutex
21)
22
23func init() {
24	b := make([]byte, 4)
25	if _, err := crand.Read(b); err != nil {
26		panic(fmt.Sprintf("unable to seed random number generator: %v", err))
27	}
28
29	randr = rand.New(rand.NewSource(int64(readInt(b))))
30}
31
32// Ensure that the atomic variable is aligned to a 64bit boundary
33// so that atomic operations can be applied on 32bit architectures.
34type controlConn struct {
35	started      int32
36	reconnecting int32
37
38	session *Session
39	conn    atomic.Value
40
41	retry RetryPolicy
42
43	quit chan struct{}
44}
45
46func createControlConn(session *Session) *controlConn {
47	control := &controlConn{
48		session: session,
49		quit:    make(chan struct{}),
50		retry:   &SimpleRetryPolicy{NumRetries: 3},
51	}
52
53	control.conn.Store((*connHost)(nil))
54
55	return control
56}
57
58func (c *controlConn) heartBeat() {
59	if !atomic.CompareAndSwapInt32(&c.started, 0, 1) {
60		return
61	}
62
63	sleepTime := 1 * time.Second
64	timer := time.NewTimer(sleepTime)
65	defer timer.Stop()
66
67	for {
68		timer.Reset(sleepTime)
69
70		select {
71		case <-c.quit:
72			return
73		case <-timer.C:
74		}
75
76		resp, err := c.writeFrame(&writeOptionsFrame{})
77		if err != nil {
78			goto reconn
79		}
80
81		switch resp.(type) {
82		case *supportedFrame:
83			// Everything ok
84			sleepTime = 5 * time.Second
85			continue
86		case error:
87			goto reconn
88		default:
89			panic(fmt.Sprintf("gocql: unknown frame in response to options: %T", resp))
90		}
91
92	reconn:
93		// try to connect a bit faster
94		sleepTime = 1 * time.Second
95		c.reconnect(true)
96		continue
97	}
98}
99
100var hostLookupPreferV4 = os.Getenv("GOCQL_HOST_LOOKUP_PREFER_V4") == "true"
101
102func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
103	var port int
104	host, portStr, err := net.SplitHostPort(addr)
105	if err != nil {
106		host = addr
107		port = defaultPort
108	} else {
109		port, err = strconv.Atoi(portStr)
110		if err != nil {
111			return nil, err
112		}
113	}
114
115	var hosts []*HostInfo
116
117	// Check if host is a literal IP address
118	if ip := net.ParseIP(host); ip != nil {
119		hosts = append(hosts, &HostInfo{connectAddress: ip, port: port})
120		return hosts, nil
121	}
122
123	// Look up host in DNS
124	ips, err := LookupIP(host)
125	if err != nil {
126		return nil, err
127	} else if len(ips) == 0 {
128		return nil, fmt.Errorf("No IP's returned from DNS lookup for %q", addr)
129	}
130
131	// Filter to v4 addresses if any present
132	if hostLookupPreferV4 {
133		var preferredIPs []net.IP
134		for _, v := range ips {
135			if v4 := v.To4(); v4 != nil {
136				preferredIPs = append(preferredIPs, v4)
137			}
138		}
139		if len(preferredIPs) != 0 {
140			ips = preferredIPs
141		}
142	}
143
144	for _, ip := range ips {
145		hosts = append(hosts, &HostInfo{connectAddress: ip, port: port})
146	}
147
148	return hosts, nil
149}
150
151func shuffleHosts(hosts []*HostInfo) []*HostInfo {
152	mutRandr.Lock()
153	perm := randr.Perm(len(hosts))
154	mutRandr.Unlock()
155	shuffled := make([]*HostInfo, len(hosts))
156
157	for i, host := range hosts {
158		shuffled[perm[i]] = host
159	}
160
161	return shuffled
162}
163
164func (c *controlConn) shuffleDial(endpoints []*HostInfo) (*Conn, error) {
165	// shuffle endpoints so not all drivers will connect to the same initial
166	// node.
167	shuffled := shuffleHosts(endpoints)
168
169	cfg := *c.session.connCfg
170	cfg.disableCoalesce = true
171
172	var err error
173	for _, host := range shuffled {
174		var conn *Conn
175		conn, err = c.session.dial(host, &cfg, c)
176		if err == nil {
177			return conn, nil
178		}
179
180		Logger.Printf("gocql: unable to dial control conn %v: %v\n", host.ConnectAddress(), err)
181	}
182
183	return nil, err
184}
185
186// this is going to be version dependant and a nightmare to maintain :(
187var protocolSupportRe = regexp.MustCompile(`the lowest supported version is \d+ and the greatest is (\d+)$`)
188
189func parseProtocolFromError(err error) int {
190	// I really wish this had the actual info in the error frame...
191	matches := protocolSupportRe.FindAllStringSubmatch(err.Error(), -1)
192	if len(matches) != 1 || len(matches[0]) != 2 {
193		if verr, ok := err.(*protocolError); ok {
194			return int(verr.frame.Header().version.version())
195		}
196		return 0
197	}
198
199	max, err := strconv.Atoi(matches[0][1])
200	if err != nil {
201		return 0
202	}
203
204	return max
205}
206
207func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
208	hosts = shuffleHosts(hosts)
209
210	connCfg := *c.session.connCfg
211	connCfg.ProtoVersion = 4 // TODO: define maxProtocol
212
213	handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) {
214		// we should never get here, but if we do it means we connected to a
215		// host successfully which means our attempted protocol version worked
216		if !closed {
217			c.Close()
218		}
219	})
220
221	var err error
222	for _, host := range hosts {
223		var conn *Conn
224		conn, err = c.session.dial(host, &connCfg, handler)
225		if conn != nil {
226			conn.Close()
227		}
228
229		if err == nil {
230			return connCfg.ProtoVersion, nil
231		}
232
233		if proto := parseProtocolFromError(err); proto > 0 {
234			return proto, nil
235		}
236	}
237
238	return 0, err
239}
240
241func (c *controlConn) connect(hosts []*HostInfo) error {
242	if len(hosts) == 0 {
243		return errors.New("control: no endpoints specified")
244	}
245
246	conn, err := c.shuffleDial(hosts)
247	if err != nil {
248		return fmt.Errorf("control: unable to connect to initial hosts: %v", err)
249	}
250
251	if err := c.setupConn(conn); err != nil {
252		conn.Close()
253		return fmt.Errorf("control: unable to setup connection: %v", err)
254	}
255
256	// we could fetch the initial ring here and update initial host data. So that
257	// when we return from here we have a ring topology ready to go.
258
259	go c.heartBeat()
260
261	return nil
262}
263
264type connHost struct {
265	conn *Conn
266	host *HostInfo
267}
268
269func (c *controlConn) setupConn(conn *Conn) error {
270	if err := c.registerEvents(conn); err != nil {
271		conn.Close()
272		return err
273	}
274
275	// TODO(zariel): do we need to fetch host info everytime
276	// the control conn connects? Surely we have it cached?
277	host, err := conn.localHostInfo(context.TODO())
278	if err != nil {
279		return err
280	}
281
282	ch := &connHost{
283		conn: conn,
284		host: host,
285	}
286
287	c.conn.Store(ch)
288	c.session.handleNodeUp(host.ConnectAddress(), host.Port(), false)
289
290	return nil
291}
292
293func (c *controlConn) registerEvents(conn *Conn) error {
294	var events []string
295
296	if !c.session.cfg.Events.DisableTopologyEvents {
297		events = append(events, "TOPOLOGY_CHANGE")
298	}
299	if !c.session.cfg.Events.DisableNodeStatusEvents {
300		events = append(events, "STATUS_CHANGE")
301	}
302	if !c.session.cfg.Events.DisableSchemaEvents {
303		events = append(events, "SCHEMA_CHANGE")
304	}
305
306	if len(events) == 0 {
307		return nil
308	}
309
310	framer, err := conn.exec(context.Background(),
311		&writeRegisterFrame{
312			events: events,
313		}, nil)
314	if err != nil {
315		return err
316	}
317
318	frame, err := framer.parseFrame()
319	if err != nil {
320		return err
321	} else if _, ok := frame.(*readyFrame); !ok {
322		return fmt.Errorf("unexpected frame in response to register: got %T: %v\n", frame, frame)
323	}
324
325	return nil
326}
327
328func (c *controlConn) reconnect(refreshring bool) {
329	if !atomic.CompareAndSwapInt32(&c.reconnecting, 0, 1) {
330		return
331	}
332	defer atomic.StoreInt32(&c.reconnecting, 0)
333	// TODO: simplify this function, use session.ring to get hosts instead of the
334	// connection pool
335
336	var host *HostInfo
337	ch := c.getConn()
338	if ch != nil {
339		host = ch.host
340		ch.conn.Close()
341	}
342
343	var newConn *Conn
344	if host != nil {
345		// try to connect to the old host
346		conn, err := c.session.connect(host, c)
347		if err != nil {
348			// host is dead
349			// TODO: this is replicated in a few places
350			if c.session.cfg.ConvictionPolicy.AddFailure(err, host) {
351				c.session.handleNodeDown(host.ConnectAddress(), host.Port())
352			}
353		} else {
354			newConn = conn
355		}
356	}
357
358	// TODO: should have our own round-robin for hosts so that we can try each
359	// in succession and guarantee that we get a different host each time.
360	if newConn == nil {
361		host := c.session.ring.rrHost()
362		if host == nil {
363			c.connect(c.session.ring.endpoints)
364			return
365		}
366
367		var err error
368		newConn, err = c.session.connect(host, c)
369		if err != nil {
370			// TODO: add log handler for things like this
371			return
372		}
373	}
374
375	if err := c.setupConn(newConn); err != nil {
376		newConn.Close()
377		Logger.Printf("gocql: control unable to register events: %v\n", err)
378		return
379	}
380
381	if refreshring {
382		c.session.hostSource.refreshRing()
383	}
384}
385
386func (c *controlConn) HandleError(conn *Conn, err error, closed bool) {
387	if !closed {
388		return
389	}
390
391	oldConn := c.getConn()
392	if oldConn.conn != conn {
393		return
394	}
395
396	c.reconnect(false)
397}
398
399func (c *controlConn) getConn() *connHost {
400	return c.conn.Load().(*connHost)
401}
402
403func (c *controlConn) writeFrame(w frameWriter) (frame, error) {
404	ch := c.getConn()
405	if ch == nil {
406		return nil, errNoControl
407	}
408
409	framer, err := ch.conn.exec(context.Background(), w, nil)
410	if err != nil {
411		return nil, err
412	}
413
414	return framer.parseFrame()
415}
416
417func (c *controlConn) withConnHost(fn func(*connHost) *Iter) *Iter {
418	const maxConnectAttempts = 5
419	connectAttempts := 0
420
421	for i := 0; i < maxConnectAttempts; i++ {
422		ch := c.getConn()
423		if ch == nil {
424			if connectAttempts > maxConnectAttempts {
425				break
426			}
427
428			connectAttempts++
429
430			c.reconnect(false)
431			continue
432		}
433
434		return fn(ch)
435	}
436
437	return &Iter{err: errNoControl}
438}
439
440func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter {
441	return c.withConnHost(func(ch *connHost) *Iter {
442		return fn(ch.conn)
443	})
444}
445
446// query will return nil if the connection is closed or nil
447func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter) {
448	q := c.session.Query(statement, values...).Consistency(One).RoutingKey([]byte{}).Trace(nil)
449
450	for {
451		iter = c.withConn(func(conn *Conn) *Iter {
452			return conn.executeQuery(context.TODO(), q)
453		})
454
455		if gocqlDebug && iter.err != nil {
456			Logger.Printf("control: error executing %q: %v\n", statement, iter.err)
457		}
458
459		q.AddAttempts(1, c.getConn().host)
460		if iter.err == nil || !c.retry.Attempt(q) {
461			break
462		}
463	}
464
465	return
466}
467
468func (c *controlConn) awaitSchemaAgreement() error {
469	return c.withConn(func(conn *Conn) *Iter {
470		return &Iter{err: conn.awaitSchemaAgreement(context.TODO())}
471	}).err
472}
473
474func (c *controlConn) close() {
475	if atomic.CompareAndSwapInt32(&c.started, 1, -1) {
476		c.quit <- struct{}{}
477	}
478
479	ch := c.getConn()
480	if ch != nil {
481		ch.conn.Close()
482	}
483}
484
485var errNoControl = errors.New("gocql: no control connection available")
486