1// Copyright 2012 Gary Burd 2// 3// Licensed under the Apache License, Version 2.0 (the "License"): you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12// License for the specific language governing permissions and limitations 13// under the License. 14 15package redis 16 17import ( 18 "bufio" 19 "bytes" 20 "crypto/tls" 21 "errors" 22 "fmt" 23 "io" 24 "net" 25 "net/url" 26 "regexp" 27 "strconv" 28 "sync" 29 "time" 30) 31 32// conn is the low-level implementation of Conn 33type conn struct { 34 35 // Shared 36 mu sync.Mutex 37 pending int 38 err error 39 conn net.Conn 40 41 // Read 42 readTimeout time.Duration 43 br *bufio.Reader 44 45 // Write 46 writeTimeout time.Duration 47 bw *bufio.Writer 48 49 // Scratch space for formatting argument length. 50 // '*' or '$', length, "\r\n" 51 lenScratch [32]byte 52 53 // Scratch space for formatting integers and floats. 54 numScratch [40]byte 55} 56 57// DialTimeout acts like Dial but takes timeouts for establishing the 58// connection to the server, writing a command and reading a reply. 59// 60// Deprecated: Use Dial with options instead. 61func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) { 62 return Dial(network, address, 63 DialConnectTimeout(connectTimeout), 64 DialReadTimeout(readTimeout), 65 DialWriteTimeout(writeTimeout)) 66} 67 68// DialOption specifies an option for dialing a Redis server. 69type DialOption struct { 70 f func(*dialOptions) 71} 72 73type dialOptions struct { 74 readTimeout time.Duration 75 writeTimeout time.Duration 76 dial func(network, addr string) (net.Conn, error) 77 db int 78 password string 79 dialTLS bool 80 skipVerify bool 81 tlsConfig *tls.Config 82} 83 84// DialReadTimeout specifies the timeout for reading a single command reply. 85func DialReadTimeout(d time.Duration) DialOption { 86 return DialOption{func(do *dialOptions) { 87 do.readTimeout = d 88 }} 89} 90 91// DialWriteTimeout specifies the timeout for writing a single command. 92func DialWriteTimeout(d time.Duration) DialOption { 93 return DialOption{func(do *dialOptions) { 94 do.writeTimeout = d 95 }} 96} 97 98// DialConnectTimeout specifies the timeout for connecting to the Redis server. 99func DialConnectTimeout(d time.Duration) DialOption { 100 return DialOption{func(do *dialOptions) { 101 dialer := net.Dialer{Timeout: d} 102 do.dial = dialer.Dial 103 }} 104} 105 106// DialNetDial specifies a custom dial function for creating TCP 107// connections. If this option is left out, then net.Dial is 108// used. DialNetDial overrides DialConnectTimeout. 109func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption { 110 return DialOption{func(do *dialOptions) { 111 do.dial = dial 112 }} 113} 114 115// DialDatabase specifies the database to select when dialing a connection. 116func DialDatabase(db int) DialOption { 117 return DialOption{func(do *dialOptions) { 118 do.db = db 119 }} 120} 121 122// DialPassword specifies the password to use when connecting to 123// the Redis server. 124func DialPassword(password string) DialOption { 125 return DialOption{func(do *dialOptions) { 126 do.password = password 127 }} 128} 129 130// DialTLSConfig specifies the config to use when a TLS connection is dialed. 131// Has no effect when not dialing a TLS connection. 132func DialTLSConfig(c *tls.Config) DialOption { 133 return DialOption{func(do *dialOptions) { 134 do.tlsConfig = c 135 }} 136} 137 138// DialTLSSkipVerify to disable server name verification when connecting 139// over TLS. Has no effect when not dialing a TLS connection. 140func DialTLSSkipVerify(skip bool) DialOption { 141 return DialOption{func(do *dialOptions) { 142 do.skipVerify = skip 143 }} 144} 145 146// Dial connects to the Redis server at the given network and 147// address using the specified options. 148func Dial(network, address string, options ...DialOption) (Conn, error) { 149 do := dialOptions{ 150 dial: net.Dial, 151 } 152 for _, option := range options { 153 option.f(&do) 154 } 155 156 netConn, err := do.dial(network, address) 157 if err != nil { 158 return nil, err 159 } 160 161 if do.dialTLS { 162 tlsConfig := cloneTLSClientConfig(do.tlsConfig, do.skipVerify) 163 if tlsConfig.ServerName == "" { 164 host, _, err := net.SplitHostPort(address) 165 if err != nil { 166 netConn.Close() 167 return nil, err 168 } 169 tlsConfig.ServerName = host 170 } 171 172 tlsConn := tls.Client(netConn, tlsConfig) 173 if err := tlsConn.Handshake(); err != nil { 174 netConn.Close() 175 return nil, err 176 } 177 netConn = tlsConn 178 } 179 180 c := &conn{ 181 conn: netConn, 182 bw: bufio.NewWriter(netConn), 183 br: bufio.NewReader(netConn), 184 readTimeout: do.readTimeout, 185 writeTimeout: do.writeTimeout, 186 } 187 188 if do.password != "" { 189 if _, err := c.Do("AUTH", do.password); err != nil { 190 netConn.Close() 191 return nil, err 192 } 193 } 194 195 if do.db != 0 { 196 if _, err := c.Do("SELECT", do.db); err != nil { 197 netConn.Close() 198 return nil, err 199 } 200 } 201 202 return c, nil 203} 204 205func dialTLS(do *dialOptions) { 206 do.dialTLS = true 207} 208 209var pathDBRegexp = regexp.MustCompile(`/(\d*)\z`) 210 211// DialURL connects to a Redis server at the given URL using the Redis 212// URI scheme. URLs should follow the draft IANA specification for the 213// scheme (https://www.iana.org/assignments/uri-schemes/prov/redis). 214func DialURL(rawurl string, options ...DialOption) (Conn, error) { 215 u, err := url.Parse(rawurl) 216 if err != nil { 217 return nil, err 218 } 219 220 if u.Scheme != "redis" && u.Scheme != "rediss" { 221 return nil, fmt.Errorf("invalid redis URL scheme: %s", u.Scheme) 222 } 223 224 // As per the IANA draft spec, the host defaults to localhost and 225 // the port defaults to 6379. 226 host, port, err := net.SplitHostPort(u.Host) 227 if err != nil { 228 // assume port is missing 229 host = u.Host 230 port = "6379" 231 } 232 if host == "" { 233 host = "localhost" 234 } 235 address := net.JoinHostPort(host, port) 236 237 if u.User != nil { 238 password, isSet := u.User.Password() 239 if isSet { 240 options = append(options, DialPassword(password)) 241 } 242 } 243 244 match := pathDBRegexp.FindStringSubmatch(u.Path) 245 if len(match) == 2 { 246 db := 0 247 if len(match[1]) > 0 { 248 db, err = strconv.Atoi(match[1]) 249 if err != nil { 250 return nil, fmt.Errorf("invalid database: %s", u.Path[1:]) 251 } 252 } 253 if db != 0 { 254 options = append(options, DialDatabase(db)) 255 } 256 } else if u.Path != "" { 257 return nil, fmt.Errorf("invalid database: %s", u.Path[1:]) 258 } 259 260 if u.Scheme == "rediss" { 261 options = append([]DialOption{{dialTLS}}, options...) 262 } 263 264 return Dial("tcp", address, options...) 265} 266 267// NewConn returns a new Redigo connection for the given net connection. 268func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn { 269 return &conn{ 270 conn: netConn, 271 bw: bufio.NewWriter(netConn), 272 br: bufio.NewReader(netConn), 273 readTimeout: readTimeout, 274 writeTimeout: writeTimeout, 275 } 276} 277 278func (c *conn) Close() error { 279 c.mu.Lock() 280 err := c.err 281 if c.err == nil { 282 c.err = errors.New("redigo: closed") 283 err = c.conn.Close() 284 } 285 c.mu.Unlock() 286 return err 287} 288 289func (c *conn) fatal(err error) error { 290 c.mu.Lock() 291 if c.err == nil { 292 c.err = err 293 // Close connection to force errors on subsequent calls and to unblock 294 // other reader or writer. 295 c.conn.Close() 296 } 297 c.mu.Unlock() 298 return err 299} 300 301func (c *conn) Err() error { 302 c.mu.Lock() 303 err := c.err 304 c.mu.Unlock() 305 return err 306} 307 308func (c *conn) writeLen(prefix byte, n int) error { 309 c.lenScratch[len(c.lenScratch)-1] = '\n' 310 c.lenScratch[len(c.lenScratch)-2] = '\r' 311 i := len(c.lenScratch) - 3 312 for { 313 c.lenScratch[i] = byte('0' + n%10) 314 i -= 1 315 n = n / 10 316 if n == 0 { 317 break 318 } 319 } 320 c.lenScratch[i] = prefix 321 _, err := c.bw.Write(c.lenScratch[i:]) 322 return err 323} 324 325func (c *conn) writeString(s string) error { 326 c.writeLen('$', len(s)) 327 c.bw.WriteString(s) 328 _, err := c.bw.WriteString("\r\n") 329 return err 330} 331 332func (c *conn) writeBytes(p []byte) error { 333 c.writeLen('$', len(p)) 334 c.bw.Write(p) 335 _, err := c.bw.WriteString("\r\n") 336 return err 337} 338 339func (c *conn) writeInt64(n int64) error { 340 return c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10)) 341} 342 343func (c *conn) writeFloat64(n float64) error { 344 return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64)) 345} 346 347func (c *conn) writeCommand(cmd string, args []interface{}) (err error) { 348 c.writeLen('*', 1+len(args)) 349 err = c.writeString(cmd) 350 for _, arg := range args { 351 if err != nil { 352 break 353 } 354 switch arg := arg.(type) { 355 case string: 356 err = c.writeString(arg) 357 case []byte: 358 err = c.writeBytes(arg) 359 case int: 360 err = c.writeInt64(int64(arg)) 361 case int64: 362 err = c.writeInt64(arg) 363 case float64: 364 err = c.writeFloat64(arg) 365 case bool: 366 if arg { 367 err = c.writeString("1") 368 } else { 369 err = c.writeString("0") 370 } 371 case nil: 372 err = c.writeString("") 373 case Argument: 374 var buf bytes.Buffer 375 fmt.Fprint(&buf, arg.RedisArg()) 376 err = c.writeBytes(buf.Bytes()) 377 default: 378 var buf bytes.Buffer 379 fmt.Fprint(&buf, arg) 380 err = c.writeBytes(buf.Bytes()) 381 } 382 } 383 return err 384} 385 386type protocolError string 387 388func (pe protocolError) Error() string { 389 return fmt.Sprintf("redigo: %s (possible server error or unsupported concurrent read by application)", string(pe)) 390} 391 392func (c *conn) readLine() ([]byte, error) { 393 p, err := c.br.ReadSlice('\n') 394 if err == bufio.ErrBufferFull { 395 return nil, protocolError("long response line") 396 } 397 if err != nil { 398 return nil, err 399 } 400 i := len(p) - 2 401 if i < 0 || p[i] != '\r' { 402 return nil, protocolError("bad response line terminator") 403 } 404 return p[:i], nil 405} 406 407// parseLen parses bulk string and array lengths. 408func parseLen(p []byte) (int, error) { 409 if len(p) == 0 { 410 return -1, protocolError("malformed length") 411 } 412 413 if p[0] == '-' && len(p) == 2 && p[1] == '1' { 414 // handle $-1 and $-1 null replies. 415 return -1, nil 416 } 417 418 var n int 419 for _, b := range p { 420 n *= 10 421 if b < '0' || b > '9' { 422 return -1, protocolError("illegal bytes in length") 423 } 424 n += int(b - '0') 425 } 426 427 return n, nil 428} 429 430// parseInt parses an integer reply. 431func parseInt(p []byte) (interface{}, error) { 432 if len(p) == 0 { 433 return 0, protocolError("malformed integer") 434 } 435 436 var negate bool 437 if p[0] == '-' { 438 negate = true 439 p = p[1:] 440 if len(p) == 0 { 441 return 0, protocolError("malformed integer") 442 } 443 } 444 445 var n int64 446 for _, b := range p { 447 n *= 10 448 if b < '0' || b > '9' { 449 return 0, protocolError("illegal bytes in length") 450 } 451 n += int64(b - '0') 452 } 453 454 if negate { 455 n = -n 456 } 457 return n, nil 458} 459 460var ( 461 okReply interface{} = "OK" 462 pongReply interface{} = "PONG" 463) 464 465func (c *conn) readReply() (interface{}, error) { 466 line, err := c.readLine() 467 if err != nil { 468 return nil, err 469 } 470 if len(line) == 0 { 471 return nil, protocolError("short response line") 472 } 473 switch line[0] { 474 case '+': 475 switch { 476 case len(line) == 3 && line[1] == 'O' && line[2] == 'K': 477 // Avoid allocation for frequent "+OK" response. 478 return okReply, nil 479 case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G': 480 // Avoid allocation in PING command benchmarks :) 481 return pongReply, nil 482 default: 483 return string(line[1:]), nil 484 } 485 case '-': 486 return Error(string(line[1:])), nil 487 case ':': 488 return parseInt(line[1:]) 489 case '$': 490 n, err := parseLen(line[1:]) 491 if n < 0 || err != nil { 492 return nil, err 493 } 494 p := make([]byte, n) 495 _, err = io.ReadFull(c.br, p) 496 if err != nil { 497 return nil, err 498 } 499 if line, err := c.readLine(); err != nil { 500 return nil, err 501 } else if len(line) != 0 { 502 return nil, protocolError("bad bulk string format") 503 } 504 return p, nil 505 case '*': 506 n, err := parseLen(line[1:]) 507 if n < 0 || err != nil { 508 return nil, err 509 } 510 r := make([]interface{}, n) 511 for i := range r { 512 r[i], err = c.readReply() 513 if err != nil { 514 return nil, err 515 } 516 } 517 return r, nil 518 } 519 return nil, protocolError("unexpected response line") 520} 521 522func (c *conn) Send(cmd string, args ...interface{}) error { 523 c.mu.Lock() 524 c.pending += 1 525 c.mu.Unlock() 526 if c.writeTimeout != 0 { 527 c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) 528 } 529 if err := c.writeCommand(cmd, args); err != nil { 530 return c.fatal(err) 531 } 532 return nil 533} 534 535func (c *conn) Flush() error { 536 if c.writeTimeout != 0 { 537 c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) 538 } 539 if err := c.bw.Flush(); err != nil { 540 return c.fatal(err) 541 } 542 return nil 543} 544 545func (c *conn) Receive() (reply interface{}, err error) { 546 if c.readTimeout != 0 { 547 c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) 548 } 549 if reply, err = c.readReply(); err != nil { 550 return nil, c.fatal(err) 551 } 552 // When using pub/sub, the number of receives can be greater than the 553 // number of sends. To enable normal use of the connection after 554 // unsubscribing from all channels, we do not decrement pending to a 555 // negative value. 556 // 557 // The pending field is decremented after the reply is read to handle the 558 // case where Receive is called before Send. 559 c.mu.Lock() 560 if c.pending > 0 { 561 c.pending -= 1 562 } 563 c.mu.Unlock() 564 if err, ok := reply.(Error); ok { 565 return nil, err 566 } 567 return 568} 569 570func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { 571 c.mu.Lock() 572 pending := c.pending 573 c.pending = 0 574 c.mu.Unlock() 575 576 if cmd == "" && pending == 0 { 577 return nil, nil 578 } 579 580 if c.writeTimeout != 0 { 581 c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) 582 } 583 584 if cmd != "" { 585 if err := c.writeCommand(cmd, args); err != nil { 586 return nil, c.fatal(err) 587 } 588 } 589 590 if err := c.bw.Flush(); err != nil { 591 return nil, c.fatal(err) 592 } 593 594 if c.readTimeout != 0 { 595 c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) 596 } 597 598 if cmd == "" { 599 reply := make([]interface{}, pending) 600 for i := range reply { 601 r, e := c.readReply() 602 if e != nil { 603 return nil, c.fatal(e) 604 } 605 reply[i] = r 606 } 607 return reply, nil 608 } 609 610 var err error 611 var reply interface{} 612 for i := 0; i <= pending; i++ { 613 var e error 614 if reply, e = c.readReply(); e != nil { 615 return nil, c.fatal(e) 616 } 617 if e, ok := reply.(Error); ok && err == nil { 618 err = e 619 } 620 } 621 return reply, err 622} 623