1// Copyright 2012 Gary Burd
2//
3// Licensed under the Apache License, Version 2.0 (the "License"): you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations
13// under the License.
14
15package redis
16
17import (
18	"bufio"
19	"bytes"
20	"errors"
21	"fmt"
22	"io"
23	"net"
24	"strconv"
25	"sync"
26	"time"
27)
28
29// conn is the low-level implementation of Conn
30type conn struct {
31
32	// Shared
33	mu      sync.Mutex
34	pending int
35	err     error
36	conn    net.Conn
37
38	// Read
39	readTimeout time.Duration
40	br          *bufio.Reader
41
42	// Write
43	writeTimeout time.Duration
44	bw           *bufio.Writer
45
46	// Scratch space for formatting argument length.
47	// '*' or '$', length, "\r\n"
48	lenScratch [32]byte
49
50	// Scratch space for formatting integers and floats.
51	numScratch [40]byte
52}
53
54// Dial connects to the Redis server at the given network and address.
55func Dial(network, address string) (Conn, error) {
56	dialer := xDialer{}
57	return dialer.Dial(network, address)
58}
59
60// DialTimeout acts like Dial but takes timeouts for establishing the
61// connection to the server, writing a command and reading a reply.
62func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) {
63	netDialer := net.Dialer{Timeout: connectTimeout}
64	dialer := xDialer{
65		NetDial:      netDialer.Dial,
66		ReadTimeout:  readTimeout,
67		WriteTimeout: writeTimeout,
68	}
69	return dialer.Dial(network, address)
70}
71
72// A Dialer specifies options for connecting to a Redis server.
73type xDialer struct {
74	// NetDial specifies the dial function for creating TCP connections. If
75	// NetDial is nil, then net.Dial is used.
76	NetDial func(network, addr string) (net.Conn, error)
77
78	// ReadTimeout specifies the timeout for reading a single command
79	// reply. If ReadTimeout is zero, then no timeout is used.
80	ReadTimeout time.Duration
81
82	// WriteTimeout specifies the timeout for writing a single command.  If
83	// WriteTimeout is zero, then no timeout is used.
84	WriteTimeout time.Duration
85}
86
87// Dial connects to the Redis server at address on the named network.
88func (d *xDialer) Dial(network, address string) (Conn, error) {
89	dial := d.NetDial
90	if dial == nil {
91		dial = net.Dial
92	}
93	netConn, err := dial(network, address)
94	if err != nil {
95		return nil, err
96	}
97	return &conn{
98		conn:         netConn,
99		bw:           bufio.NewWriter(netConn),
100		br:           bufio.NewReader(netConn),
101		readTimeout:  d.ReadTimeout,
102		writeTimeout: d.WriteTimeout,
103	}, nil
104}
105
106// NewConn returns a new Redigo connection for the given net connection.
107func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn {
108	return &conn{
109		conn:         netConn,
110		bw:           bufio.NewWriter(netConn),
111		br:           bufio.NewReader(netConn),
112		readTimeout:  readTimeout,
113		writeTimeout: writeTimeout,
114	}
115}
116
117func (c *conn) Close() error {
118	c.mu.Lock()
119	err := c.err
120	if c.err == nil {
121		c.err = errors.New("redigo: closed")
122		err = c.conn.Close()
123	}
124	c.mu.Unlock()
125	return err
126}
127
128func (c *conn) fatal(err error) error {
129	c.mu.Lock()
130	if c.err == nil {
131		c.err = err
132		// Close connection to force errors on subsequent calls and to unblock
133		// other reader or writer.
134		c.conn.Close()
135	}
136	c.mu.Unlock()
137	return err
138}
139
140func (c *conn) Err() error {
141	c.mu.Lock()
142	err := c.err
143	c.mu.Unlock()
144	return err
145}
146
147func (c *conn) writeLen(prefix byte, n int) error {
148	c.lenScratch[len(c.lenScratch)-1] = '\n'
149	c.lenScratch[len(c.lenScratch)-2] = '\r'
150	i := len(c.lenScratch) - 3
151	for {
152		c.lenScratch[i] = byte('0' + n%10)
153		i -= 1
154		n = n / 10
155		if n == 0 {
156			break
157		}
158	}
159	c.lenScratch[i] = prefix
160	_, err := c.bw.Write(c.lenScratch[i:])
161	return err
162}
163
164func (c *conn) writeString(s string) error {
165	c.writeLen('$', len(s))
166	c.bw.WriteString(s)
167	_, err := c.bw.WriteString("\r\n")
168	return err
169}
170
171func (c *conn) writeBytes(p []byte) error {
172	c.writeLen('$', len(p))
173	c.bw.Write(p)
174	_, err := c.bw.WriteString("\r\n")
175	return err
176}
177
178func (c *conn) writeInt64(n int64) error {
179	return c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10))
180}
181
182func (c *conn) writeFloat64(n float64) error {
183	return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64))
184}
185
186func (c *conn) writeCommand(cmd string, args []interface{}) (err error) {
187	c.writeLen('*', 1+len(args))
188	err = c.writeString(cmd)
189	for _, arg := range args {
190		if err != nil {
191			break
192		}
193		switch arg := arg.(type) {
194		case string:
195			err = c.writeString(arg)
196		case []byte:
197			err = c.writeBytes(arg)
198		case int:
199			err = c.writeInt64(int64(arg))
200		case int64:
201			err = c.writeInt64(arg)
202		case float64:
203			err = c.writeFloat64(arg)
204		case bool:
205			if arg {
206				err = c.writeString("1")
207			} else {
208				err = c.writeString("0")
209			}
210		case nil:
211			err = c.writeString("")
212		default:
213			var buf bytes.Buffer
214			fmt.Fprint(&buf, arg)
215			err = c.writeBytes(buf.Bytes())
216		}
217	}
218	return err
219}
220
221type protocolError string
222
223func (pe protocolError) Error() string {
224	return fmt.Sprintf("redigo: %s (possible server error or unsupported concurrent read by application)", string(pe))
225}
226
227func (c *conn) readLine() ([]byte, error) {
228	p, err := c.br.ReadSlice('\n')
229	if err == bufio.ErrBufferFull {
230		return nil, protocolError("long response line")
231	}
232	if err != nil {
233		return nil, err
234	}
235	i := len(p) - 2
236	if i < 0 || p[i] != '\r' {
237		return nil, protocolError("bad response line terminator")
238	}
239	return p[:i], nil
240}
241
242// parseLen parses bulk string and array lengths.
243func parseLen(p []byte) (int, error) {
244	if len(p) == 0 {
245		return -1, protocolError("malformed length")
246	}
247
248	if p[0] == '-' && len(p) == 2 && p[1] == '1' {
249		// handle $-1 and $-1 null replies.
250		return -1, nil
251	}
252
253	var n int
254	for _, b := range p {
255		n *= 10
256		if b < '0' || b > '9' {
257			return -1, protocolError("illegal bytes in length")
258		}
259		n += int(b - '0')
260	}
261
262	return n, nil
263}
264
265// parseInt parses an integer reply.
266func parseInt(p []byte) (interface{}, error) {
267	if len(p) == 0 {
268		return 0, protocolError("malformed integer")
269	}
270
271	var negate bool
272	if p[0] == '-' {
273		negate = true
274		p = p[1:]
275		if len(p) == 0 {
276			return 0, protocolError("malformed integer")
277		}
278	}
279
280	var n int64
281	for _, b := range p {
282		n *= 10
283		if b < '0' || b > '9' {
284			return 0, protocolError("illegal bytes in length")
285		}
286		n += int64(b - '0')
287	}
288
289	if negate {
290		n = -n
291	}
292	return n, nil
293}
294
295var (
296	okReply   interface{} = "OK"
297	pongReply interface{} = "PONG"
298)
299
300func (c *conn) readReply() (interface{}, error) {
301	line, err := c.readLine()
302	if err != nil {
303		return nil, err
304	}
305	if len(line) == 0 {
306		return nil, protocolError("short response line")
307	}
308	switch line[0] {
309	case '+':
310		switch {
311		case len(line) == 3 && line[1] == 'O' && line[2] == 'K':
312			// Avoid allocation for frequent "+OK" response.
313			return okReply, nil
314		case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G':
315			// Avoid allocation in PING command benchmarks :)
316			return pongReply, nil
317		default:
318			return string(line[1:]), nil
319		}
320	case '-':
321		return Error(string(line[1:])), nil
322	case ':':
323		return parseInt(line[1:])
324	case '$':
325		n, err := parseLen(line[1:])
326		if n < 0 || err != nil {
327			return nil, err
328		}
329		p := make([]byte, n)
330		_, err = io.ReadFull(c.br, p)
331		if err != nil {
332			return nil, err
333		}
334		if line, err := c.readLine(); err != nil {
335			return nil, err
336		} else if len(line) != 0 {
337			return nil, protocolError("bad bulk string format")
338		}
339		return p, nil
340	case '*':
341		n, err := parseLen(line[1:])
342		if n < 0 || err != nil {
343			return nil, err
344		}
345		r := make([]interface{}, n)
346		for i := range r {
347			r[i], err = c.readReply()
348			if err != nil {
349				return nil, err
350			}
351		}
352		return r, nil
353	}
354	return nil, protocolError("unexpected response line")
355}
356
357func (c *conn) Send(cmd string, args ...interface{}) error {
358	c.mu.Lock()
359	c.pending += 1
360	c.mu.Unlock()
361	if c.writeTimeout != 0 {
362		c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
363	}
364	if err := c.writeCommand(cmd, args); err != nil {
365		return c.fatal(err)
366	}
367	return nil
368}
369
370func (c *conn) Flush() error {
371	if c.writeTimeout != 0 {
372		c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
373	}
374	if err := c.bw.Flush(); err != nil {
375		return c.fatal(err)
376	}
377	return nil
378}
379
380func (c *conn) Receive() (reply interface{}, err error) {
381	if c.readTimeout != 0 {
382		c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
383	}
384	if reply, err = c.readReply(); err != nil {
385		return nil, c.fatal(err)
386	}
387	// When using pub/sub, the number of receives can be greater than the
388	// number of sends. To enable normal use of the connection after
389	// unsubscribing from all channels, we do not decrement pending to a
390	// negative value.
391	//
392	// The pending field is decremented after the reply is read to handle the
393	// case where Receive is called before Send.
394	c.mu.Lock()
395	if c.pending > 0 {
396		c.pending -= 1
397	}
398	c.mu.Unlock()
399	if err, ok := reply.(Error); ok {
400		return nil, err
401	}
402	return
403}
404
405func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
406	c.mu.Lock()
407	pending := c.pending
408	c.pending = 0
409	c.mu.Unlock()
410
411	if cmd == "" && pending == 0 {
412		return nil, nil
413	}
414
415	if c.writeTimeout != 0 {
416		c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
417	}
418
419	if cmd != "" {
420		c.writeCommand(cmd, args)
421	}
422
423	if err := c.bw.Flush(); err != nil {
424		return nil, c.fatal(err)
425	}
426
427	if c.readTimeout != 0 {
428		c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
429	}
430
431	if cmd == "" {
432		reply := make([]interface{}, pending)
433		for i := range reply {
434			r, e := c.readReply()
435			if e != nil {
436				return nil, c.fatal(e)
437			}
438			reply[i] = r
439		}
440		return reply, nil
441	}
442
443	var err error
444	var reply interface{}
445	for i := 0; i <= pending; i++ {
446		var e error
447		if reply, e = c.readReply(); e != nil {
448			return nil, c.fatal(e)
449		}
450		if e, ok := reply.(Error); ok && err == nil {
451			err = e
452		}
453	}
454	return reply, err
455}
456