1package gocql
2
3import (
4	"context"
5	crand "crypto/rand"
6	"errors"
7	"fmt"
8	"math/rand"
9	"net"
10	"os"
11	"regexp"
12	"strconv"
13	"sync"
14	"sync/atomic"
15	"time"
16)
17
18var (
19	randr    *rand.Rand
20	mutRandr sync.Mutex
21)
22
23func init() {
24	b := make([]byte, 4)
25	if _, err := crand.Read(b); err != nil {
26		panic(fmt.Sprintf("unable to seed random number generator: %v", err))
27	}
28
29	randr = rand.New(rand.NewSource(int64(readInt(b))))
30}
31
32// Ensure that the atomic variable is aligned to a 64bit boundary
33// so that atomic operations can be applied on 32bit architectures.
34type controlConn struct {
35	started      int32
36	reconnecting int32
37
38	session *Session
39	conn    atomic.Value
40
41	retry RetryPolicy
42
43	quit chan struct{}
44}
45
46func createControlConn(session *Session) *controlConn {
47	control := &controlConn{
48		session: session,
49		quit:    make(chan struct{}),
50		retry:   &SimpleRetryPolicy{NumRetries: 3},
51	}
52
53	control.conn.Store((*connHost)(nil))
54
55	return control
56}
57
58func (c *controlConn) heartBeat() {
59	if !atomic.CompareAndSwapInt32(&c.started, 0, 1) {
60		return
61	}
62
63	sleepTime := 1 * time.Second
64	timer := time.NewTimer(sleepTime)
65	defer timer.Stop()
66
67	for {
68		timer.Reset(sleepTime)
69
70		select {
71		case <-c.quit:
72			return
73		case <-timer.C:
74		}
75
76		resp, err := c.writeFrame(&writeOptionsFrame{})
77		if err != nil {
78			goto reconn
79		}
80
81		switch resp.(type) {
82		case *supportedFrame:
83			// Everything ok
84			sleepTime = 5 * time.Second
85			continue
86		case error:
87			goto reconn
88		default:
89			panic(fmt.Sprintf("gocql: unknown frame in response to options: %T", resp))
90		}
91
92	reconn:
93		// try to connect a bit faster
94		sleepTime = 1 * time.Second
95		c.reconnect(true)
96		continue
97	}
98}
99
100var hostLookupPreferV4 = os.Getenv("GOCQL_HOST_LOOKUP_PREFER_V4") == "true"
101
102func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
103	var port int
104	host, portStr, err := net.SplitHostPort(addr)
105	if err != nil {
106		host = addr
107		port = defaultPort
108	} else {
109		port, err = strconv.Atoi(portStr)
110		if err != nil {
111			return nil, err
112		}
113	}
114
115	var hosts []*HostInfo
116
117	// Check if host is a literal IP address
118	if ip := net.ParseIP(host); ip != nil {
119		hosts = append(hosts, &HostInfo{connectAddress: ip, port: port})
120		return hosts, nil
121	}
122
123	// Look up host in DNS
124	ips, err := net.LookupIP(host)
125	if err != nil {
126		return nil, err
127	} else if len(ips) == 0 {
128		return nil, fmt.Errorf("No IP's returned from DNS lookup for %q", addr)
129	}
130
131	// Filter to v4 addresses if any present
132	if hostLookupPreferV4 {
133		var preferredIPs []net.IP
134		for _, v := range ips {
135			if v4 := v.To4(); v4 != nil {
136				preferredIPs = append(preferredIPs, v4)
137			}
138		}
139		if len(preferredIPs) != 0 {
140			ips = preferredIPs
141		}
142	}
143
144	for _, ip := range ips {
145		hosts = append(hosts, &HostInfo{connectAddress: ip, port: port})
146	}
147
148	return hosts, nil
149}
150
151func shuffleHosts(hosts []*HostInfo) []*HostInfo {
152	mutRandr.Lock()
153	perm := randr.Perm(len(hosts))
154	mutRandr.Unlock()
155	shuffled := make([]*HostInfo, len(hosts))
156
157	for i, host := range hosts {
158		shuffled[perm[i]] = host
159	}
160
161	return shuffled
162}
163
164func (c *controlConn) shuffleDial(endpoints []*HostInfo) (*Conn, error) {
165	// shuffle endpoints so not all drivers will connect to the same initial
166	// node.
167	shuffled := shuffleHosts(endpoints)
168
169	var err error
170	for _, host := range shuffled {
171		var conn *Conn
172		conn, err = c.session.connect(host, c)
173		if err == nil {
174			return conn, nil
175		}
176
177		Logger.Printf("gocql: unable to dial control conn %v: %v\n", host.ConnectAddress(), err)
178	}
179
180	return nil, err
181}
182
183// this is going to be version dependant and a nightmare to maintain :(
184var protocolSupportRe = regexp.MustCompile(`the lowest supported version is \d+ and the greatest is (\d+)$`)
185
186func parseProtocolFromError(err error) int {
187	// I really wish this had the actual info in the error frame...
188	matches := protocolSupportRe.FindAllStringSubmatch(err.Error(), -1)
189	if len(matches) != 1 || len(matches[0]) != 2 {
190		if verr, ok := err.(*protocolError); ok {
191			return int(verr.frame.Header().version.version())
192		}
193		return 0
194	}
195
196	max, err := strconv.Atoi(matches[0][1])
197	if err != nil {
198		return 0
199	}
200
201	return max
202}
203
204func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
205	hosts = shuffleHosts(hosts)
206
207	connCfg := *c.session.connCfg
208	connCfg.ProtoVersion = 4 // TODO: define maxProtocol
209
210	handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) {
211		// we should never get here, but if we do it means we connected to a
212		// host successfully which means our attempted protocol version worked
213		if !closed {
214			c.Close()
215		}
216	})
217
218	var err error
219	for _, host := range hosts {
220		var conn *Conn
221		conn, err = c.session.dial(host, &connCfg, handler)
222		if conn != nil {
223			conn.Close()
224		}
225
226		if err == nil {
227			return connCfg.ProtoVersion, nil
228		}
229
230		if proto := parseProtocolFromError(err); proto > 0 {
231			return proto, nil
232		}
233	}
234
235	return 0, err
236}
237
238func (c *controlConn) connect(hosts []*HostInfo) error {
239	if len(hosts) == 0 {
240		return errors.New("control: no endpoints specified")
241	}
242
243	conn, err := c.shuffleDial(hosts)
244	if err != nil {
245		return fmt.Errorf("control: unable to connect to initial hosts: %v", err)
246	}
247
248	if err := c.setupConn(conn); err != nil {
249		conn.Close()
250		return fmt.Errorf("control: unable to setup connection: %v", err)
251	}
252
253	// we could fetch the initial ring here and update initial host data. So that
254	// when we return from here we have a ring topology ready to go.
255
256	go c.heartBeat()
257
258	return nil
259}
260
261type connHost struct {
262	conn *Conn
263	host *HostInfo
264}
265
266func (c *controlConn) setupConn(conn *Conn) error {
267	if err := c.registerEvents(conn); err != nil {
268		conn.Close()
269		return err
270	}
271
272	// TODO(zariel): do we need to fetch host info everytime
273	// the control conn connects? Surely we have it cached?
274	host, err := conn.localHostInfo()
275	if err != nil {
276		return err
277	}
278
279	ch := &connHost{
280		conn: conn,
281		host: host,
282	}
283
284	c.conn.Store(ch)
285	c.session.handleNodeUp(host.ConnectAddress(), host.Port(), false)
286
287	return nil
288}
289
290func (c *controlConn) registerEvents(conn *Conn) error {
291	var events []string
292
293	if !c.session.cfg.Events.DisableTopologyEvents {
294		events = append(events, "TOPOLOGY_CHANGE")
295	}
296	if !c.session.cfg.Events.DisableNodeStatusEvents {
297		events = append(events, "STATUS_CHANGE")
298	}
299	if !c.session.cfg.Events.DisableSchemaEvents {
300		events = append(events, "SCHEMA_CHANGE")
301	}
302
303	if len(events) == 0 {
304		return nil
305	}
306
307	framer, err := conn.exec(context.Background(),
308		&writeRegisterFrame{
309			events: events,
310		}, nil)
311	if err != nil {
312		return err
313	}
314
315	frame, err := framer.parseFrame()
316	if err != nil {
317		return err
318	} else if _, ok := frame.(*readyFrame); !ok {
319		return fmt.Errorf("unexpected frame in response to register: got %T: %v\n", frame, frame)
320	}
321
322	return nil
323}
324
325func (c *controlConn) reconnect(refreshring bool) {
326	if !atomic.CompareAndSwapInt32(&c.reconnecting, 0, 1) {
327		return
328	}
329	defer atomic.StoreInt32(&c.reconnecting, 0)
330	// TODO: simplify this function, use session.ring to get hosts instead of the
331	// connection pool
332
333	var host *HostInfo
334	ch := c.getConn()
335	if ch != nil {
336		host = ch.host
337		ch.conn.Close()
338	}
339
340	var newConn *Conn
341	if host != nil {
342		// try to connect to the old host
343		conn, err := c.session.connect(host, c)
344		if err != nil {
345			// host is dead
346			// TODO: this is replicated in a few places
347			if c.session.cfg.ConvictionPolicy.AddFailure(err, host) {
348				c.session.handleNodeDown(host.ConnectAddress(), host.Port())
349			}
350		} else {
351			newConn = conn
352		}
353	}
354
355	// TODO: should have our own round-robin for hosts so that we can try each
356	// in succession and guarantee that we get a different host each time.
357	if newConn == nil {
358		host := c.session.ring.rrHost()
359		if host == nil {
360			c.connect(c.session.ring.endpoints)
361			return
362		}
363
364		var err error
365		newConn, err = c.session.connect(host, c)
366		if err != nil {
367			// TODO: add log handler for things like this
368			return
369		}
370	}
371
372	if err := c.setupConn(newConn); err != nil {
373		newConn.Close()
374		Logger.Printf("gocql: control unable to register events: %v\n", err)
375		return
376	}
377
378	if refreshring {
379		c.session.hostSource.refreshRing()
380	}
381}
382
383func (c *controlConn) HandleError(conn *Conn, err error, closed bool) {
384	if !closed {
385		return
386	}
387
388	oldConn := c.getConn()
389	if oldConn.conn != conn {
390		return
391	}
392
393	c.reconnect(false)
394}
395
396func (c *controlConn) getConn() *connHost {
397	return c.conn.Load().(*connHost)
398}
399
400func (c *controlConn) writeFrame(w frameWriter) (frame, error) {
401	ch := c.getConn()
402	if ch == nil {
403		return nil, errNoControl
404	}
405
406	framer, err := ch.conn.exec(context.Background(), w, nil)
407	if err != nil {
408		return nil, err
409	}
410
411	return framer.parseFrame()
412}
413
414func (c *controlConn) withConnHost(fn func(*connHost) *Iter) *Iter {
415	const maxConnectAttempts = 5
416	connectAttempts := 0
417
418	for i := 0; i < maxConnectAttempts; i++ {
419		ch := c.getConn()
420		if ch == nil {
421			if connectAttempts > maxConnectAttempts {
422				break
423			}
424
425			connectAttempts++
426
427			c.reconnect(false)
428			continue
429		}
430
431		return fn(ch)
432	}
433
434	return &Iter{err: errNoControl}
435}
436
437func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter {
438	return c.withConnHost(func(ch *connHost) *Iter {
439		return fn(ch.conn)
440	})
441}
442
443// query will return nil if the connection is closed or nil
444func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter) {
445	q := c.session.Query(statement, values...).Consistency(One).RoutingKey([]byte{}).Trace(nil)
446
447	for {
448		iter = c.withConn(func(conn *Conn) *Iter {
449			return conn.executeQuery(q)
450		})
451
452		if gocqlDebug && iter.err != nil {
453			Logger.Printf("control: error executing %q: %v\n", statement, iter.err)
454		}
455
456		q.attempts++
457		if iter.err == nil || !c.retry.Attempt(q) {
458			break
459		}
460	}
461
462	return
463}
464
465func (c *controlConn) awaitSchemaAgreement() error {
466	return c.withConn(func(conn *Conn) *Iter {
467		return &Iter{err: conn.awaitSchemaAgreement()}
468	}).err
469}
470
471func (c *controlConn) close() {
472	if atomic.CompareAndSwapInt32(&c.started, 1, -1) {
473		c.quit <- struct{}{}
474	}
475
476	ch := c.getConn()
477	if ch != nil {
478		ch.conn.Close()
479	}
480}
481
482var errNoControl = errors.New("gocql: no control connection available")
483