1// Copyright (c) 2012 The gocql Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package gocql 6 7import ( 8 "bufio" 9 "context" 10 "crypto/tls" 11 "errors" 12 "fmt" 13 "io" 14 "io/ioutil" 15 "net" 16 "strconv" 17 "strings" 18 "sync" 19 "sync/atomic" 20 "time" 21 22 "github.com/gocql/gocql/internal/lru" 23 "github.com/gocql/gocql/internal/streams" 24) 25 26var ( 27 approvedAuthenticators = [...]string{ 28 "org.apache.cassandra.auth.PasswordAuthenticator", 29 "com.instaclustr.cassandra.auth.SharedSecretAuthenticator", 30 "com.datastax.bdp.cassandra.auth.DseAuthenticator", 31 "io.aiven.cassandra.auth.AivenAuthenticator", 32 } 33) 34 35func approve(authenticator string) bool { 36 for _, s := range approvedAuthenticators { 37 if authenticator == s { 38 return true 39 } 40 } 41 return false 42} 43 44//JoinHostPort is a utility to return a address string that can be used 45//gocql.Conn to form a connection with a host. 46func JoinHostPort(addr string, port int) string { 47 addr = strings.TrimSpace(addr) 48 if _, _, err := net.SplitHostPort(addr); err != nil { 49 addr = net.JoinHostPort(addr, strconv.Itoa(port)) 50 } 51 return addr 52} 53 54type Authenticator interface { 55 Challenge(req []byte) (resp []byte, auth Authenticator, err error) 56 Success(data []byte) error 57} 58 59type PasswordAuthenticator struct { 60 Username string 61 Password string 62} 63 64func (p PasswordAuthenticator) Challenge(req []byte) ([]byte, Authenticator, error) { 65 if !approve(string(req)) { 66 return nil, nil, fmt.Errorf("unexpected authenticator %q", req) 67 } 68 resp := make([]byte, 2+len(p.Username)+len(p.Password)) 69 resp[0] = 0 70 copy(resp[1:], p.Username) 71 resp[len(p.Username)+1] = 0 72 copy(resp[2+len(p.Username):], p.Password) 73 return resp, nil, nil 74} 75 76func (p PasswordAuthenticator) Success(data []byte) error { 77 return nil 78} 79 80type SslOptions struct { 81 *tls.Config 82 83 // CertPath and KeyPath are optional depending on server 84 // config, but both fields must be omitted to avoid using a 85 // client certificate 86 CertPath string 87 KeyPath string 88 CaPath string //optional depending on server config 89 // If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this on 90 // This option is basically the inverse of InSecureSkipVerify 91 // See InSecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info 92 EnableHostVerification bool 93} 94 95type ConnConfig struct { 96 ProtoVersion int 97 CQLVersion string 98 Timeout time.Duration 99 ConnectTimeout time.Duration 100 Compressor Compressor 101 Authenticator Authenticator 102 AuthProvider func(h *HostInfo) (Authenticator, error) 103 Keepalive time.Duration 104 105 tlsConfig *tls.Config 106 disableCoalesce bool 107} 108 109type ConnErrorHandler interface { 110 HandleError(conn *Conn, err error, closed bool) 111} 112 113type connErrorHandlerFn func(conn *Conn, err error, closed bool) 114 115func (fn connErrorHandlerFn) HandleError(conn *Conn, err error, closed bool) { 116 fn(conn, err, closed) 117} 118 119// If not zero, how many timeouts we will allow to occur before the connection is closed 120// and restarted. This is to prevent a single query timeout from killing a connection 121// which may be serving more queries just fine. 122// Default is 0, should not be changed concurrently with queries. 123// 124// depreciated 125var TimeoutLimit int64 = 0 126 127// Conn is a single connection to a Cassandra node. It can be used to execute 128// queries, but users are usually advised to use a more reliable, higher 129// level API. 130type Conn struct { 131 conn net.Conn 132 r *bufio.Reader 133 w io.Writer 134 135 timeout time.Duration 136 cfg *ConnConfig 137 frameObserver FrameHeaderObserver 138 139 headerBuf [maxFrameHeaderSize]byte 140 141 streams *streams.IDGenerator 142 mu sync.Mutex 143 calls map[int]*callReq 144 145 errorHandler ConnErrorHandler 146 compressor Compressor 147 auth Authenticator 148 addr string 149 150 version uint8 151 currentKeyspace string 152 host *HostInfo 153 154 session *Session 155 156 closed int32 157 quit chan struct{} 158 159 timeouts int64 160} 161 162// connect establishes a connection to a Cassandra node using session's connection config. 163func (s *Session) connect(host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) { 164 return s.dial(host, s.connCfg, errorHandler) 165} 166 167// dial establishes a connection to a Cassandra node and notifies the session's connectObserver. 168func (s *Session) dial(host *HostInfo, connConfig *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) { 169 var obs ObservedConnect 170 if s.connectObserver != nil { 171 obs.Host = host 172 obs.Start = time.Now() 173 } 174 175 conn, err := s.dialWithoutObserver(host, connConfig, errorHandler) 176 177 if s.connectObserver != nil { 178 obs.End = time.Now() 179 obs.Err = err 180 s.connectObserver.ObserveConnect(obs) 181 } 182 183 return conn, err 184} 185 186// dialWithoutObserver establishes connection to a Cassandra node. 187// 188// dialWithoutObserver does not notify the connection observer, so you most probably want to call dial() instead. 189func (s *Session) dialWithoutObserver(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) { 190 ip := host.ConnectAddress() 191 port := host.port 192 193 // TODO(zariel): remove these 194 if len(ip) == 0 || ip.IsUnspecified() { 195 panic(fmt.Sprintf("host missing connect ip address: %v", ip)) 196 } else if port == 0 { 197 panic(fmt.Sprintf("host missing port: %v", port)) 198 } 199 200 var ( 201 err error 202 conn net.Conn 203 ) 204 205 dialer := &net.Dialer{ 206 Timeout: cfg.ConnectTimeout, 207 } 208 if cfg.Keepalive > 0 { 209 dialer.KeepAlive = cfg.Keepalive 210 } 211 212 // TODO(zariel): handle ipv6 zone 213 addr := (&net.TCPAddr{IP: ip, Port: port}).String() 214 215 if cfg.tlsConfig != nil { 216 // the TLS config is safe to be reused by connections but it must not 217 // be modified after being used. 218 conn, err = tls.DialWithDialer(dialer, "tcp", addr, cfg.tlsConfig) 219 } else { 220 conn, err = dialer.Dial("tcp", addr) 221 } 222 223 if err != nil { 224 return nil, err 225 } 226 227 c := &Conn{ 228 conn: conn, 229 r: bufio.NewReader(conn), 230 cfg: cfg, 231 calls: make(map[int]*callReq), 232 version: uint8(cfg.ProtoVersion), 233 addr: conn.RemoteAddr().String(), 234 errorHandler: errorHandler, 235 compressor: cfg.Compressor, 236 quit: make(chan struct{}), 237 session: s, 238 streams: streams.New(cfg.ProtoVersion), 239 host: host, 240 frameObserver: s.frameObserver, 241 w: &deadlineWriter{ 242 w: conn, 243 timeout: cfg.Timeout, 244 }, 245 } 246 247 if cfg.AuthProvider != nil { 248 c.auth, err = cfg.AuthProvider(host) 249 if err != nil { 250 return nil, err 251 } 252 } else { 253 c.auth = cfg.Authenticator 254 } 255 256 var ( 257 ctx context.Context 258 cancel func() 259 ) 260 if cfg.ConnectTimeout > 0 { 261 ctx, cancel = context.WithTimeout(context.TODO(), cfg.ConnectTimeout) 262 } else { 263 ctx, cancel = context.WithCancel(context.TODO()) 264 } 265 defer cancel() 266 267 startup := &startupCoordinator{ 268 frameTicker: make(chan struct{}), 269 conn: c, 270 } 271 272 c.timeout = cfg.ConnectTimeout 273 if err := startup.setupConn(ctx); err != nil { 274 c.close() 275 return nil, err 276 } 277 278 c.timeout = cfg.Timeout 279 280 // dont coalesce startup frames 281 if s.cfg.WriteCoalesceWaitTime > 0 && !cfg.disableCoalesce { 282 c.w = newWriteCoalescer(conn, c.timeout, s.cfg.WriteCoalesceWaitTime, c.quit) 283 } 284 285 go c.serve() 286 go c.heartBeat() 287 288 return c, nil 289} 290 291func (c *Conn) Write(p []byte) (n int, err error) { 292 return c.w.Write(p) 293} 294 295func (c *Conn) Read(p []byte) (n int, err error) { 296 const maxAttempts = 5 297 298 for i := 0; i < maxAttempts; i++ { 299 var nn int 300 if c.timeout > 0 { 301 c.conn.SetReadDeadline(time.Now().Add(c.timeout)) 302 } 303 304 nn, err = io.ReadFull(c.r, p[n:]) 305 n += nn 306 if err == nil { 307 break 308 } 309 310 if verr, ok := err.(net.Error); !ok || !verr.Temporary() { 311 break 312 } 313 } 314 315 return 316} 317 318type startupCoordinator struct { 319 conn *Conn 320 frameTicker chan struct{} 321} 322 323func (s *startupCoordinator) setupConn(ctx context.Context) error { 324 startupErr := make(chan error) 325 go func() { 326 for range s.frameTicker { 327 err := s.conn.recv() 328 if err != nil { 329 select { 330 case startupErr <- err: 331 case <-ctx.Done(): 332 } 333 334 return 335 } 336 } 337 }() 338 339 go func() { 340 defer close(s.frameTicker) 341 err := s.options(ctx) 342 select { 343 case startupErr <- err: 344 case <-ctx.Done(): 345 } 346 }() 347 348 select { 349 case err := <-startupErr: 350 if err != nil { 351 return err 352 } 353 case <-ctx.Done(): 354 return errors.New("gocql: no response to connection startup within timeout") 355 } 356 357 return nil 358} 359 360func (s *startupCoordinator) write(ctx context.Context, frame frameWriter) (frame, error) { 361 select { 362 case s.frameTicker <- struct{}{}: 363 case <-ctx.Done(): 364 return nil, ctx.Err() 365 } 366 367 framer, err := s.conn.exec(ctx, frame, nil) 368 if err != nil { 369 return nil, err 370 } 371 372 return framer.parseFrame() 373} 374 375func (s *startupCoordinator) options(ctx context.Context) error { 376 frame, err := s.write(ctx, &writeOptionsFrame{}) 377 if err != nil { 378 return err 379 } 380 381 supported, ok := frame.(*supportedFrame) 382 if !ok { 383 return NewErrProtocol("Unknown type of response to startup frame: %T", frame) 384 } 385 386 return s.startup(ctx, supported.supported) 387} 388 389func (s *startupCoordinator) startup(ctx context.Context, supported map[string][]string) error { 390 m := map[string]string{ 391 "CQL_VERSION": s.conn.cfg.CQLVersion, 392 } 393 394 if s.conn.compressor != nil { 395 comp := supported["COMPRESSION"] 396 name := s.conn.compressor.Name() 397 for _, compressor := range comp { 398 if compressor == name { 399 m["COMPRESSION"] = compressor 400 break 401 } 402 } 403 404 if _, ok := m["COMPRESSION"]; !ok { 405 s.conn.compressor = nil 406 } 407 } 408 409 frame, err := s.write(ctx, &writeStartupFrame{opts: m}) 410 if err != nil { 411 return err 412 } 413 414 switch v := frame.(type) { 415 case error: 416 return v 417 case *readyFrame: 418 return nil 419 case *authenticateFrame: 420 return s.authenticateHandshake(ctx, v) 421 default: 422 return NewErrProtocol("Unknown type of response to startup frame: %s", v) 423 } 424} 425 426func (s *startupCoordinator) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame) error { 427 if s.conn.auth == nil { 428 return fmt.Errorf("authentication required (using %q)", authFrame.class) 429 } 430 431 resp, challenger, err := s.conn.auth.Challenge([]byte(authFrame.class)) 432 if err != nil { 433 return err 434 } 435 436 req := &writeAuthResponseFrame{data: resp} 437 for { 438 frame, err := s.write(ctx, req) 439 if err != nil { 440 return err 441 } 442 443 switch v := frame.(type) { 444 case error: 445 return v 446 case *authSuccessFrame: 447 if challenger != nil { 448 return challenger.Success(v.data) 449 } 450 return nil 451 case *authChallengeFrame: 452 resp, challenger, err = challenger.Challenge(v.data) 453 if err != nil { 454 return err 455 } 456 457 req = &writeAuthResponseFrame{ 458 data: resp, 459 } 460 default: 461 return fmt.Errorf("unknown frame response during authentication: %v", v) 462 } 463 } 464} 465 466func (c *Conn) closeWithError(err error) { 467 if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) { 468 return 469 } 470 471 // we should attempt to deliver the error back to the caller if it 472 // exists 473 if err != nil { 474 c.mu.Lock() 475 for _, req := range c.calls { 476 // we need to send the error to all waiting queries, put the state 477 // of this conn into not active so that it can not execute any queries. 478 select { 479 case req.resp <- err: 480 case <-req.timeout: 481 } 482 } 483 c.mu.Unlock() 484 } 485 486 // if error was nil then unblock the quit channel 487 close(c.quit) 488 cerr := c.close() 489 490 if err != nil { 491 c.errorHandler.HandleError(c, err, true) 492 } else if cerr != nil { 493 // TODO(zariel): is it a good idea to do this? 494 c.errorHandler.HandleError(c, cerr, true) 495 } 496} 497 498func (c *Conn) close() error { 499 return c.conn.Close() 500} 501 502func (c *Conn) Close() { 503 c.closeWithError(nil) 504} 505 506// Serve starts the stream multiplexer for this connection, which is required 507// to execute any queries. This method runs as long as the connection is 508// open and is therefore usually called in a separate goroutine. 509func (c *Conn) serve() { 510 var err error 511 for err == nil { 512 err = c.recv() 513 } 514 515 c.closeWithError(err) 516} 517 518func (c *Conn) discardFrame(head frameHeader) error { 519 _, err := io.CopyN(ioutil.Discard, c, int64(head.length)) 520 if err != nil { 521 return err 522 } 523 return nil 524} 525 526type protocolError struct { 527 frame frame 528} 529 530func (p *protocolError) Error() string { 531 if err, ok := p.frame.(error); ok { 532 return err.Error() 533 } 534 return fmt.Sprintf("gocql: received unexpected frame on stream %d: %v", p.frame.Header().stream, p.frame) 535} 536 537func (c *Conn) heartBeat() { 538 sleepTime := 1 * time.Second 539 timer := time.NewTimer(sleepTime) 540 defer timer.Stop() 541 542 var failures int 543 544 for { 545 if failures > 5 { 546 c.closeWithError(fmt.Errorf("gocql: heartbeat failed")) 547 return 548 } 549 550 timer.Reset(sleepTime) 551 552 select { 553 case <-c.quit: 554 return 555 case <-timer.C: 556 } 557 558 framer, err := c.exec(context.Background(), &writeOptionsFrame{}, nil) 559 if err != nil { 560 failures++ 561 continue 562 } 563 564 resp, err := framer.parseFrame() 565 if err != nil { 566 // invalid frame 567 failures++ 568 continue 569 } 570 571 switch resp.(type) { 572 case *supportedFrame: 573 // Everything ok 574 sleepTime = 5 * time.Second 575 failures = 0 576 case error: 577 // TODO: should we do something here? 578 default: 579 panic(fmt.Sprintf("gocql: unknown frame in response to options: %T", resp)) 580 } 581 } 582} 583 584func (c *Conn) recv() error { 585 // not safe for concurrent reads 586 587 // read a full header, ignore timeouts, as this is being ran in a loop 588 // TODO: TCP level deadlines? or just query level deadlines? 589 if c.timeout > 0 { 590 c.conn.SetReadDeadline(time.Time{}) 591 } 592 593 headStartTime := time.Now() 594 // were just reading headers over and over and copy bodies 595 head, err := readHeader(c.r, c.headerBuf[:]) 596 headEndTime := time.Now() 597 if err != nil { 598 return err 599 } 600 601 if c.frameObserver != nil { 602 c.frameObserver.ObserveFrameHeader(context.Background(), ObservedFrameHeader{ 603 Version: protoVersion(head.version), 604 Flags: head.flags, 605 Stream: int16(head.stream), 606 Opcode: frameOp(head.op), 607 Length: int32(head.length), 608 Start: headStartTime, 609 End: headEndTime, 610 }) 611 } 612 613 if head.stream > c.streams.NumStreams { 614 return fmt.Errorf("gocql: frame header stream is beyond call expected bounds: %d", head.stream) 615 } else if head.stream == -1 { 616 // TODO: handle cassandra event frames, we shouldnt get any currently 617 framer := newFramer(c, c, c.compressor, c.version) 618 if err := framer.readFrame(&head); err != nil { 619 return err 620 } 621 go c.session.handleEvent(framer) 622 return nil 623 } else if head.stream <= 0 { 624 // reserved stream that we dont use, probably due to a protocol error 625 // or a bug in Cassandra, this should be an error, parse it and return. 626 framer := newFramer(c, c, c.compressor, c.version) 627 if err := framer.readFrame(&head); err != nil { 628 return err 629 } 630 631 frame, err := framer.parseFrame() 632 if err != nil { 633 return err 634 } 635 636 return &protocolError{ 637 frame: frame, 638 } 639 } 640 641 c.mu.Lock() 642 call, ok := c.calls[head.stream] 643 delete(c.calls, head.stream) 644 c.mu.Unlock() 645 if call == nil || call.framer == nil || !ok { 646 Logger.Printf("gocql: received response for stream which has no handler: header=%v\n", head) 647 return c.discardFrame(head) 648 } else if head.stream != call.streamID { 649 panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.stream)) 650 } 651 652 err = call.framer.readFrame(&head) 653 if err != nil { 654 // only net errors should cause the connection to be closed. Though 655 // cassandra returning corrupt frames will be returned here as well. 656 if _, ok := err.(net.Error); ok { 657 return err 658 } 659 } 660 661 // we either, return a response to the caller, the caller timedout, or the 662 // connection has closed. Either way we should never block indefinatly here 663 select { 664 case call.resp <- err: 665 case <-call.timeout: 666 c.releaseStream(call) 667 case <-c.quit: 668 } 669 670 return nil 671} 672 673func (c *Conn) releaseStream(call *callReq) { 674 if call.timer != nil { 675 call.timer.Stop() 676 } 677 678 c.streams.Clear(call.streamID) 679} 680 681func (c *Conn) handleTimeout() { 682 if TimeoutLimit > 0 && atomic.AddInt64(&c.timeouts, 1) > TimeoutLimit { 683 c.closeWithError(ErrTooManyTimeouts) 684 } 685} 686 687type callReq struct { 688 // could use a waitgroup but this allows us to do timeouts on the read/send 689 resp chan error 690 framer *framer 691 timeout chan struct{} // indicates to recv() that a call has timedout 692 streamID int // current stream in use 693 694 timer *time.Timer 695} 696 697type deadlineWriter struct { 698 w interface { 699 SetWriteDeadline(time.Time) error 700 io.Writer 701 } 702 timeout time.Duration 703} 704 705func (c *deadlineWriter) Write(p []byte) (int, error) { 706 if c.timeout > 0 { 707 c.w.SetWriteDeadline(time.Now().Add(c.timeout)) 708 } 709 return c.w.Write(p) 710} 711 712func newWriteCoalescer(conn net.Conn, timeout time.Duration, d time.Duration, quit <-chan struct{}) *writeCoalescer { 713 wc := &writeCoalescer{ 714 writeCh: make(chan struct{}), // TODO: could this be sync? 715 cond: sync.NewCond(&sync.Mutex{}), 716 c: conn, 717 quit: quit, 718 timeout: timeout, 719 } 720 go wc.writeFlusher(d) 721 return wc 722} 723 724type writeCoalescer struct { 725 c net.Conn 726 727 quit <-chan struct{} 728 writeCh chan struct{} 729 running bool 730 731 // cond waits for the buffer to be flushed 732 cond *sync.Cond 733 buffers net.Buffers 734 timeout time.Duration 735 736 // result of the write 737 err error 738} 739 740func (w *writeCoalescer) flushLocked() { 741 w.running = false 742 if len(w.buffers) == 0 { 743 return 744 } 745 746 if w.timeout > 0 { 747 w.c.SetWriteDeadline(time.Now().Add(w.timeout)) 748 } 749 750 // Given we are going to do a fanout n is useless and according to 751 // the docs WriteTo should return 0 and err or bytes written and 752 // no error. 753 _, w.err = w.buffers.WriteTo(w.c) 754 if w.err != nil { 755 w.buffers = nil 756 } 757 w.cond.Broadcast() 758} 759 760func (w *writeCoalescer) flush() { 761 w.cond.L.Lock() 762 w.flushLocked() 763 w.cond.L.Unlock() 764} 765 766func (w *writeCoalescer) stop() { 767 w.cond.L.Lock() 768 defer w.cond.L.Unlock() 769 770 w.flushLocked() 771 // nil the channel out sends block forever on it 772 // instead of closing which causes a send on closed channel 773 // panic. 774 w.writeCh = nil 775} 776 777func (w *writeCoalescer) Write(p []byte) (int, error) { 778 w.cond.L.Lock() 779 780 if !w.running { 781 select { 782 case w.writeCh <- struct{}{}: 783 w.running = true 784 case <-w.quit: 785 w.cond.L.Unlock() 786 return 0, io.EOF // TODO: better error here? 787 } 788 } 789 790 w.buffers = append(w.buffers, p) 791 for len(w.buffers) != 0 { 792 w.cond.Wait() 793 } 794 795 err := w.err 796 w.cond.L.Unlock() 797 798 if err != nil { 799 return 0, err 800 } 801 return len(p), nil 802} 803 804func (w *writeCoalescer) writeFlusher(interval time.Duration) { 805 timer := time.NewTimer(interval) 806 defer timer.Stop() 807 defer w.stop() 808 809 if !timer.Stop() { 810 <-timer.C 811 } 812 813 for { 814 // wait for a write to start the flush loop 815 select { 816 case <-w.writeCh: 817 case <-w.quit: 818 return 819 } 820 821 timer.Reset(interval) 822 823 select { 824 case <-w.quit: 825 return 826 case <-timer.C: 827 } 828 829 w.flush() 830 } 831} 832 833func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) { 834 // TODO: move tracer onto conn 835 stream, ok := c.streams.GetStream() 836 if !ok { 837 return nil, ErrNoStreams 838 } 839 840 // resp is basically a waiting semaphore protecting the framer 841 framer := newFramer(c, c, c.compressor, c.version) 842 843 call := &callReq{ 844 framer: framer, 845 timeout: make(chan struct{}), 846 streamID: stream, 847 resp: make(chan error), 848 } 849 850 c.mu.Lock() 851 existingCall := c.calls[stream] 852 if existingCall == nil { 853 c.calls[stream] = call 854 } 855 c.mu.Unlock() 856 857 if existingCall != nil { 858 return nil, fmt.Errorf("attempting to use stream already in use: %d -> %d", stream, existingCall.streamID) 859 } 860 861 if tracer != nil { 862 framer.trace() 863 } 864 865 err := req.writeFrame(framer, stream) 866 if err != nil { 867 // closeWithError will block waiting for this stream to either receive a response 868 // or for us to timeout, close the timeout chan here. Im not entirely sure 869 // but we should not get a response after an error on the write side. 870 close(call.timeout) 871 // I think this is the correct thing to do, im not entirely sure. It is not 872 // ideal as readers might still get some data, but they probably wont. 873 // Here we need to be careful as the stream is not available and if all 874 // writes just timeout or fail then the pool might use this connection to 875 // send a frame on, with all the streams used up and not returned. 876 c.closeWithError(err) 877 return nil, err 878 } 879 880 var timeoutCh <-chan time.Time 881 if c.timeout > 0 { 882 if call.timer == nil { 883 call.timer = time.NewTimer(0) 884 <-call.timer.C 885 } else { 886 if !call.timer.Stop() { 887 select { 888 case <-call.timer.C: 889 default: 890 } 891 } 892 } 893 894 call.timer.Reset(c.timeout) 895 timeoutCh = call.timer.C 896 } 897 898 var ctxDone <-chan struct{} 899 if ctx != nil { 900 ctxDone = ctx.Done() 901 } 902 903 select { 904 case err := <-call.resp: 905 close(call.timeout) 906 if err != nil { 907 if !c.Closed() { 908 // if the connection is closed then we cant release the stream, 909 // this is because the request is still outstanding and we have 910 // been handed another error from another stream which caused the 911 // connection to close. 912 c.releaseStream(call) 913 } 914 return nil, err 915 } 916 case <-timeoutCh: 917 close(call.timeout) 918 c.handleTimeout() 919 return nil, ErrTimeoutNoResponse 920 case <-ctxDone: 921 close(call.timeout) 922 return nil, ctx.Err() 923 case <-c.quit: 924 return nil, ErrConnectionClosed 925 } 926 927 // dont release the stream if detect a timeout as another request can reuse 928 // that stream and get a response for the old request, which we have no 929 // easy way of detecting. 930 // 931 // Ensure that the stream is not released if there are potentially outstanding 932 // requests on the stream to prevent nil pointer dereferences in recv(). 933 defer c.releaseStream(call) 934 935 if v := framer.header.version.version(); v != c.version { 936 return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version) 937 } 938 939 return framer, nil 940} 941 942type preparedStatment struct { 943 id []byte 944 request preparedMetadata 945 response resultMetadata 946} 947 948type inflightPrepare struct { 949 wg sync.WaitGroup 950 err error 951 952 preparedStatment *preparedStatment 953} 954 955func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) { 956 stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, stmt) 957 flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare { 958 flight := new(inflightPrepare) 959 flight.wg.Add(1) 960 lru.Add(stmtCacheKey, flight) 961 return flight 962 }) 963 964 if ok { 965 flight.wg.Wait() 966 return flight.preparedStatment, flight.err 967 } 968 969 prep := &writePrepareFrame{ 970 statement: stmt, 971 } 972 if c.version > protoVersion4 { 973 prep.keyspace = c.currentKeyspace 974 } 975 976 framer, err := c.exec(ctx, prep, tracer) 977 if err != nil { 978 flight.err = err 979 flight.wg.Done() 980 c.session.stmtsLRU.remove(stmtCacheKey) 981 return nil, err 982 } 983 984 frame, err := framer.parseFrame() 985 if err != nil { 986 flight.err = err 987 flight.wg.Done() 988 c.session.stmtsLRU.remove(stmtCacheKey) 989 return nil, err 990 } 991 992 // TODO(zariel): tidy this up, simplify handling of frame parsing so its not duplicated 993 // everytime we need to parse a frame. 994 if len(framer.traceID) > 0 && tracer != nil { 995 tracer.Trace(framer.traceID) 996 } 997 998 switch x := frame.(type) { 999 case *resultPreparedFrame: 1000 flight.preparedStatment = &preparedStatment{ 1001 // defensively copy as we will recycle the underlying buffer after we 1002 // return. 1003 id: copyBytes(x.preparedID), 1004 // the type info's should _not_ have a reference to the framers read buffer, 1005 // therefore we can just copy them directly. 1006 request: x.reqMeta, 1007 response: x.respMeta, 1008 } 1009 case error: 1010 flight.err = x 1011 default: 1012 flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x) 1013 } 1014 flight.wg.Done() 1015 1016 if flight.err != nil { 1017 c.session.stmtsLRU.remove(stmtCacheKey) 1018 } 1019 1020 return flight.preparedStatment, flight.err 1021} 1022 1023func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error { 1024 if named, ok := value.(*namedValue); ok { 1025 dst.name = named.name 1026 value = named.value 1027 } 1028 1029 if _, ok := value.(unsetColumn); !ok { 1030 val, err := Marshal(typ, value) 1031 if err != nil { 1032 return err 1033 } 1034 1035 dst.value = val 1036 } else { 1037 dst.isUnset = true 1038 } 1039 1040 return nil 1041} 1042 1043func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { 1044 params := queryParams{ 1045 consistency: qry.cons, 1046 } 1047 1048 // frame checks that it is not 0 1049 params.serialConsistency = qry.serialCons 1050 params.defaultTimestamp = qry.defaultTimestamp 1051 params.defaultTimestampValue = qry.defaultTimestampValue 1052 1053 if len(qry.pageState) > 0 { 1054 params.pagingState = qry.pageState 1055 } 1056 if qry.pageSize > 0 { 1057 params.pageSize = qry.pageSize 1058 } 1059 if c.version > protoVersion4 { 1060 params.keyspace = c.currentKeyspace 1061 } 1062 1063 var ( 1064 frame frameWriter 1065 info *preparedStatment 1066 ) 1067 1068 if qry.shouldPrepare() { 1069 // Prepare all DML queries. Other queries can not be prepared. 1070 var err error 1071 info, err = c.prepareStatement(ctx, qry.stmt, qry.trace) 1072 if err != nil { 1073 return &Iter{err: err} 1074 } 1075 1076 var values []interface{} 1077 1078 if qry.binding == nil { 1079 values = qry.values 1080 } else { 1081 values, err = qry.binding(&QueryInfo{ 1082 Id: info.id, 1083 Args: info.request.columns, 1084 Rval: info.response.columns, 1085 PKeyColumns: info.request.pkeyColumns, 1086 }) 1087 1088 if err != nil { 1089 return &Iter{err: err} 1090 } 1091 } 1092 1093 if len(values) != info.request.actualColCount { 1094 return &Iter{err: fmt.Errorf("gocql: expected %d values send got %d", info.request.actualColCount, len(values))} 1095 } 1096 1097 params.values = make([]queryValues, len(values)) 1098 for i := 0; i < len(values); i++ { 1099 v := ¶ms.values[i] 1100 value := values[i] 1101 typ := info.request.columns[i].TypeInfo 1102 if err := marshalQueryValue(typ, value, v); err != nil { 1103 return &Iter{err: err} 1104 } 1105 } 1106 1107 params.skipMeta = !(c.session.cfg.DisableSkipMetadata || qry.disableSkipMetadata) 1108 1109 frame = &writeExecuteFrame{ 1110 preparedID: info.id, 1111 params: params, 1112 customPayload: qry.customPayload, 1113 } 1114 } else { 1115 frame = &writeQueryFrame{ 1116 statement: qry.stmt, 1117 params: params, 1118 customPayload: qry.customPayload, 1119 } 1120 } 1121 1122 framer, err := c.exec(ctx, frame, qry.trace) 1123 if err != nil { 1124 return &Iter{err: err} 1125 } 1126 1127 resp, err := framer.parseFrame() 1128 if err != nil { 1129 return &Iter{err: err} 1130 } 1131 1132 if len(framer.traceID) > 0 && qry.trace != nil { 1133 qry.trace.Trace(framer.traceID) 1134 } 1135 1136 switch x := resp.(type) { 1137 case *resultVoidFrame: 1138 return &Iter{framer: framer} 1139 case *resultRowsFrame: 1140 iter := &Iter{ 1141 meta: x.meta, 1142 framer: framer, 1143 numRows: x.numRows, 1144 } 1145 1146 if params.skipMeta { 1147 if info != nil { 1148 iter.meta = info.response 1149 iter.meta.pagingState = copyBytes(x.meta.pagingState) 1150 } else { 1151 return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")} 1152 } 1153 } else { 1154 iter.meta = x.meta 1155 } 1156 1157 if x.meta.morePages() && !qry.disableAutoPage { 1158 iter.next = &nextIter{ 1159 qry: qry, 1160 pos: int((1 - qry.prefetch) * float64(x.numRows)), 1161 } 1162 1163 iter.next.qry.pageState = copyBytes(x.meta.pagingState) 1164 if iter.next.pos < 1 { 1165 iter.next.pos = 1 1166 } 1167 } 1168 1169 return iter 1170 case *resultKeyspaceFrame: 1171 return &Iter{framer: framer} 1172 case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction, *schemaChangeAggregate, *schemaChangeType: 1173 iter := &Iter{framer: framer} 1174 if err := c.awaitSchemaAgreement(ctx); err != nil { 1175 // TODO: should have this behind a flag 1176 Logger.Println(err) 1177 } 1178 // dont return an error from this, might be a good idea to give a warning 1179 // though. The impact of this returning an error would be that the cluster 1180 // is not consistent with regards to its schema. 1181 return iter 1182 case *RequestErrUnprepared: 1183 stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, qry.stmt) 1184 if c.session.stmtsLRU.remove(stmtCacheKey) { 1185 return c.executeQuery(ctx, qry) 1186 } 1187 1188 return &Iter{err: x, framer: framer} 1189 case error: 1190 return &Iter{err: x, framer: framer} 1191 default: 1192 return &Iter{ 1193 err: NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x), 1194 framer: framer, 1195 } 1196 } 1197} 1198 1199func (c *Conn) Pick(qry *Query) *Conn { 1200 if c.Closed() { 1201 return nil 1202 } 1203 return c 1204} 1205 1206func (c *Conn) Closed() bool { 1207 return atomic.LoadInt32(&c.closed) == 1 1208} 1209 1210func (c *Conn) Address() string { 1211 return c.addr 1212} 1213 1214func (c *Conn) AvailableStreams() int { 1215 return c.streams.Available() 1216} 1217 1218func (c *Conn) UseKeyspace(keyspace string) error { 1219 q := &writeQueryFrame{statement: `USE "` + keyspace + `"`} 1220 q.params.consistency = Any 1221 1222 framer, err := c.exec(context.Background(), q, nil) 1223 if err != nil { 1224 return err 1225 } 1226 1227 resp, err := framer.parseFrame() 1228 if err != nil { 1229 return err 1230 } 1231 1232 switch x := resp.(type) { 1233 case *resultKeyspaceFrame: 1234 case error: 1235 return x 1236 default: 1237 return NewErrProtocol("unknown frame in response to USE: %v", x) 1238 } 1239 1240 c.currentKeyspace = keyspace 1241 1242 return nil 1243} 1244 1245func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { 1246 if c.version == protoVersion1 { 1247 return &Iter{err: ErrUnsupported} 1248 } 1249 1250 n := len(batch.Entries) 1251 req := &writeBatchFrame{ 1252 typ: batch.Type, 1253 statements: make([]batchStatment, n), 1254 consistency: batch.Cons, 1255 serialConsistency: batch.serialCons, 1256 defaultTimestamp: batch.defaultTimestamp, 1257 defaultTimestampValue: batch.defaultTimestampValue, 1258 customPayload: batch.CustomPayload, 1259 } 1260 1261 stmts := make(map[string]string, len(batch.Entries)) 1262 1263 for i := 0; i < n; i++ { 1264 entry := &batch.Entries[i] 1265 b := &req.statements[i] 1266 1267 if len(entry.Args) > 0 || entry.binding != nil { 1268 info, err := c.prepareStatement(batch.Context(), entry.Stmt, nil) 1269 if err != nil { 1270 return &Iter{err: err} 1271 } 1272 1273 var values []interface{} 1274 if entry.binding == nil { 1275 values = entry.Args 1276 } else { 1277 values, err = entry.binding(&QueryInfo{ 1278 Id: info.id, 1279 Args: info.request.columns, 1280 Rval: info.response.columns, 1281 PKeyColumns: info.request.pkeyColumns, 1282 }) 1283 if err != nil { 1284 return &Iter{err: err} 1285 } 1286 } 1287 1288 if len(values) != info.request.actualColCount { 1289 return &Iter{err: fmt.Errorf("gocql: batch statement %d expected %d values send got %d", i, info.request.actualColCount, len(values))} 1290 } 1291 1292 b.preparedID = info.id 1293 stmts[string(info.id)] = entry.Stmt 1294 1295 b.values = make([]queryValues, info.request.actualColCount) 1296 1297 for j := 0; j < info.request.actualColCount; j++ { 1298 v := &b.values[j] 1299 value := values[j] 1300 typ := info.request.columns[j].TypeInfo 1301 if err := marshalQueryValue(typ, value, v); err != nil { 1302 return &Iter{err: err} 1303 } 1304 } 1305 } else { 1306 b.statement = entry.Stmt 1307 } 1308 } 1309 1310 // TODO: should batch support tracing? 1311 framer, err := c.exec(batch.Context(), req, nil) 1312 if err != nil { 1313 return &Iter{err: err} 1314 } 1315 1316 resp, err := framer.parseFrame() 1317 if err != nil { 1318 return &Iter{err: err, framer: framer} 1319 } 1320 1321 switch x := resp.(type) { 1322 case *resultVoidFrame: 1323 return &Iter{} 1324 case *RequestErrUnprepared: 1325 stmt, found := stmts[string(x.StatementId)] 1326 if found { 1327 key := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, stmt) 1328 c.session.stmtsLRU.remove(key) 1329 } 1330 1331 if found { 1332 return c.executeBatch(ctx, batch) 1333 } else { 1334 return &Iter{err: x, framer: framer} 1335 } 1336 case *resultRowsFrame: 1337 iter := &Iter{ 1338 meta: x.meta, 1339 framer: framer, 1340 numRows: x.numRows, 1341 } 1342 1343 return iter 1344 case error: 1345 return &Iter{err: x, framer: framer} 1346 default: 1347 return &Iter{err: NewErrProtocol("Unknown type in response to batch statement: %s", x), framer: framer} 1348 } 1349} 1350 1351func (c *Conn) query(ctx context.Context, statement string, values ...interface{}) (iter *Iter) { 1352 q := c.session.Query(statement, values...).Consistency(One) 1353 q.trace = nil 1354 return c.executeQuery(ctx, q) 1355} 1356 1357func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) { 1358 const ( 1359 peerSchemas = "SELECT schema_version, peer FROM system.peers" 1360 localSchemas = "SELECT schema_version FROM system.local WHERE key='local'" 1361 ) 1362 1363 var versions map[string]struct{} 1364 1365 endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement) 1366 for time.Now().Before(endDeadline) { 1367 iter := c.query(ctx, peerSchemas) 1368 1369 versions = make(map[string]struct{}) 1370 1371 var schemaVersion string 1372 var peer string 1373 for iter.Scan(&schemaVersion, &peer) { 1374 if schemaVersion == "" { 1375 Logger.Printf("skipping peer entry with empty schema_version: peer=%q", peer) 1376 continue 1377 } 1378 1379 versions[schemaVersion] = struct{}{} 1380 schemaVersion = "" 1381 } 1382 1383 if err = iter.Close(); err != nil { 1384 goto cont 1385 } 1386 1387 iter = c.query(ctx, localSchemas) 1388 for iter.Scan(&schemaVersion) { 1389 versions[schemaVersion] = struct{}{} 1390 schemaVersion = "" 1391 } 1392 1393 if err = iter.Close(); err != nil { 1394 goto cont 1395 } 1396 1397 if len(versions) <= 1 { 1398 return nil 1399 } 1400 1401 cont: 1402 select { 1403 case <-ctx.Done(): 1404 return ctx.Err() 1405 case <-time.After(200 * time.Millisecond): 1406 } 1407 } 1408 1409 if err != nil { 1410 return err 1411 } 1412 1413 schemas := make([]string, 0, len(versions)) 1414 for schema := range versions { 1415 schemas = append(schemas, schema) 1416 } 1417 1418 // not exported 1419 return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas) 1420} 1421 1422func (c *Conn) localHostInfo(ctx context.Context) (*HostInfo, error) { 1423 row, err := c.query(ctx, "SELECT * FROM system.local WHERE key='local'").rowMap() 1424 if err != nil { 1425 return nil, err 1426 } 1427 1428 port := c.conn.RemoteAddr().(*net.TCPAddr).Port 1429 1430 // TODO(zariel): avoid doing this here 1431 host, err := c.session.hostInfoFromMap(row, port) 1432 if err != nil { 1433 return nil, err 1434 } 1435 1436 return c.session.ring.addOrUpdate(host), nil 1437} 1438 1439var ( 1440 ErrQueryArgLength = errors.New("gocql: query argument length mismatch") 1441 ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period") 1442 ErrTooManyTimeouts = errors.New("gocql: too many query timeouts on the connection") 1443 ErrConnectionClosed = errors.New("gocql: connection closed waiting for response") 1444 ErrNoStreams = errors.New("gocql: no streams available on connection") 1445) 1446