1// Copyright (c) 2012 The gocql Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package gocql
6
7import (
8	"crypto/tls"
9	"crypto/x509"
10	"errors"
11	"fmt"
12	"io/ioutil"
13	"math/rand"
14	"net"
15	"sync"
16	"sync/atomic"
17	"time"
18)
19
20// interface to implement to receive the host information
21type SetHosts interface {
22	SetHosts(hosts []*HostInfo)
23}
24
25// interface to implement to receive the partitioner value
26type SetPartitioner interface {
27	SetPartitioner(partitioner string)
28}
29
30func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
31	if sslOpts.Config == nil {
32		sslOpts.Config = &tls.Config{}
33	}
34
35	// ca cert is optional
36	if sslOpts.CaPath != "" {
37		if sslOpts.RootCAs == nil {
38			sslOpts.RootCAs = x509.NewCertPool()
39		}
40
41		pem, err := ioutil.ReadFile(sslOpts.CaPath)
42		if err != nil {
43			return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err)
44		}
45
46		if !sslOpts.RootCAs.AppendCertsFromPEM(pem) {
47			return nil, errors.New("connectionpool: failed parsing or CA certs")
48		}
49	}
50
51	if sslOpts.CertPath != "" || sslOpts.KeyPath != "" {
52		mycert, err := tls.LoadX509KeyPair(sslOpts.CertPath, sslOpts.KeyPath)
53		if err != nil {
54			return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err)
55		}
56		sslOpts.Certificates = append(sslOpts.Certificates, mycert)
57	}
58
59	sslOpts.InsecureSkipVerify = !sslOpts.EnableHostVerification
60
61	// return clone to avoid race
62	return sslOpts.Config.Clone(), nil
63}
64
65type policyConnPool struct {
66	session *Session
67
68	port     int
69	numConns int
70	keyspace string
71
72	mu            sync.RWMutex
73	hostConnPools map[string]*hostConnPool
74
75	endpoints []string
76}
77
78func connConfig(cfg *ClusterConfig) (*ConnConfig, error) {
79	var (
80		err       error
81		tlsConfig *tls.Config
82	)
83
84	// TODO(zariel): move tls config setup into session init.
85	if cfg.SslOpts != nil {
86		tlsConfig, err = setupTLSConfig(cfg.SslOpts)
87		if err != nil {
88			return nil, err
89		}
90	}
91
92	return &ConnConfig{
93		ProtoVersion:    cfg.ProtoVersion,
94		CQLVersion:      cfg.CQLVersion,
95		Timeout:         cfg.Timeout,
96		ConnectTimeout:  cfg.ConnectTimeout,
97		Compressor:      cfg.Compressor,
98		Authenticator:   cfg.Authenticator,
99		AuthProvider:    cfg.AuthProvider,
100		Keepalive:       cfg.SocketKeepalive,
101		tlsConfig:       tlsConfig,
102		disableCoalesce: tlsConfig != nil, // write coalescing doesn't work with framing on top of TCP like in TLS.
103	}, nil
104}
105
106func newPolicyConnPool(session *Session) *policyConnPool {
107	// create the pool
108	pool := &policyConnPool{
109		session:       session,
110		port:          session.cfg.Port,
111		numConns:      session.cfg.NumConns,
112		keyspace:      session.cfg.Keyspace,
113		hostConnPools: map[string]*hostConnPool{},
114	}
115
116	pool.endpoints = make([]string, len(session.cfg.Hosts))
117	copy(pool.endpoints, session.cfg.Hosts)
118
119	return pool
120}
121
122func (p *policyConnPool) SetHosts(hosts []*HostInfo) {
123	p.mu.Lock()
124	defer p.mu.Unlock()
125
126	toRemove := make(map[string]struct{})
127	for addr := range p.hostConnPools {
128		toRemove[addr] = struct{}{}
129	}
130
131	pools := make(chan *hostConnPool)
132	createCount := 0
133	for _, host := range hosts {
134		if !host.IsUp() {
135			// don't create a connection pool for a down host
136			continue
137		}
138		ip := host.ConnectAddress().String()
139		if _, exists := p.hostConnPools[ip]; exists {
140			// still have this host, so don't remove it
141			delete(toRemove, ip)
142			continue
143		}
144
145		createCount++
146		go func(host *HostInfo) {
147			// create a connection pool for the host
148			pools <- newHostConnPool(
149				p.session,
150				host,
151				p.port,
152				p.numConns,
153				p.keyspace,
154			)
155		}(host)
156	}
157
158	// add created pools
159	for createCount > 0 {
160		pool := <-pools
161		createCount--
162		if pool.Size() > 0 {
163			// add pool only if there a connections available
164			p.hostConnPools[string(pool.host.ConnectAddress())] = pool
165		}
166	}
167
168	for addr := range toRemove {
169		pool := p.hostConnPools[addr]
170		delete(p.hostConnPools, addr)
171		go pool.Close()
172	}
173}
174
175func (p *policyConnPool) Size() int {
176	p.mu.RLock()
177	count := 0
178	for _, pool := range p.hostConnPools {
179		count += pool.Size()
180	}
181	p.mu.RUnlock()
182
183	return count
184}
185
186func (p *policyConnPool) getPool(host *HostInfo) (pool *hostConnPool, ok bool) {
187	ip := host.ConnectAddress().String()
188	p.mu.RLock()
189	pool, ok = p.hostConnPools[ip]
190	p.mu.RUnlock()
191	return
192}
193
194func (p *policyConnPool) Close() {
195	p.mu.Lock()
196	defer p.mu.Unlock()
197
198	// close the pools
199	for addr, pool := range p.hostConnPools {
200		delete(p.hostConnPools, addr)
201		pool.Close()
202	}
203}
204
205func (p *policyConnPool) addHost(host *HostInfo) {
206	ip := host.ConnectAddress().String()
207	p.mu.Lock()
208	pool, ok := p.hostConnPools[ip]
209	if !ok {
210		pool = newHostConnPool(
211			p.session,
212			host,
213			host.Port(), // TODO: if port == 0 use pool.port?
214			p.numConns,
215			p.keyspace,
216		)
217
218		p.hostConnPools[ip] = pool
219	}
220	p.mu.Unlock()
221
222	pool.fill()
223}
224
225func (p *policyConnPool) removeHost(ip net.IP) {
226	k := ip.String()
227	p.mu.Lock()
228	pool, ok := p.hostConnPools[k]
229	if !ok {
230		p.mu.Unlock()
231		return
232	}
233
234	delete(p.hostConnPools, k)
235	p.mu.Unlock()
236
237	go pool.Close()
238}
239
240func (p *policyConnPool) hostUp(host *HostInfo) {
241	// TODO(zariel): have a set of up hosts and down hosts, we can internally
242	// detect down hosts, then try to reconnect to them.
243	p.addHost(host)
244}
245
246func (p *policyConnPool) hostDown(ip net.IP) {
247	// TODO(zariel): mark host as down so we can try to connect to it later, for
248	// now just treat it has removed.
249	p.removeHost(ip)
250}
251
252// hostConnPool is a connection pool for a single host.
253// Connection selection is based on a provided ConnSelectionPolicy
254type hostConnPool struct {
255	session  *Session
256	host     *HostInfo
257	port     int
258	addr     string
259	size     int
260	keyspace string
261	// protection for conns, closed, filling
262	mu      sync.RWMutex
263	conns   []*Conn
264	closed  bool
265	filling bool
266
267	pos uint32
268}
269
270func (h *hostConnPool) String() string {
271	h.mu.RLock()
272	defer h.mu.RUnlock()
273	return fmt.Sprintf("[filling=%v closed=%v conns=%v size=%v host=%v]",
274		h.filling, h.closed, len(h.conns), h.size, h.host)
275}
276
277func newHostConnPool(session *Session, host *HostInfo, port, size int,
278	keyspace string) *hostConnPool {
279
280	pool := &hostConnPool{
281		session:  session,
282		host:     host,
283		port:     port,
284		addr:     (&net.TCPAddr{IP: host.ConnectAddress(), Port: host.Port()}).String(),
285		size:     size,
286		keyspace: keyspace,
287		conns:    make([]*Conn, 0, size),
288		filling:  false,
289		closed:   false,
290	}
291
292	// the pool is not filled or connected
293	return pool
294}
295
296// Pick a connection from this connection pool for the given query.
297func (pool *hostConnPool) Pick() *Conn {
298	pool.mu.RLock()
299	defer pool.mu.RUnlock()
300
301	if pool.closed {
302		return nil
303	}
304
305	size := len(pool.conns)
306	if size < pool.size {
307		// try to fill the pool
308		go pool.fill()
309
310		if size == 0 {
311			return nil
312		}
313	}
314
315	pos := int(atomic.AddUint32(&pool.pos, 1) - 1)
316
317	var (
318		leastBusyConn    *Conn
319		streamsAvailable int
320	)
321
322	// find the conn which has the most available streams, this is racy
323	for i := 0; i < size; i++ {
324		conn := pool.conns[(pos+i)%size]
325		if streams := conn.AvailableStreams(); streams > streamsAvailable {
326			leastBusyConn = conn
327			streamsAvailable = streams
328		}
329	}
330
331	return leastBusyConn
332}
333
334//Size returns the number of connections currently active in the pool
335func (pool *hostConnPool) Size() int {
336	pool.mu.RLock()
337	defer pool.mu.RUnlock()
338
339	return len(pool.conns)
340}
341
342//Close the connection pool
343func (pool *hostConnPool) Close() {
344	pool.mu.Lock()
345
346	if pool.closed {
347		pool.mu.Unlock()
348		return
349	}
350	pool.closed = true
351
352	// ensure we dont try to reacquire the lock in handleError
353	// TODO: improve this as the following can happen
354	// 1) we have locked pool.mu write lock
355	// 2) conn.Close calls conn.closeWithError(nil)
356	// 3) conn.closeWithError calls conn.Close() which returns an error
357	// 4) conn.closeWithError calls pool.HandleError with the error from conn.Close
358	// 5) pool.HandleError tries to lock pool.mu
359	// deadlock
360
361	// empty the pool
362	conns := pool.conns
363	pool.conns = nil
364
365	pool.mu.Unlock()
366
367	// close the connections
368	for _, conn := range conns {
369		conn.Close()
370	}
371}
372
373// Fill the connection pool
374func (pool *hostConnPool) fill() {
375	pool.mu.RLock()
376	// avoid filling a closed pool, or concurrent filling
377	if pool.closed || pool.filling {
378		pool.mu.RUnlock()
379		return
380	}
381
382	// determine the filling work to be done
383	startCount := len(pool.conns)
384	fillCount := pool.size - startCount
385
386	// avoid filling a full (or overfull) pool
387	if fillCount <= 0 {
388		pool.mu.RUnlock()
389		return
390	}
391
392	// switch from read to write lock
393	pool.mu.RUnlock()
394	pool.mu.Lock()
395
396	// double check everything since the lock was released
397	startCount = len(pool.conns)
398	fillCount = pool.size - startCount
399	if pool.closed || pool.filling || fillCount <= 0 {
400		// looks like another goroutine already beat this
401		// goroutine to the filling
402		pool.mu.Unlock()
403		return
404	}
405
406	// ok fill the pool
407	pool.filling = true
408
409	// allow others to access the pool while filling
410	pool.mu.Unlock()
411	// only this goroutine should make calls to fill/empty the pool at this
412	// point until after this routine or its subordinates calls
413	// fillingStopped
414
415	// fill only the first connection synchronously
416	if startCount == 0 {
417		err := pool.connect()
418		pool.logConnectErr(err)
419
420		if err != nil {
421			// probably unreachable host
422			pool.fillingStopped(true)
423
424			// this is call with the connection pool mutex held, this call will
425			// then recursively try to lock it again. FIXME
426			if pool.session.cfg.ConvictionPolicy.AddFailure(err, pool.host) {
427				go pool.session.handleNodeDown(pool.host.ConnectAddress(), pool.port)
428			}
429			return
430		}
431
432		// filled one
433		fillCount--
434	}
435
436	// fill the rest of the pool asynchronously
437	go func() {
438		err := pool.connectMany(fillCount)
439
440		// mark the end of filling
441		pool.fillingStopped(err != nil)
442	}()
443}
444
445func (pool *hostConnPool) logConnectErr(err error) {
446	if opErr, ok := err.(*net.OpError); ok && (opErr.Op == "dial" || opErr.Op == "read") {
447		// connection refused
448		// these are typical during a node outage so avoid log spam.
449		if gocqlDebug {
450			Logger.Printf("unable to dial %q: %v\n", pool.host.ConnectAddress(), err)
451		}
452	} else if err != nil {
453		// unexpected error
454		Logger.Printf("error: failed to connect to %s due to error: %v", pool.addr, err)
455	}
456}
457
458// transition back to a not-filling state.
459func (pool *hostConnPool) fillingStopped(hadError bool) {
460	if hadError {
461		// wait for some time to avoid back-to-back filling
462		// this provides some time between failed attempts
463		// to fill the pool for the host to recover
464		time.Sleep(time.Duration(rand.Int31n(100)+31) * time.Millisecond)
465	}
466
467	pool.mu.Lock()
468	pool.filling = false
469	pool.mu.Unlock()
470}
471
472// connectMany creates new connections concurrent.
473func (pool *hostConnPool) connectMany(count int) error {
474	if count == 0 {
475		return nil
476	}
477	var (
478		wg         sync.WaitGroup
479		mu         sync.Mutex
480		connectErr error
481	)
482	wg.Add(count)
483	for i := 0; i < count; i++ {
484		go func() {
485			defer wg.Done()
486			err := pool.connect()
487			pool.logConnectErr(err)
488			if err != nil {
489				mu.Lock()
490				connectErr = err
491				mu.Unlock()
492			}
493		}()
494	}
495	// wait for all connections are done
496	wg.Wait()
497
498	return connectErr
499}
500
501// create a new connection to the host and add it to the pool
502func (pool *hostConnPool) connect() (err error) {
503	// TODO: provide a more robust connection retry mechanism, we should also
504	// be able to detect hosts that come up by trying to connect to downed ones.
505	// try to connect
506	var conn *Conn
507	reconnectionPolicy := pool.session.cfg.ReconnectionPolicy
508	for i := 0; i < reconnectionPolicy.GetMaxRetries(); i++ {
509		conn, err = pool.session.connect(pool.host, pool)
510		if err == nil {
511			break
512		}
513		if opErr, isOpErr := err.(*net.OpError); isOpErr {
514			// if the error is not a temporary error (ex: network unreachable) don't
515			//  retry
516			if !opErr.Temporary() {
517				break
518			}
519		}
520		if gocqlDebug {
521			Logger.Printf("connection failed %q: %v, reconnecting with %T\n",
522				pool.host.ConnectAddress(), err, reconnectionPolicy)
523		}
524		time.Sleep(reconnectionPolicy.GetInterval(i))
525	}
526
527	if err != nil {
528		return err
529	}
530
531	if pool.keyspace != "" {
532		// set the keyspace
533		if err = conn.UseKeyspace(pool.keyspace); err != nil {
534			conn.Close()
535			return err
536		}
537	}
538
539	// add the Conn to the pool
540	pool.mu.Lock()
541	defer pool.mu.Unlock()
542
543	if pool.closed {
544		conn.Close()
545		return nil
546	}
547
548	pool.conns = append(pool.conns, conn)
549
550	return nil
551}
552
553// handle any error from a Conn
554func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) {
555	if !closed {
556		// still an open connection, so continue using it
557		return
558	}
559
560	// TODO: track the number of errors per host and detect when a host is dead,
561	// then also have something which can detect when a host comes back.
562	pool.mu.Lock()
563	defer pool.mu.Unlock()
564
565	if pool.closed {
566		// pool closed
567		return
568	}
569
570	// find the connection index
571	for i, candidate := range pool.conns {
572		if candidate == conn {
573			// remove the connection, not preserving order
574			pool.conns[i], pool.conns = pool.conns[len(pool.conns)-1], pool.conns[:len(pool.conns)-1]
575
576			// lost a connection, so fill the pool
577			go pool.fill()
578			break
579		}
580	}
581}
582