1// Copyright 2012 Gary Burd 2// 3// Licensed under the Apache License, Version 2.0 (the "License"): you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12// License for the specific language governing permissions and limitations 13// under the License. 14 15package redis 16 17import ( 18 "bufio" 19 "bytes" 20 "errors" 21 "fmt" 22 "io" 23 "net" 24 "strconv" 25 "sync" 26 "time" 27) 28 29// conn is the low-level implementation of Conn 30type conn struct { 31 32 // Shared 33 mu sync.Mutex 34 pending int 35 err error 36 conn net.Conn 37 38 // Read 39 readTimeout time.Duration 40 br *bufio.Reader 41 42 // Write 43 writeTimeout time.Duration 44 bw *bufio.Writer 45 46 // Scratch space for formatting argument length. 47 // '*' or '$', length, "\r\n" 48 lenScratch [32]byte 49 50 // Scratch space for formatting integers and floats. 51 numScratch [40]byte 52} 53 54// Dial connects to the Redis server at the given network and address. 55func Dial(network, address string) (Conn, error) { 56 dialer := xDialer{} 57 return dialer.Dial(network, address) 58} 59 60// DialTimeout acts like Dial but takes timeouts for establishing the 61// connection to the server, writing a command and reading a reply. 62func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) { 63 netDialer := net.Dialer{Timeout: connectTimeout} 64 dialer := xDialer{ 65 NetDial: netDialer.Dial, 66 ReadTimeout: readTimeout, 67 WriteTimeout: writeTimeout, 68 } 69 return dialer.Dial(network, address) 70} 71 72// A Dialer specifies options for connecting to a Redis server. 73type xDialer struct { 74 // NetDial specifies the dial function for creating TCP connections. If 75 // NetDial is nil, then net.Dial is used. 76 NetDial func(network, addr string) (net.Conn, error) 77 78 // ReadTimeout specifies the timeout for reading a single command 79 // reply. If ReadTimeout is zero, then no timeout is used. 80 ReadTimeout time.Duration 81 82 // WriteTimeout specifies the timeout for writing a single command. If 83 // WriteTimeout is zero, then no timeout is used. 84 WriteTimeout time.Duration 85} 86 87// Dial connects to the Redis server at address on the named network. 88func (d *xDialer) Dial(network, address string) (Conn, error) { 89 dial := d.NetDial 90 if dial == nil { 91 dial = net.Dial 92 } 93 netConn, err := dial(network, address) 94 if err != nil { 95 return nil, err 96 } 97 return &conn{ 98 conn: netConn, 99 bw: bufio.NewWriter(netConn), 100 br: bufio.NewReader(netConn), 101 readTimeout: d.ReadTimeout, 102 writeTimeout: d.WriteTimeout, 103 }, nil 104} 105 106// NewConn returns a new Redigo connection for the given net connection. 107func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn { 108 return &conn{ 109 conn: netConn, 110 bw: bufio.NewWriter(netConn), 111 br: bufio.NewReader(netConn), 112 readTimeout: readTimeout, 113 writeTimeout: writeTimeout, 114 } 115} 116 117func (c *conn) Close() error { 118 c.mu.Lock() 119 err := c.err 120 if c.err == nil { 121 c.err = errors.New("redigo: closed") 122 err = c.conn.Close() 123 } 124 c.mu.Unlock() 125 return err 126} 127 128func (c *conn) fatal(err error) error { 129 c.mu.Lock() 130 if c.err == nil { 131 c.err = err 132 // Close connection to force errors on subsequent calls and to unblock 133 // other reader or writer. 134 c.conn.Close() 135 } 136 c.mu.Unlock() 137 return err 138} 139 140func (c *conn) Err() error { 141 c.mu.Lock() 142 err := c.err 143 c.mu.Unlock() 144 return err 145} 146 147func (c *conn) writeLen(prefix byte, n int) error { 148 c.lenScratch[len(c.lenScratch)-1] = '\n' 149 c.lenScratch[len(c.lenScratch)-2] = '\r' 150 i := len(c.lenScratch) - 3 151 for { 152 c.lenScratch[i] = byte('0' + n%10) 153 i -= 1 154 n = n / 10 155 if n == 0 { 156 break 157 } 158 } 159 c.lenScratch[i] = prefix 160 _, err := c.bw.Write(c.lenScratch[i:]) 161 return err 162} 163 164func (c *conn) writeString(s string) error { 165 c.writeLen('$', len(s)) 166 c.bw.WriteString(s) 167 _, err := c.bw.WriteString("\r\n") 168 return err 169} 170 171func (c *conn) writeBytes(p []byte) error { 172 c.writeLen('$', len(p)) 173 c.bw.Write(p) 174 _, err := c.bw.WriteString("\r\n") 175 return err 176} 177 178func (c *conn) writeInt64(n int64) error { 179 return c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10)) 180} 181 182func (c *conn) writeFloat64(n float64) error { 183 return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64)) 184} 185 186func (c *conn) writeCommand(cmd string, args []interface{}) (err error) { 187 c.writeLen('*', 1+len(args)) 188 err = c.writeString(cmd) 189 for _, arg := range args { 190 if err != nil { 191 break 192 } 193 switch arg := arg.(type) { 194 case string: 195 err = c.writeString(arg) 196 case []byte: 197 err = c.writeBytes(arg) 198 case int: 199 err = c.writeInt64(int64(arg)) 200 case int64: 201 err = c.writeInt64(arg) 202 case float64: 203 err = c.writeFloat64(arg) 204 case bool: 205 if arg { 206 err = c.writeString("1") 207 } else { 208 err = c.writeString("0") 209 } 210 case nil: 211 err = c.writeString("") 212 default: 213 var buf bytes.Buffer 214 fmt.Fprint(&buf, arg) 215 err = c.writeBytes(buf.Bytes()) 216 } 217 } 218 return err 219} 220 221type protocolError string 222 223func (pe protocolError) Error() string { 224 return fmt.Sprintf("redigo: %s (possible server error or unsupported concurrent read by application)", string(pe)) 225} 226 227func (c *conn) readLine() ([]byte, error) { 228 p, err := c.br.ReadSlice('\n') 229 if err == bufio.ErrBufferFull { 230 return nil, protocolError("long response line") 231 } 232 if err != nil { 233 return nil, err 234 } 235 i := len(p) - 2 236 if i < 0 || p[i] != '\r' { 237 return nil, protocolError("bad response line terminator") 238 } 239 return p[:i], nil 240} 241 242// parseLen parses bulk string and array lengths. 243func parseLen(p []byte) (int, error) { 244 if len(p) == 0 { 245 return -1, protocolError("malformed length") 246 } 247 248 if p[0] == '-' && len(p) == 2 && p[1] == '1' { 249 // handle $-1 and $-1 null replies. 250 return -1, nil 251 } 252 253 var n int 254 for _, b := range p { 255 n *= 10 256 if b < '0' || b > '9' { 257 return -1, protocolError("illegal bytes in length") 258 } 259 n += int(b - '0') 260 } 261 262 return n, nil 263} 264 265// parseInt parses an integer reply. 266func parseInt(p []byte) (interface{}, error) { 267 if len(p) == 0 { 268 return 0, protocolError("malformed integer") 269 } 270 271 var negate bool 272 if p[0] == '-' { 273 negate = true 274 p = p[1:] 275 if len(p) == 0 { 276 return 0, protocolError("malformed integer") 277 } 278 } 279 280 var n int64 281 for _, b := range p { 282 n *= 10 283 if b < '0' || b > '9' { 284 return 0, protocolError("illegal bytes in length") 285 } 286 n += int64(b - '0') 287 } 288 289 if negate { 290 n = -n 291 } 292 return n, nil 293} 294 295var ( 296 okReply interface{} = "OK" 297 pongReply interface{} = "PONG" 298) 299 300func (c *conn) readReply() (interface{}, error) { 301 line, err := c.readLine() 302 if err != nil { 303 return nil, err 304 } 305 if len(line) == 0 { 306 return nil, protocolError("short response line") 307 } 308 switch line[0] { 309 case '+': 310 switch { 311 case len(line) == 3 && line[1] == 'O' && line[2] == 'K': 312 // Avoid allocation for frequent "+OK" response. 313 return okReply, nil 314 case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G': 315 // Avoid allocation in PING command benchmarks :) 316 return pongReply, nil 317 default: 318 return string(line[1:]), nil 319 } 320 case '-': 321 return Error(string(line[1:])), nil 322 case ':': 323 return parseInt(line[1:]) 324 case '$': 325 n, err := parseLen(line[1:]) 326 if n < 0 || err != nil { 327 return nil, err 328 } 329 p := make([]byte, n) 330 _, err = io.ReadFull(c.br, p) 331 if err != nil { 332 return nil, err 333 } 334 if line, err := c.readLine(); err != nil { 335 return nil, err 336 } else if len(line) != 0 { 337 return nil, protocolError("bad bulk string format") 338 } 339 return p, nil 340 case '*': 341 n, err := parseLen(line[1:]) 342 if n < 0 || err != nil { 343 return nil, err 344 } 345 r := make([]interface{}, n) 346 for i := range r { 347 r[i], err = c.readReply() 348 if err != nil { 349 return nil, err 350 } 351 } 352 return r, nil 353 } 354 return nil, protocolError("unexpected response line") 355} 356 357func (c *conn) Send(cmd string, args ...interface{}) error { 358 c.mu.Lock() 359 c.pending += 1 360 c.mu.Unlock() 361 if c.writeTimeout != 0 { 362 c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) 363 } 364 if err := c.writeCommand(cmd, args); err != nil { 365 return c.fatal(err) 366 } 367 return nil 368} 369 370func (c *conn) Flush() error { 371 if c.writeTimeout != 0 { 372 c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) 373 } 374 if err := c.bw.Flush(); err != nil { 375 return c.fatal(err) 376 } 377 return nil 378} 379 380func (c *conn) Receive() (reply interface{}, err error) { 381 if c.readTimeout != 0 { 382 c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) 383 } 384 if reply, err = c.readReply(); err != nil { 385 return nil, c.fatal(err) 386 } 387 // When using pub/sub, the number of receives can be greater than the 388 // number of sends. To enable normal use of the connection after 389 // unsubscribing from all channels, we do not decrement pending to a 390 // negative value. 391 // 392 // The pending field is decremented after the reply is read to handle the 393 // case where Receive is called before Send. 394 c.mu.Lock() 395 if c.pending > 0 { 396 c.pending -= 1 397 } 398 c.mu.Unlock() 399 if err, ok := reply.(Error); ok { 400 return nil, err 401 } 402 return 403} 404 405func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { 406 c.mu.Lock() 407 pending := c.pending 408 c.pending = 0 409 c.mu.Unlock() 410 411 if cmd == "" && pending == 0 { 412 return nil, nil 413 } 414 415 if c.writeTimeout != 0 { 416 c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) 417 } 418 419 if cmd != "" { 420 c.writeCommand(cmd, args) 421 } 422 423 if err := c.bw.Flush(); err != nil { 424 return nil, c.fatal(err) 425 } 426 427 if c.readTimeout != 0 { 428 c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) 429 } 430 431 if cmd == "" { 432 reply := make([]interface{}, pending) 433 for i := range reply { 434 r, e := c.readReply() 435 if e != nil { 436 return nil, c.fatal(e) 437 } 438 reply[i] = r 439 } 440 return reply, nil 441 } 442 443 var err error 444 var reply interface{} 445 for i := 0; i <= pending; i++ { 446 var e error 447 if reply, e = c.readReply(); e != nil { 448 return nil, c.fatal(e) 449 } 450 if e, ok := reply.(Error); ok && err == nil { 451 err = e 452 } 453 } 454 return reply, err 455} 456