1// Copyright (c) 2012 The gocql Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package gocql
6
7import (
8	"bufio"
9	"context"
10	"crypto/tls"
11	"errors"
12	"fmt"
13	"io"
14	"io/ioutil"
15	"net"
16	"strconv"
17	"strings"
18	"sync"
19	"sync/atomic"
20	"time"
21
22	"github.com/gocql/gocql/internal/lru"
23	"github.com/gocql/gocql/internal/streams"
24)
25
26var (
27	approvedAuthenticators = [...]string{
28		"org.apache.cassandra.auth.PasswordAuthenticator",
29		"com.instaclustr.cassandra.auth.SharedSecretAuthenticator",
30		"com.datastax.bdp.cassandra.auth.DseAuthenticator",
31		"io.aiven.cassandra.auth.AivenAuthenticator",
32	}
33)
34
35func approve(authenticator string) bool {
36	for _, s := range approvedAuthenticators {
37		if authenticator == s {
38			return true
39		}
40	}
41	return false
42}
43
44//JoinHostPort is a utility to return a address string that can be used
45//gocql.Conn to form a connection with a host.
46func JoinHostPort(addr string, port int) string {
47	addr = strings.TrimSpace(addr)
48	if _, _, err := net.SplitHostPort(addr); err != nil {
49		addr = net.JoinHostPort(addr, strconv.Itoa(port))
50	}
51	return addr
52}
53
54type Authenticator interface {
55	Challenge(req []byte) (resp []byte, auth Authenticator, err error)
56	Success(data []byte) error
57}
58
59type PasswordAuthenticator struct {
60	Username string
61	Password string
62}
63
64func (p PasswordAuthenticator) Challenge(req []byte) ([]byte, Authenticator, error) {
65	if !approve(string(req)) {
66		return nil, nil, fmt.Errorf("unexpected authenticator %q", req)
67	}
68	resp := make([]byte, 2+len(p.Username)+len(p.Password))
69	resp[0] = 0
70	copy(resp[1:], p.Username)
71	resp[len(p.Username)+1] = 0
72	copy(resp[2+len(p.Username):], p.Password)
73	return resp, nil, nil
74}
75
76func (p PasswordAuthenticator) Success(data []byte) error {
77	return nil
78}
79
80type SslOptions struct {
81	*tls.Config
82
83	// CertPath and KeyPath are optional depending on server
84	// config, but both fields must be omitted to avoid using a
85	// client certificate
86	CertPath string
87	KeyPath  string
88	CaPath   string //optional depending on server config
89	// If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this on
90	// This option is basically the inverse of InSecureSkipVerify
91	// See InSecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info
92	EnableHostVerification bool
93}
94
95type ConnConfig struct {
96	ProtoVersion   int
97	CQLVersion     string
98	Timeout        time.Duration
99	ConnectTimeout time.Duration
100	Compressor     Compressor
101	Authenticator  Authenticator
102	AuthProvider   func(h *HostInfo) (Authenticator, error)
103	Keepalive      time.Duration
104
105	tlsConfig       *tls.Config
106	disableCoalesce bool
107}
108
109type ConnErrorHandler interface {
110	HandleError(conn *Conn, err error, closed bool)
111}
112
113type connErrorHandlerFn func(conn *Conn, err error, closed bool)
114
115func (fn connErrorHandlerFn) HandleError(conn *Conn, err error, closed bool) {
116	fn(conn, err, closed)
117}
118
119// If not zero, how many timeouts we will allow to occur before the connection is closed
120// and restarted. This is to prevent a single query timeout from killing a connection
121// which may be serving more queries just fine.
122// Default is 0, should not be changed concurrently with queries.
123//
124// depreciated
125var TimeoutLimit int64 = 0
126
127// Conn is a single connection to a Cassandra node. It can be used to execute
128// queries, but users are usually advised to use a more reliable, higher
129// level API.
130type Conn struct {
131	conn net.Conn
132	r    *bufio.Reader
133	w    io.Writer
134
135	timeout       time.Duration
136	cfg           *ConnConfig
137	frameObserver FrameHeaderObserver
138
139	headerBuf [maxFrameHeaderSize]byte
140
141	streams *streams.IDGenerator
142	mu      sync.Mutex
143	calls   map[int]*callReq
144
145	errorHandler ConnErrorHandler
146	compressor   Compressor
147	auth         Authenticator
148	addr         string
149
150	version         uint8
151	currentKeyspace string
152	host            *HostInfo
153
154	session *Session
155
156	closed int32
157	quit   chan struct{}
158
159	timeouts int64
160}
161
162// connect establishes a connection to a Cassandra node using session's connection config.
163func (s *Session) connect(host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) {
164	return s.dial(host, s.connCfg, errorHandler)
165}
166
167// dial establishes a connection to a Cassandra node and notifies the session's connectObserver.
168func (s *Session) dial(host *HostInfo, connConfig *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
169	var obs ObservedConnect
170	if s.connectObserver != nil {
171		obs.Host = host
172		obs.Start = time.Now()
173	}
174
175	conn, err := s.dialWithoutObserver(host, connConfig, errorHandler)
176
177	if s.connectObserver != nil {
178		obs.End = time.Now()
179		obs.Err = err
180		s.connectObserver.ObserveConnect(obs)
181	}
182
183	return conn, err
184}
185
186// dialWithoutObserver establishes connection to a Cassandra node.
187//
188// dialWithoutObserver does not notify the connection observer, so you most probably want to call dial() instead.
189func (s *Session) dialWithoutObserver(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
190	ip := host.ConnectAddress()
191	port := host.port
192
193	// TODO(zariel): remove these
194	if len(ip) == 0 || ip.IsUnspecified() {
195		panic(fmt.Sprintf("host missing connect ip address: %v", ip))
196	} else if port == 0 {
197		panic(fmt.Sprintf("host missing port: %v", port))
198	}
199
200	var (
201		err  error
202		conn net.Conn
203	)
204
205	dialer := &net.Dialer{
206		Timeout: cfg.ConnectTimeout,
207	}
208	if cfg.Keepalive > 0 {
209		dialer.KeepAlive = cfg.Keepalive
210	}
211
212	// TODO(zariel): handle ipv6 zone
213	addr := (&net.TCPAddr{IP: ip, Port: port}).String()
214
215	if cfg.tlsConfig != nil {
216		// the TLS config is safe to be reused by connections but it must not
217		// be modified after being used.
218		conn, err = tls.DialWithDialer(dialer, "tcp", addr, cfg.tlsConfig)
219	} else {
220		conn, err = dialer.Dial("tcp", addr)
221	}
222
223	if err != nil {
224		return nil, err
225	}
226
227	c := &Conn{
228		conn:          conn,
229		r:             bufio.NewReader(conn),
230		cfg:           cfg,
231		calls:         make(map[int]*callReq),
232		version:       uint8(cfg.ProtoVersion),
233		addr:          conn.RemoteAddr().String(),
234		errorHandler:  errorHandler,
235		compressor:    cfg.Compressor,
236		quit:          make(chan struct{}),
237		session:       s,
238		streams:       streams.New(cfg.ProtoVersion),
239		host:          host,
240		frameObserver: s.frameObserver,
241		w: &deadlineWriter{
242			w:       conn,
243			timeout: cfg.Timeout,
244		},
245	}
246
247	if cfg.AuthProvider != nil {
248		c.auth, err = cfg.AuthProvider(host)
249		if err != nil {
250			return nil, err
251		}
252	} else {
253		c.auth = cfg.Authenticator
254	}
255
256	var (
257		ctx    context.Context
258		cancel func()
259	)
260	if cfg.ConnectTimeout > 0 {
261		ctx, cancel = context.WithTimeout(context.TODO(), cfg.ConnectTimeout)
262	} else {
263		ctx, cancel = context.WithCancel(context.TODO())
264	}
265	defer cancel()
266
267	startup := &startupCoordinator{
268		frameTicker: make(chan struct{}),
269		conn:        c,
270	}
271
272	c.timeout = cfg.ConnectTimeout
273	if err := startup.setupConn(ctx); err != nil {
274		c.close()
275		return nil, err
276	}
277
278	c.timeout = cfg.Timeout
279
280	// dont coalesce startup frames
281	if s.cfg.WriteCoalesceWaitTime > 0 && !cfg.disableCoalesce {
282		c.w = newWriteCoalescer(conn, c.timeout, s.cfg.WriteCoalesceWaitTime, c.quit)
283	}
284
285	go c.serve()
286	go c.heartBeat()
287
288	return c, nil
289}
290
291func (c *Conn) Write(p []byte) (n int, err error) {
292	return c.w.Write(p)
293}
294
295func (c *Conn) Read(p []byte) (n int, err error) {
296	const maxAttempts = 5
297
298	for i := 0; i < maxAttempts; i++ {
299		var nn int
300		if c.timeout > 0 {
301			c.conn.SetReadDeadline(time.Now().Add(c.timeout))
302		}
303
304		nn, err = io.ReadFull(c.r, p[n:])
305		n += nn
306		if err == nil {
307			break
308		}
309
310		if verr, ok := err.(net.Error); !ok || !verr.Temporary() {
311			break
312		}
313	}
314
315	return
316}
317
318type startupCoordinator struct {
319	conn        *Conn
320	frameTicker chan struct{}
321}
322
323func (s *startupCoordinator) setupConn(ctx context.Context) error {
324	startupErr := make(chan error)
325	go func() {
326		for range s.frameTicker {
327			err := s.conn.recv()
328			if err != nil {
329				select {
330				case startupErr <- err:
331				case <-ctx.Done():
332				}
333
334				return
335			}
336		}
337	}()
338
339	go func() {
340		defer close(s.frameTicker)
341		err := s.options(ctx)
342		select {
343		case startupErr <- err:
344		case <-ctx.Done():
345		}
346	}()
347
348	select {
349	case err := <-startupErr:
350		if err != nil {
351			return err
352		}
353	case <-ctx.Done():
354		return errors.New("gocql: no response to connection startup within timeout")
355	}
356
357	return nil
358}
359
360func (s *startupCoordinator) write(ctx context.Context, frame frameWriter) (frame, error) {
361	select {
362	case s.frameTicker <- struct{}{}:
363	case <-ctx.Done():
364		return nil, ctx.Err()
365	}
366
367	framer, err := s.conn.exec(ctx, frame, nil)
368	if err != nil {
369		return nil, err
370	}
371
372	return framer.parseFrame()
373}
374
375func (s *startupCoordinator) options(ctx context.Context) error {
376	frame, err := s.write(ctx, &writeOptionsFrame{})
377	if err != nil {
378		return err
379	}
380
381	supported, ok := frame.(*supportedFrame)
382	if !ok {
383		return NewErrProtocol("Unknown type of response to startup frame: %T", frame)
384	}
385
386	return s.startup(ctx, supported.supported)
387}
388
389func (s *startupCoordinator) startup(ctx context.Context, supported map[string][]string) error {
390	m := map[string]string{
391		"CQL_VERSION": s.conn.cfg.CQLVersion,
392	}
393
394	if s.conn.compressor != nil {
395		comp := supported["COMPRESSION"]
396		name := s.conn.compressor.Name()
397		for _, compressor := range comp {
398			if compressor == name {
399				m["COMPRESSION"] = compressor
400				break
401			}
402		}
403
404		if _, ok := m["COMPRESSION"]; !ok {
405			s.conn.compressor = nil
406		}
407	}
408
409	frame, err := s.write(ctx, &writeStartupFrame{opts: m})
410	if err != nil {
411		return err
412	}
413
414	switch v := frame.(type) {
415	case error:
416		return v
417	case *readyFrame:
418		return nil
419	case *authenticateFrame:
420		return s.authenticateHandshake(ctx, v)
421	default:
422		return NewErrProtocol("Unknown type of response to startup frame: %s", v)
423	}
424}
425
426func (s *startupCoordinator) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame) error {
427	if s.conn.auth == nil {
428		return fmt.Errorf("authentication required (using %q)", authFrame.class)
429	}
430
431	resp, challenger, err := s.conn.auth.Challenge([]byte(authFrame.class))
432	if err != nil {
433		return err
434	}
435
436	req := &writeAuthResponseFrame{data: resp}
437	for {
438		frame, err := s.write(ctx, req)
439		if err != nil {
440			return err
441		}
442
443		switch v := frame.(type) {
444		case error:
445			return v
446		case *authSuccessFrame:
447			if challenger != nil {
448				return challenger.Success(v.data)
449			}
450			return nil
451		case *authChallengeFrame:
452			resp, challenger, err = challenger.Challenge(v.data)
453			if err != nil {
454				return err
455			}
456
457			req = &writeAuthResponseFrame{
458				data: resp,
459			}
460		default:
461			return fmt.Errorf("unknown frame response during authentication: %v", v)
462		}
463	}
464}
465
466func (c *Conn) closeWithError(err error) {
467	if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
468		return
469	}
470
471	// we should attempt to deliver the error back to the caller if it
472	// exists
473	if err != nil {
474		c.mu.Lock()
475		for _, req := range c.calls {
476			// we need to send the error to all waiting queries, put the state
477			// of this conn into not active so that it can not execute any queries.
478			select {
479			case req.resp <- err:
480			case <-req.timeout:
481			}
482		}
483		c.mu.Unlock()
484	}
485
486	// if error was nil then unblock the quit channel
487	close(c.quit)
488	cerr := c.close()
489
490	if err != nil {
491		c.errorHandler.HandleError(c, err, true)
492	} else if cerr != nil {
493		// TODO(zariel): is it a good idea to do this?
494		c.errorHandler.HandleError(c, cerr, true)
495	}
496}
497
498func (c *Conn) close() error {
499	return c.conn.Close()
500}
501
502func (c *Conn) Close() {
503	c.closeWithError(nil)
504}
505
506// Serve starts the stream multiplexer for this connection, which is required
507// to execute any queries. This method runs as long as the connection is
508// open and is therefore usually called in a separate goroutine.
509func (c *Conn) serve() {
510	var err error
511	for err == nil {
512		err = c.recv()
513	}
514
515	c.closeWithError(err)
516}
517
518func (c *Conn) discardFrame(head frameHeader) error {
519	_, err := io.CopyN(ioutil.Discard, c, int64(head.length))
520	if err != nil {
521		return err
522	}
523	return nil
524}
525
526type protocolError struct {
527	frame frame
528}
529
530func (p *protocolError) Error() string {
531	if err, ok := p.frame.(error); ok {
532		return err.Error()
533	}
534	return fmt.Sprintf("gocql: received unexpected frame on stream %d: %v", p.frame.Header().stream, p.frame)
535}
536
537func (c *Conn) heartBeat() {
538	sleepTime := 1 * time.Second
539	timer := time.NewTimer(sleepTime)
540	defer timer.Stop()
541
542	var failures int
543
544	for {
545		if failures > 5 {
546			c.closeWithError(fmt.Errorf("gocql: heartbeat failed"))
547			return
548		}
549
550		timer.Reset(sleepTime)
551
552		select {
553		case <-c.quit:
554			return
555		case <-timer.C:
556		}
557
558		framer, err := c.exec(context.Background(), &writeOptionsFrame{}, nil)
559		if err != nil {
560			failures++
561			continue
562		}
563
564		resp, err := framer.parseFrame()
565		if err != nil {
566			// invalid frame
567			failures++
568			continue
569		}
570
571		switch resp.(type) {
572		case *supportedFrame:
573			// Everything ok
574			sleepTime = 5 * time.Second
575			failures = 0
576		case error:
577			// TODO: should we do something here?
578		default:
579			panic(fmt.Sprintf("gocql: unknown frame in response to options: %T", resp))
580		}
581	}
582}
583
584func (c *Conn) recv() error {
585	// not safe for concurrent reads
586
587	// read a full header, ignore timeouts, as this is being ran in a loop
588	// TODO: TCP level deadlines? or just query level deadlines?
589	if c.timeout > 0 {
590		c.conn.SetReadDeadline(time.Time{})
591	}
592
593	headStartTime := time.Now()
594	// were just reading headers over and over and copy bodies
595	head, err := readHeader(c.r, c.headerBuf[:])
596	headEndTime := time.Now()
597	if err != nil {
598		return err
599	}
600
601	if c.frameObserver != nil {
602		c.frameObserver.ObserveFrameHeader(context.Background(), ObservedFrameHeader{
603			Version: protoVersion(head.version),
604			Flags:   head.flags,
605			Stream:  int16(head.stream),
606			Opcode:  frameOp(head.op),
607			Length:  int32(head.length),
608			Start:   headStartTime,
609			End:     headEndTime,
610		})
611	}
612
613	if head.stream > c.streams.NumStreams {
614		return fmt.Errorf("gocql: frame header stream is beyond call expected bounds: %d", head.stream)
615	} else if head.stream == -1 {
616		// TODO: handle cassandra event frames, we shouldnt get any currently
617		framer := newFramer(c, c, c.compressor, c.version)
618		if err := framer.readFrame(&head); err != nil {
619			return err
620		}
621		go c.session.handleEvent(framer)
622		return nil
623	} else if head.stream <= 0 {
624		// reserved stream that we dont use, probably due to a protocol error
625		// or a bug in Cassandra, this should be an error, parse it and return.
626		framer := newFramer(c, c, c.compressor, c.version)
627		if err := framer.readFrame(&head); err != nil {
628			return err
629		}
630
631		frame, err := framer.parseFrame()
632		if err != nil {
633			return err
634		}
635
636		return &protocolError{
637			frame: frame,
638		}
639	}
640
641	c.mu.Lock()
642	call, ok := c.calls[head.stream]
643	delete(c.calls, head.stream)
644	c.mu.Unlock()
645	if call == nil || call.framer == nil || !ok {
646		Logger.Printf("gocql: received response for stream which has no handler: header=%v\n", head)
647		return c.discardFrame(head)
648	} else if head.stream != call.streamID {
649		panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.stream))
650	}
651
652	err = call.framer.readFrame(&head)
653	if err != nil {
654		// only net errors should cause the connection to be closed. Though
655		// cassandra returning corrupt frames will be returned here as well.
656		if _, ok := err.(net.Error); ok {
657			return err
658		}
659	}
660
661	// we either, return a response to the caller, the caller timedout, or the
662	// connection has closed. Either way we should never block indefinatly here
663	select {
664	case call.resp <- err:
665	case <-call.timeout:
666		c.releaseStream(call)
667	case <-c.quit:
668	}
669
670	return nil
671}
672
673func (c *Conn) releaseStream(call *callReq) {
674	if call.timer != nil {
675		call.timer.Stop()
676	}
677
678	c.streams.Clear(call.streamID)
679}
680
681func (c *Conn) handleTimeout() {
682	if TimeoutLimit > 0 && atomic.AddInt64(&c.timeouts, 1) > TimeoutLimit {
683		c.closeWithError(ErrTooManyTimeouts)
684	}
685}
686
687type callReq struct {
688	// could use a waitgroup but this allows us to do timeouts on the read/send
689	resp     chan error
690	framer   *framer
691	timeout  chan struct{} // indicates to recv() that a call has timedout
692	streamID int           // current stream in use
693
694	timer *time.Timer
695}
696
697type deadlineWriter struct {
698	w interface {
699		SetWriteDeadline(time.Time) error
700		io.Writer
701	}
702	timeout time.Duration
703}
704
705func (c *deadlineWriter) Write(p []byte) (int, error) {
706	if c.timeout > 0 {
707		c.w.SetWriteDeadline(time.Now().Add(c.timeout))
708	}
709	return c.w.Write(p)
710}
711
712func newWriteCoalescer(conn net.Conn, timeout time.Duration, d time.Duration, quit <-chan struct{}) *writeCoalescer {
713	wc := &writeCoalescer{
714		writeCh: make(chan struct{}), // TODO: could this be sync?
715		cond:    sync.NewCond(&sync.Mutex{}),
716		c:       conn,
717		quit:    quit,
718		timeout: timeout,
719	}
720	go wc.writeFlusher(d)
721	return wc
722}
723
724type writeCoalescer struct {
725	c net.Conn
726
727	quit    <-chan struct{}
728	writeCh chan struct{}
729	running bool
730
731	// cond waits for the buffer to be flushed
732	cond    *sync.Cond
733	buffers net.Buffers
734	timeout time.Duration
735
736	// result of the write
737	err error
738}
739
740func (w *writeCoalescer) flushLocked() {
741	w.running = false
742	if len(w.buffers) == 0 {
743		return
744	}
745
746	if w.timeout > 0 {
747		w.c.SetWriteDeadline(time.Now().Add(w.timeout))
748	}
749
750	// Given we are going to do a fanout n is useless and according to
751	// the docs WriteTo should return 0 and err or bytes written and
752	// no error.
753	_, w.err = w.buffers.WriteTo(w.c)
754	if w.err != nil {
755		w.buffers = nil
756	}
757	w.cond.Broadcast()
758}
759
760func (w *writeCoalescer) flush() {
761	w.cond.L.Lock()
762	w.flushLocked()
763	w.cond.L.Unlock()
764}
765
766func (w *writeCoalescer) stop() {
767	w.cond.L.Lock()
768	defer w.cond.L.Unlock()
769
770	w.flushLocked()
771	// nil the channel out sends block forever on it
772	// instead of closing which causes a send on closed channel
773	// panic.
774	w.writeCh = nil
775}
776
777func (w *writeCoalescer) Write(p []byte) (int, error) {
778	w.cond.L.Lock()
779
780	if !w.running {
781		select {
782		case w.writeCh <- struct{}{}:
783			w.running = true
784		case <-w.quit:
785			w.cond.L.Unlock()
786			return 0, io.EOF // TODO: better error here?
787		}
788	}
789
790	w.buffers = append(w.buffers, p)
791	for len(w.buffers) != 0 {
792		w.cond.Wait()
793	}
794
795	err := w.err
796	w.cond.L.Unlock()
797
798	if err != nil {
799		return 0, err
800	}
801	return len(p), nil
802}
803
804func (w *writeCoalescer) writeFlusher(interval time.Duration) {
805	timer := time.NewTimer(interval)
806	defer timer.Stop()
807	defer w.stop()
808
809	if !timer.Stop() {
810		<-timer.C
811	}
812
813	for {
814		// wait for a write to start the flush loop
815		select {
816		case <-w.writeCh:
817		case <-w.quit:
818			return
819		}
820
821		timer.Reset(interval)
822
823		select {
824		case <-w.quit:
825			return
826		case <-timer.C:
827		}
828
829		w.flush()
830	}
831}
832
833func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) {
834	// TODO: move tracer onto conn
835	stream, ok := c.streams.GetStream()
836	if !ok {
837		return nil, ErrNoStreams
838	}
839
840	// resp is basically a waiting semaphore protecting the framer
841	framer := newFramer(c, c, c.compressor, c.version)
842
843	call := &callReq{
844		framer:   framer,
845		timeout:  make(chan struct{}),
846		streamID: stream,
847		resp:     make(chan error),
848	}
849
850	c.mu.Lock()
851	existingCall := c.calls[stream]
852	if existingCall == nil {
853		c.calls[stream] = call
854	}
855	c.mu.Unlock()
856
857	if existingCall != nil {
858		return nil, fmt.Errorf("attempting to use stream already in use: %d -> %d", stream, existingCall.streamID)
859	}
860
861	if tracer != nil {
862		framer.trace()
863	}
864
865	err := req.writeFrame(framer, stream)
866	if err != nil {
867		// closeWithError will block waiting for this stream to either receive a response
868		// or for us to timeout, close the timeout chan here. Im not entirely sure
869		// but we should not get a response after an error on the write side.
870		close(call.timeout)
871		// I think this is the correct thing to do, im not entirely sure. It is not
872		// ideal as readers might still get some data, but they probably wont.
873		// Here we need to be careful as the stream is not available and if all
874		// writes just timeout or fail then the pool might use this connection to
875		// send a frame on, with all the streams used up and not returned.
876		c.closeWithError(err)
877		return nil, err
878	}
879
880	var timeoutCh <-chan time.Time
881	if c.timeout > 0 {
882		if call.timer == nil {
883			call.timer = time.NewTimer(0)
884			<-call.timer.C
885		} else {
886			if !call.timer.Stop() {
887				select {
888				case <-call.timer.C:
889				default:
890				}
891			}
892		}
893
894		call.timer.Reset(c.timeout)
895		timeoutCh = call.timer.C
896	}
897
898	var ctxDone <-chan struct{}
899	if ctx != nil {
900		ctxDone = ctx.Done()
901	}
902
903	select {
904	case err := <-call.resp:
905		close(call.timeout)
906		if err != nil {
907			if !c.Closed() {
908				// if the connection is closed then we cant release the stream,
909				// this is because the request is still outstanding and we have
910				// been handed another error from another stream which caused the
911				// connection to close.
912				c.releaseStream(call)
913			}
914			return nil, err
915		}
916	case <-timeoutCh:
917		close(call.timeout)
918		c.handleTimeout()
919		return nil, ErrTimeoutNoResponse
920	case <-ctxDone:
921		close(call.timeout)
922		return nil, ctx.Err()
923	case <-c.quit:
924		return nil, ErrConnectionClosed
925	}
926
927	// dont release the stream if detect a timeout as another request can reuse
928	// that stream and get a response for the old request, which we have no
929	// easy way of detecting.
930	//
931	// Ensure that the stream is not released if there are potentially outstanding
932	// requests on the stream to prevent nil pointer dereferences in recv().
933	defer c.releaseStream(call)
934
935	if v := framer.header.version.version(); v != c.version {
936		return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
937	}
938
939	return framer, nil
940}
941
942type preparedStatment struct {
943	id       []byte
944	request  preparedMetadata
945	response resultMetadata
946}
947
948type inflightPrepare struct {
949	wg  sync.WaitGroup
950	err error
951
952	preparedStatment *preparedStatment
953}
954
955func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) {
956	stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, stmt)
957	flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare {
958		flight := new(inflightPrepare)
959		flight.wg.Add(1)
960		lru.Add(stmtCacheKey, flight)
961		return flight
962	})
963
964	if ok {
965		flight.wg.Wait()
966		return flight.preparedStatment, flight.err
967	}
968
969	prep := &writePrepareFrame{
970		statement: stmt,
971	}
972	if c.version > protoVersion4 {
973		prep.keyspace = c.currentKeyspace
974	}
975
976	framer, err := c.exec(ctx, prep, tracer)
977	if err != nil {
978		flight.err = err
979		flight.wg.Done()
980		c.session.stmtsLRU.remove(stmtCacheKey)
981		return nil, err
982	}
983
984	frame, err := framer.parseFrame()
985	if err != nil {
986		flight.err = err
987		flight.wg.Done()
988		c.session.stmtsLRU.remove(stmtCacheKey)
989		return nil, err
990	}
991
992	// TODO(zariel): tidy this up, simplify handling of frame parsing so its not duplicated
993	// everytime we need to parse a frame.
994	if len(framer.traceID) > 0 && tracer != nil {
995		tracer.Trace(framer.traceID)
996	}
997
998	switch x := frame.(type) {
999	case *resultPreparedFrame:
1000		flight.preparedStatment = &preparedStatment{
1001			// defensively copy as we will recycle the underlying buffer after we
1002			// return.
1003			id: copyBytes(x.preparedID),
1004			// the type info's should _not_ have a reference to the framers read buffer,
1005			// therefore we can just copy them directly.
1006			request:  x.reqMeta,
1007			response: x.respMeta,
1008		}
1009	case error:
1010		flight.err = x
1011	default:
1012		flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x)
1013	}
1014	flight.wg.Done()
1015
1016	if flight.err != nil {
1017		c.session.stmtsLRU.remove(stmtCacheKey)
1018	}
1019
1020	return flight.preparedStatment, flight.err
1021}
1022
1023func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error {
1024	if named, ok := value.(*namedValue); ok {
1025		dst.name = named.name
1026		value = named.value
1027	}
1028
1029	if _, ok := value.(unsetColumn); !ok {
1030		val, err := Marshal(typ, value)
1031		if err != nil {
1032			return err
1033		}
1034
1035		dst.value = val
1036	} else {
1037		dst.isUnset = true
1038	}
1039
1040	return nil
1041}
1042
1043func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
1044	params := queryParams{
1045		consistency: qry.cons,
1046	}
1047
1048	// frame checks that it is not 0
1049	params.serialConsistency = qry.serialCons
1050	params.defaultTimestamp = qry.defaultTimestamp
1051	params.defaultTimestampValue = qry.defaultTimestampValue
1052
1053	if len(qry.pageState) > 0 {
1054		params.pagingState = qry.pageState
1055	}
1056	if qry.pageSize > 0 {
1057		params.pageSize = qry.pageSize
1058	}
1059	if c.version > protoVersion4 {
1060		params.keyspace = c.currentKeyspace
1061	}
1062
1063	var (
1064		frame frameWriter
1065		info  *preparedStatment
1066	)
1067
1068	if qry.shouldPrepare() {
1069		// Prepare all DML queries. Other queries can not be prepared.
1070		var err error
1071		info, err = c.prepareStatement(ctx, qry.stmt, qry.trace)
1072		if err != nil {
1073			return &Iter{err: err}
1074		}
1075
1076		var values []interface{}
1077
1078		if qry.binding == nil {
1079			values = qry.values
1080		} else {
1081			values, err = qry.binding(&QueryInfo{
1082				Id:          info.id,
1083				Args:        info.request.columns,
1084				Rval:        info.response.columns,
1085				PKeyColumns: info.request.pkeyColumns,
1086			})
1087
1088			if err != nil {
1089				return &Iter{err: err}
1090			}
1091		}
1092
1093		if len(values) != info.request.actualColCount {
1094			return &Iter{err: fmt.Errorf("gocql: expected %d values send got %d", info.request.actualColCount, len(values))}
1095		}
1096
1097		params.values = make([]queryValues, len(values))
1098		for i := 0; i < len(values); i++ {
1099			v := &params.values[i]
1100			value := values[i]
1101			typ := info.request.columns[i].TypeInfo
1102			if err := marshalQueryValue(typ, value, v); err != nil {
1103				return &Iter{err: err}
1104			}
1105		}
1106
1107		params.skipMeta = !(c.session.cfg.DisableSkipMetadata || qry.disableSkipMetadata)
1108
1109		frame = &writeExecuteFrame{
1110			preparedID:    info.id,
1111			params:        params,
1112			customPayload: qry.customPayload,
1113		}
1114	} else {
1115		frame = &writeQueryFrame{
1116			statement:     qry.stmt,
1117			params:        params,
1118			customPayload: qry.customPayload,
1119		}
1120	}
1121
1122	framer, err := c.exec(ctx, frame, qry.trace)
1123	if err != nil {
1124		return &Iter{err: err}
1125	}
1126
1127	resp, err := framer.parseFrame()
1128	if err != nil {
1129		return &Iter{err: err}
1130	}
1131
1132	if len(framer.traceID) > 0 && qry.trace != nil {
1133		qry.trace.Trace(framer.traceID)
1134	}
1135
1136	switch x := resp.(type) {
1137	case *resultVoidFrame:
1138		return &Iter{framer: framer}
1139	case *resultRowsFrame:
1140		iter := &Iter{
1141			meta:    x.meta,
1142			framer:  framer,
1143			numRows: x.numRows,
1144		}
1145
1146		if params.skipMeta {
1147			if info != nil {
1148				iter.meta = info.response
1149				iter.meta.pagingState = copyBytes(x.meta.pagingState)
1150			} else {
1151				return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")}
1152			}
1153		} else {
1154			iter.meta = x.meta
1155		}
1156
1157		if x.meta.morePages() && !qry.disableAutoPage {
1158			iter.next = &nextIter{
1159				qry: qry,
1160				pos: int((1 - qry.prefetch) * float64(x.numRows)),
1161			}
1162
1163			iter.next.qry.pageState = copyBytes(x.meta.pagingState)
1164			if iter.next.pos < 1 {
1165				iter.next.pos = 1
1166			}
1167		}
1168
1169		return iter
1170	case *resultKeyspaceFrame:
1171		return &Iter{framer: framer}
1172	case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction, *schemaChangeAggregate, *schemaChangeType:
1173		iter := &Iter{framer: framer}
1174		if err := c.awaitSchemaAgreement(ctx); err != nil {
1175			// TODO: should have this behind a flag
1176			Logger.Println(err)
1177		}
1178		// dont return an error from this, might be a good idea to give a warning
1179		// though. The impact of this returning an error would be that the cluster
1180		// is not consistent with regards to its schema.
1181		return iter
1182	case *RequestErrUnprepared:
1183		stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, qry.stmt)
1184		if c.session.stmtsLRU.remove(stmtCacheKey) {
1185			return c.executeQuery(ctx, qry)
1186		}
1187
1188		return &Iter{err: x, framer: framer}
1189	case error:
1190		return &Iter{err: x, framer: framer}
1191	default:
1192		return &Iter{
1193			err:    NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x),
1194			framer: framer,
1195		}
1196	}
1197}
1198
1199func (c *Conn) Pick(qry *Query) *Conn {
1200	if c.Closed() {
1201		return nil
1202	}
1203	return c
1204}
1205
1206func (c *Conn) Closed() bool {
1207	return atomic.LoadInt32(&c.closed) == 1
1208}
1209
1210func (c *Conn) Address() string {
1211	return c.addr
1212}
1213
1214func (c *Conn) AvailableStreams() int {
1215	return c.streams.Available()
1216}
1217
1218func (c *Conn) UseKeyspace(keyspace string) error {
1219	q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
1220	q.params.consistency = Any
1221
1222	framer, err := c.exec(context.Background(), q, nil)
1223	if err != nil {
1224		return err
1225	}
1226
1227	resp, err := framer.parseFrame()
1228	if err != nil {
1229		return err
1230	}
1231
1232	switch x := resp.(type) {
1233	case *resultKeyspaceFrame:
1234	case error:
1235		return x
1236	default:
1237		return NewErrProtocol("unknown frame in response to USE: %v", x)
1238	}
1239
1240	c.currentKeyspace = keyspace
1241
1242	return nil
1243}
1244
1245func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
1246	if c.version == protoVersion1 {
1247		return &Iter{err: ErrUnsupported}
1248	}
1249
1250	n := len(batch.Entries)
1251	req := &writeBatchFrame{
1252		typ:                   batch.Type,
1253		statements:            make([]batchStatment, n),
1254		consistency:           batch.Cons,
1255		serialConsistency:     batch.serialCons,
1256		defaultTimestamp:      batch.defaultTimestamp,
1257		defaultTimestampValue: batch.defaultTimestampValue,
1258		customPayload:         batch.CustomPayload,
1259	}
1260
1261	stmts := make(map[string]string, len(batch.Entries))
1262
1263	for i := 0; i < n; i++ {
1264		entry := &batch.Entries[i]
1265		b := &req.statements[i]
1266
1267		if len(entry.Args) > 0 || entry.binding != nil {
1268			info, err := c.prepareStatement(batch.Context(), entry.Stmt, nil)
1269			if err != nil {
1270				return &Iter{err: err}
1271			}
1272
1273			var values []interface{}
1274			if entry.binding == nil {
1275				values = entry.Args
1276			} else {
1277				values, err = entry.binding(&QueryInfo{
1278					Id:          info.id,
1279					Args:        info.request.columns,
1280					Rval:        info.response.columns,
1281					PKeyColumns: info.request.pkeyColumns,
1282				})
1283				if err != nil {
1284					return &Iter{err: err}
1285				}
1286			}
1287
1288			if len(values) != info.request.actualColCount {
1289				return &Iter{err: fmt.Errorf("gocql: batch statement %d expected %d values send got %d", i, info.request.actualColCount, len(values))}
1290			}
1291
1292			b.preparedID = info.id
1293			stmts[string(info.id)] = entry.Stmt
1294
1295			b.values = make([]queryValues, info.request.actualColCount)
1296
1297			for j := 0; j < info.request.actualColCount; j++ {
1298				v := &b.values[j]
1299				value := values[j]
1300				typ := info.request.columns[j].TypeInfo
1301				if err := marshalQueryValue(typ, value, v); err != nil {
1302					return &Iter{err: err}
1303				}
1304			}
1305		} else {
1306			b.statement = entry.Stmt
1307		}
1308	}
1309
1310	// TODO: should batch support tracing?
1311	framer, err := c.exec(batch.Context(), req, nil)
1312	if err != nil {
1313		return &Iter{err: err}
1314	}
1315
1316	resp, err := framer.parseFrame()
1317	if err != nil {
1318		return &Iter{err: err, framer: framer}
1319	}
1320
1321	switch x := resp.(type) {
1322	case *resultVoidFrame:
1323		return &Iter{}
1324	case *RequestErrUnprepared:
1325		stmt, found := stmts[string(x.StatementId)]
1326		if found {
1327			key := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, stmt)
1328			c.session.stmtsLRU.remove(key)
1329		}
1330
1331		if found {
1332			return c.executeBatch(ctx, batch)
1333		} else {
1334			return &Iter{err: x, framer: framer}
1335		}
1336	case *resultRowsFrame:
1337		iter := &Iter{
1338			meta:    x.meta,
1339			framer:  framer,
1340			numRows: x.numRows,
1341		}
1342
1343		return iter
1344	case error:
1345		return &Iter{err: x, framer: framer}
1346	default:
1347		return &Iter{err: NewErrProtocol("Unknown type in response to batch statement: %s", x), framer: framer}
1348	}
1349}
1350
1351func (c *Conn) query(ctx context.Context, statement string, values ...interface{}) (iter *Iter) {
1352	q := c.session.Query(statement, values...).Consistency(One)
1353	q.trace = nil
1354	return c.executeQuery(ctx, q)
1355}
1356
1357func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) {
1358	const (
1359		peerSchemas  = "SELECT schema_version, peer FROM system.peers"
1360		localSchemas = "SELECT schema_version FROM system.local WHERE key='local'"
1361	)
1362
1363	var versions map[string]struct{}
1364
1365	endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement)
1366	for time.Now().Before(endDeadline) {
1367		iter := c.query(ctx, peerSchemas)
1368
1369		versions = make(map[string]struct{})
1370
1371		var schemaVersion string
1372		var peer string
1373		for iter.Scan(&schemaVersion, &peer) {
1374			if schemaVersion == "" {
1375				Logger.Printf("skipping peer entry with empty schema_version: peer=%q", peer)
1376				continue
1377			}
1378
1379			versions[schemaVersion] = struct{}{}
1380			schemaVersion = ""
1381		}
1382
1383		if err = iter.Close(); err != nil {
1384			goto cont
1385		}
1386
1387		iter = c.query(ctx, localSchemas)
1388		for iter.Scan(&schemaVersion) {
1389			versions[schemaVersion] = struct{}{}
1390			schemaVersion = ""
1391		}
1392
1393		if err = iter.Close(); err != nil {
1394			goto cont
1395		}
1396
1397		if len(versions) <= 1 {
1398			return nil
1399		}
1400
1401	cont:
1402		select {
1403		case <-ctx.Done():
1404			return ctx.Err()
1405		case <-time.After(200 * time.Millisecond):
1406		}
1407	}
1408
1409	if err != nil {
1410		return err
1411	}
1412
1413	schemas := make([]string, 0, len(versions))
1414	for schema := range versions {
1415		schemas = append(schemas, schema)
1416	}
1417
1418	// not exported
1419	return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas)
1420}
1421
1422func (c *Conn) localHostInfo(ctx context.Context) (*HostInfo, error) {
1423	row, err := c.query(ctx, "SELECT * FROM system.local WHERE key='local'").rowMap()
1424	if err != nil {
1425		return nil, err
1426	}
1427
1428	port := c.conn.RemoteAddr().(*net.TCPAddr).Port
1429
1430	// TODO(zariel): avoid doing this here
1431	host, err := c.session.hostInfoFromMap(row, port)
1432	if err != nil {
1433		return nil, err
1434	}
1435
1436	return c.session.ring.addOrUpdate(host), nil
1437}
1438
1439var (
1440	ErrQueryArgLength    = errors.New("gocql: query argument length mismatch")
1441	ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period")
1442	ErrTooManyTimeouts   = errors.New("gocql: too many query timeouts on the connection")
1443	ErrConnectionClosed  = errors.New("gocql: connection closed waiting for response")
1444	ErrNoStreams         = errors.New("gocql: no streams available on connection")
1445)
1446