1package gocql 2 3import ( 4 "context" 5 crand "crypto/rand" 6 "errors" 7 "fmt" 8 "math/rand" 9 "net" 10 "os" 11 "regexp" 12 "strconv" 13 "sync" 14 "sync/atomic" 15 "time" 16) 17 18var ( 19 randr *rand.Rand 20 mutRandr sync.Mutex 21) 22 23func init() { 24 b := make([]byte, 4) 25 if _, err := crand.Read(b); err != nil { 26 panic(fmt.Sprintf("unable to seed random number generator: %v", err)) 27 } 28 29 randr = rand.New(rand.NewSource(int64(readInt(b)))) 30} 31 32// Ensure that the atomic variable is aligned to a 64bit boundary 33// so that atomic operations can be applied on 32bit architectures. 34type controlConn struct { 35 started int32 36 reconnecting int32 37 38 session *Session 39 conn atomic.Value 40 41 retry RetryPolicy 42 43 quit chan struct{} 44} 45 46func createControlConn(session *Session) *controlConn { 47 control := &controlConn{ 48 session: session, 49 quit: make(chan struct{}), 50 retry: &SimpleRetryPolicy{NumRetries: 3}, 51 } 52 53 control.conn.Store((*connHost)(nil)) 54 55 return control 56} 57 58func (c *controlConn) heartBeat() { 59 if !atomic.CompareAndSwapInt32(&c.started, 0, 1) { 60 return 61 } 62 63 sleepTime := 1 * time.Second 64 timer := time.NewTimer(sleepTime) 65 defer timer.Stop() 66 67 for { 68 timer.Reset(sleepTime) 69 70 select { 71 case <-c.quit: 72 return 73 case <-timer.C: 74 } 75 76 resp, err := c.writeFrame(&writeOptionsFrame{}) 77 if err != nil { 78 goto reconn 79 } 80 81 switch resp.(type) { 82 case *supportedFrame: 83 // Everything ok 84 sleepTime = 5 * time.Second 85 continue 86 case error: 87 goto reconn 88 default: 89 panic(fmt.Sprintf("gocql: unknown frame in response to options: %T", resp)) 90 } 91 92 reconn: 93 // try to connect a bit faster 94 sleepTime = 1 * time.Second 95 c.reconnect(true) 96 continue 97 } 98} 99 100var hostLookupPreferV4 = os.Getenv("GOCQL_HOST_LOOKUP_PREFER_V4") == "true" 101 102func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) { 103 var port int 104 host, portStr, err := net.SplitHostPort(addr) 105 if err != nil { 106 host = addr 107 port = defaultPort 108 } else { 109 port, err = strconv.Atoi(portStr) 110 if err != nil { 111 return nil, err 112 } 113 } 114 115 var hosts []*HostInfo 116 117 // Check if host is a literal IP address 118 if ip := net.ParseIP(host); ip != nil { 119 hosts = append(hosts, &HostInfo{connectAddress: ip, port: port}) 120 return hosts, nil 121 } 122 123 // Look up host in DNS 124 ips, err := net.LookupIP(host) 125 if err != nil { 126 return nil, err 127 } else if len(ips) == 0 { 128 return nil, fmt.Errorf("No IP's returned from DNS lookup for %q", addr) 129 } 130 131 // Filter to v4 addresses if any present 132 if hostLookupPreferV4 { 133 var preferredIPs []net.IP 134 for _, v := range ips { 135 if v4 := v.To4(); v4 != nil { 136 preferredIPs = append(preferredIPs, v4) 137 } 138 } 139 if len(preferredIPs) != 0 { 140 ips = preferredIPs 141 } 142 } 143 144 for _, ip := range ips { 145 hosts = append(hosts, &HostInfo{connectAddress: ip, port: port}) 146 } 147 148 return hosts, nil 149} 150 151func shuffleHosts(hosts []*HostInfo) []*HostInfo { 152 mutRandr.Lock() 153 perm := randr.Perm(len(hosts)) 154 mutRandr.Unlock() 155 shuffled := make([]*HostInfo, len(hosts)) 156 157 for i, host := range hosts { 158 shuffled[perm[i]] = host 159 } 160 161 return shuffled 162} 163 164func (c *controlConn) shuffleDial(endpoints []*HostInfo) (*Conn, error) { 165 // shuffle endpoints so not all drivers will connect to the same initial 166 // node. 167 shuffled := shuffleHosts(endpoints) 168 169 var err error 170 for _, host := range shuffled { 171 var conn *Conn 172 conn, err = c.session.connect(host, c) 173 if err == nil { 174 return conn, nil 175 } 176 177 Logger.Printf("gocql: unable to dial control conn %v: %v\n", host.ConnectAddress(), err) 178 } 179 180 return nil, err 181} 182 183// this is going to be version dependant and a nightmare to maintain :( 184var protocolSupportRe = regexp.MustCompile(`the lowest supported version is \d+ and the greatest is (\d+)$`) 185 186func parseProtocolFromError(err error) int { 187 // I really wish this had the actual info in the error frame... 188 matches := protocolSupportRe.FindAllStringSubmatch(err.Error(), -1) 189 if len(matches) != 1 || len(matches[0]) != 2 { 190 if verr, ok := err.(*protocolError); ok { 191 return int(verr.frame.Header().version.version()) 192 } 193 return 0 194 } 195 196 max, err := strconv.Atoi(matches[0][1]) 197 if err != nil { 198 return 0 199 } 200 201 return max 202} 203 204func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) { 205 hosts = shuffleHosts(hosts) 206 207 connCfg := *c.session.connCfg 208 connCfg.ProtoVersion = 4 // TODO: define maxProtocol 209 210 handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) { 211 // we should never get here, but if we do it means we connected to a 212 // host successfully which means our attempted protocol version worked 213 if !closed { 214 c.Close() 215 } 216 }) 217 218 var err error 219 for _, host := range hosts { 220 var conn *Conn 221 conn, err = c.session.dial(host, &connCfg, handler) 222 if conn != nil { 223 conn.Close() 224 } 225 226 if err == nil { 227 return connCfg.ProtoVersion, nil 228 } 229 230 if proto := parseProtocolFromError(err); proto > 0 { 231 return proto, nil 232 } 233 } 234 235 return 0, err 236} 237 238func (c *controlConn) connect(hosts []*HostInfo) error { 239 if len(hosts) == 0 { 240 return errors.New("control: no endpoints specified") 241 } 242 243 conn, err := c.shuffleDial(hosts) 244 if err != nil { 245 return fmt.Errorf("control: unable to connect to initial hosts: %v", err) 246 } 247 248 if err := c.setupConn(conn); err != nil { 249 conn.Close() 250 return fmt.Errorf("control: unable to setup connection: %v", err) 251 } 252 253 // we could fetch the initial ring here and update initial host data. So that 254 // when we return from here we have a ring topology ready to go. 255 256 go c.heartBeat() 257 258 return nil 259} 260 261type connHost struct { 262 conn *Conn 263 host *HostInfo 264} 265 266func (c *controlConn) setupConn(conn *Conn) error { 267 if err := c.registerEvents(conn); err != nil { 268 conn.Close() 269 return err 270 } 271 272 // TODO(zariel): do we need to fetch host info everytime 273 // the control conn connects? Surely we have it cached? 274 host, err := conn.localHostInfo() 275 if err != nil { 276 return err 277 } 278 279 ch := &connHost{ 280 conn: conn, 281 host: host, 282 } 283 284 c.conn.Store(ch) 285 c.session.handleNodeUp(host.ConnectAddress(), host.Port(), false) 286 287 return nil 288} 289 290func (c *controlConn) registerEvents(conn *Conn) error { 291 var events []string 292 293 if !c.session.cfg.Events.DisableTopologyEvents { 294 events = append(events, "TOPOLOGY_CHANGE") 295 } 296 if !c.session.cfg.Events.DisableNodeStatusEvents { 297 events = append(events, "STATUS_CHANGE") 298 } 299 if !c.session.cfg.Events.DisableSchemaEvents { 300 events = append(events, "SCHEMA_CHANGE") 301 } 302 303 if len(events) == 0 { 304 return nil 305 } 306 307 framer, err := conn.exec(context.Background(), 308 &writeRegisterFrame{ 309 events: events, 310 }, nil) 311 if err != nil { 312 return err 313 } 314 315 frame, err := framer.parseFrame() 316 if err != nil { 317 return err 318 } else if _, ok := frame.(*readyFrame); !ok { 319 return fmt.Errorf("unexpected frame in response to register: got %T: %v\n", frame, frame) 320 } 321 322 return nil 323} 324 325func (c *controlConn) reconnect(refreshring bool) { 326 if !atomic.CompareAndSwapInt32(&c.reconnecting, 0, 1) { 327 return 328 } 329 defer atomic.StoreInt32(&c.reconnecting, 0) 330 // TODO: simplify this function, use session.ring to get hosts instead of the 331 // connection pool 332 333 var host *HostInfo 334 ch := c.getConn() 335 if ch != nil { 336 host = ch.host 337 ch.conn.Close() 338 } 339 340 var newConn *Conn 341 if host != nil { 342 // try to connect to the old host 343 conn, err := c.session.connect(host, c) 344 if err != nil { 345 // host is dead 346 // TODO: this is replicated in a few places 347 if c.session.cfg.ConvictionPolicy.AddFailure(err, host) { 348 c.session.handleNodeDown(host.ConnectAddress(), host.Port()) 349 } 350 } else { 351 newConn = conn 352 } 353 } 354 355 // TODO: should have our own round-robin for hosts so that we can try each 356 // in succession and guarantee that we get a different host each time. 357 if newConn == nil { 358 host := c.session.ring.rrHost() 359 if host == nil { 360 c.connect(c.session.ring.endpoints) 361 return 362 } 363 364 var err error 365 newConn, err = c.session.connect(host, c) 366 if err != nil { 367 // TODO: add log handler for things like this 368 return 369 } 370 } 371 372 if err := c.setupConn(newConn); err != nil { 373 newConn.Close() 374 Logger.Printf("gocql: control unable to register events: %v\n", err) 375 return 376 } 377 378 if refreshring { 379 c.session.hostSource.refreshRing() 380 } 381} 382 383func (c *controlConn) HandleError(conn *Conn, err error, closed bool) { 384 if !closed { 385 return 386 } 387 388 oldConn := c.getConn() 389 if oldConn.conn != conn { 390 return 391 } 392 393 c.reconnect(false) 394} 395 396func (c *controlConn) getConn() *connHost { 397 return c.conn.Load().(*connHost) 398} 399 400func (c *controlConn) writeFrame(w frameWriter) (frame, error) { 401 ch := c.getConn() 402 if ch == nil { 403 return nil, errNoControl 404 } 405 406 framer, err := ch.conn.exec(context.Background(), w, nil) 407 if err != nil { 408 return nil, err 409 } 410 411 return framer.parseFrame() 412} 413 414func (c *controlConn) withConnHost(fn func(*connHost) *Iter) *Iter { 415 const maxConnectAttempts = 5 416 connectAttempts := 0 417 418 for i := 0; i < maxConnectAttempts; i++ { 419 ch := c.getConn() 420 if ch == nil { 421 if connectAttempts > maxConnectAttempts { 422 break 423 } 424 425 connectAttempts++ 426 427 c.reconnect(false) 428 continue 429 } 430 431 return fn(ch) 432 } 433 434 return &Iter{err: errNoControl} 435} 436 437func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter { 438 return c.withConnHost(func(ch *connHost) *Iter { 439 return fn(ch.conn) 440 }) 441} 442 443// query will return nil if the connection is closed or nil 444func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter) { 445 q := c.session.Query(statement, values...).Consistency(One).RoutingKey([]byte{}).Trace(nil) 446 447 for { 448 iter = c.withConn(func(conn *Conn) *Iter { 449 return conn.executeQuery(q) 450 }) 451 452 if gocqlDebug && iter.err != nil { 453 Logger.Printf("control: error executing %q: %v\n", statement, iter.err) 454 } 455 456 q.attempts++ 457 if iter.err == nil || !c.retry.Attempt(q) { 458 break 459 } 460 } 461 462 return 463} 464 465func (c *controlConn) awaitSchemaAgreement() error { 466 return c.withConn(func(conn *Conn) *Iter { 467 return &Iter{err: conn.awaitSchemaAgreement()} 468 }).err 469} 470 471func (c *controlConn) close() { 472 if atomic.CompareAndSwapInt32(&c.started, 1, -1) { 473 c.quit <- struct{}{} 474 } 475 476 ch := c.getConn() 477 if ch != nil { 478 ch.conn.Close() 479 } 480} 481 482var errNoControl = errors.New("gocql: no control connection available") 483