1// Copyright 2012 Gary Burd
2//
3// Licensed under the Apache License, Version 2.0 (the "License"): you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations
13// under the License.
14
15package redis
16
17import (
18	"bufio"
19	"bytes"
20	"crypto/tls"
21	"errors"
22	"fmt"
23	"io"
24	"net"
25	"net/url"
26	"regexp"
27	"strconv"
28	"sync"
29	"time"
30)
31
32// conn is the low-level implementation of Conn
33type conn struct {
34
35	// Shared
36	mu      sync.Mutex
37	pending int
38	err     error
39	conn    net.Conn
40
41	// Read
42	readTimeout time.Duration
43	br          *bufio.Reader
44
45	// Write
46	writeTimeout time.Duration
47	bw           *bufio.Writer
48
49	// Scratch space for formatting argument length.
50	// '*' or '$', length, "\r\n"
51	lenScratch [32]byte
52
53	// Scratch space for formatting integers and floats.
54	numScratch [40]byte
55}
56
57// DialTimeout acts like Dial but takes timeouts for establishing the
58// connection to the server, writing a command and reading a reply.
59//
60// Deprecated: Use Dial with options instead.
61func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) {
62	return Dial(network, address,
63		DialConnectTimeout(connectTimeout),
64		DialReadTimeout(readTimeout),
65		DialWriteTimeout(writeTimeout))
66}
67
68// DialOption specifies an option for dialing a Redis server.
69type DialOption struct {
70	f func(*dialOptions)
71}
72
73type dialOptions struct {
74	readTimeout  time.Duration
75	writeTimeout time.Duration
76	dial         func(network, addr string) (net.Conn, error)
77	db           int
78	password     string
79	dialTLS      bool
80	skipVerify   bool
81	tlsConfig    *tls.Config
82}
83
84// DialReadTimeout specifies the timeout for reading a single command reply.
85func DialReadTimeout(d time.Duration) DialOption {
86	return DialOption{func(do *dialOptions) {
87		do.readTimeout = d
88	}}
89}
90
91// DialWriteTimeout specifies the timeout for writing a single command.
92func DialWriteTimeout(d time.Duration) DialOption {
93	return DialOption{func(do *dialOptions) {
94		do.writeTimeout = d
95	}}
96}
97
98// DialConnectTimeout specifies the timeout for connecting to the Redis server.
99func DialConnectTimeout(d time.Duration) DialOption {
100	return DialOption{func(do *dialOptions) {
101		dialer := net.Dialer{Timeout: d}
102		do.dial = dialer.Dial
103	}}
104}
105
106// DialNetDial specifies a custom dial function for creating TCP
107// connections. If this option is left out, then net.Dial is
108// used. DialNetDial overrides DialConnectTimeout.
109func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption {
110	return DialOption{func(do *dialOptions) {
111		do.dial = dial
112	}}
113}
114
115// DialDatabase specifies the database to select when dialing a connection.
116func DialDatabase(db int) DialOption {
117	return DialOption{func(do *dialOptions) {
118		do.db = db
119	}}
120}
121
122// DialPassword specifies the password to use when connecting to
123// the Redis server.
124func DialPassword(password string) DialOption {
125	return DialOption{func(do *dialOptions) {
126		do.password = password
127	}}
128}
129
130// DialTLSConfig specifies the config to use when a TLS connection is dialed.
131// Has no effect when not dialing a TLS connection.
132func DialTLSConfig(c *tls.Config) DialOption {
133	return DialOption{func(do *dialOptions) {
134		do.tlsConfig = c
135	}}
136}
137
138// DialTLSSkipVerify to disable server name verification when connecting
139// over TLS. Has no effect when not dialing a TLS connection.
140func DialTLSSkipVerify(skip bool) DialOption {
141	return DialOption{func(do *dialOptions) {
142		do.skipVerify = skip
143	}}
144}
145
146// Dial connects to the Redis server at the given network and
147// address using the specified options.
148func Dial(network, address string, options ...DialOption) (Conn, error) {
149	do := dialOptions{
150		dial: net.Dial,
151	}
152	for _, option := range options {
153		option.f(&do)
154	}
155
156	netConn, err := do.dial(network, address)
157	if err != nil {
158		return nil, err
159	}
160
161	if do.dialTLS {
162		tlsConfig := cloneTLSClientConfig(do.tlsConfig, do.skipVerify)
163		if tlsConfig.ServerName == "" {
164			host, _, err := net.SplitHostPort(address)
165			if err != nil {
166				netConn.Close()
167				return nil, err
168			}
169			tlsConfig.ServerName = host
170		}
171
172		tlsConn := tls.Client(netConn, tlsConfig)
173		if err := tlsConn.Handshake(); err != nil {
174			netConn.Close()
175			return nil, err
176		}
177		netConn = tlsConn
178	}
179
180	c := &conn{
181		conn:         netConn,
182		bw:           bufio.NewWriter(netConn),
183		br:           bufio.NewReader(netConn),
184		readTimeout:  do.readTimeout,
185		writeTimeout: do.writeTimeout,
186	}
187
188	if do.password != "" {
189		if _, err := c.Do("AUTH", do.password); err != nil {
190			netConn.Close()
191			return nil, err
192		}
193	}
194
195	if do.db != 0 {
196		if _, err := c.Do("SELECT", do.db); err != nil {
197			netConn.Close()
198			return nil, err
199		}
200	}
201
202	return c, nil
203}
204
205func dialTLS(do *dialOptions) {
206	do.dialTLS = true
207}
208
209var pathDBRegexp = regexp.MustCompile(`/(\d*)\z`)
210
211// DialURL connects to a Redis server at the given URL using the Redis
212// URI scheme. URLs should follow the draft IANA specification for the
213// scheme (https://www.iana.org/assignments/uri-schemes/prov/redis).
214func DialURL(rawurl string, options ...DialOption) (Conn, error) {
215	u, err := url.Parse(rawurl)
216	if err != nil {
217		return nil, err
218	}
219
220	if u.Scheme != "redis" && u.Scheme != "rediss" {
221		return nil, fmt.Errorf("invalid redis URL scheme: %s", u.Scheme)
222	}
223
224	// As per the IANA draft spec, the host defaults to localhost and
225	// the port defaults to 6379.
226	host, port, err := net.SplitHostPort(u.Host)
227	if err != nil {
228		// assume port is missing
229		host = u.Host
230		port = "6379"
231	}
232	if host == "" {
233		host = "localhost"
234	}
235	address := net.JoinHostPort(host, port)
236
237	if u.User != nil {
238		password, isSet := u.User.Password()
239		if isSet {
240			options = append(options, DialPassword(password))
241		}
242	}
243
244	match := pathDBRegexp.FindStringSubmatch(u.Path)
245	if len(match) == 2 {
246		db := 0
247		if len(match[1]) > 0 {
248			db, err = strconv.Atoi(match[1])
249			if err != nil {
250				return nil, fmt.Errorf("invalid database: %s", u.Path[1:])
251			}
252		}
253		if db != 0 {
254			options = append(options, DialDatabase(db))
255		}
256	} else if u.Path != "" {
257		return nil, fmt.Errorf("invalid database: %s", u.Path[1:])
258	}
259
260	if u.Scheme == "rediss" {
261		options = append([]DialOption{{dialTLS}}, options...)
262	}
263
264	return Dial("tcp", address, options...)
265}
266
267// NewConn returns a new Redigo connection for the given net connection.
268func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn {
269	return &conn{
270		conn:         netConn,
271		bw:           bufio.NewWriter(netConn),
272		br:           bufio.NewReader(netConn),
273		readTimeout:  readTimeout,
274		writeTimeout: writeTimeout,
275	}
276}
277
278func (c *conn) Close() error {
279	c.mu.Lock()
280	err := c.err
281	if c.err == nil {
282		c.err = errors.New("redigo: closed")
283		err = c.conn.Close()
284	}
285	c.mu.Unlock()
286	return err
287}
288
289func (c *conn) fatal(err error) error {
290	c.mu.Lock()
291	if c.err == nil {
292		c.err = err
293		// Close connection to force errors on subsequent calls and to unblock
294		// other reader or writer.
295		c.conn.Close()
296	}
297	c.mu.Unlock()
298	return err
299}
300
301func (c *conn) Err() error {
302	c.mu.Lock()
303	err := c.err
304	c.mu.Unlock()
305	return err
306}
307
308func (c *conn) writeLen(prefix byte, n int) error {
309	c.lenScratch[len(c.lenScratch)-1] = '\n'
310	c.lenScratch[len(c.lenScratch)-2] = '\r'
311	i := len(c.lenScratch) - 3
312	for {
313		c.lenScratch[i] = byte('0' + n%10)
314		i -= 1
315		n = n / 10
316		if n == 0 {
317			break
318		}
319	}
320	c.lenScratch[i] = prefix
321	_, err := c.bw.Write(c.lenScratch[i:])
322	return err
323}
324
325func (c *conn) writeString(s string) error {
326	c.writeLen('$', len(s))
327	c.bw.WriteString(s)
328	_, err := c.bw.WriteString("\r\n")
329	return err
330}
331
332func (c *conn) writeBytes(p []byte) error {
333	c.writeLen('$', len(p))
334	c.bw.Write(p)
335	_, err := c.bw.WriteString("\r\n")
336	return err
337}
338
339func (c *conn) writeInt64(n int64) error {
340	return c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10))
341}
342
343func (c *conn) writeFloat64(n float64) error {
344	return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64))
345}
346
347func (c *conn) writeCommand(cmd string, args []interface{}) (err error) {
348	c.writeLen('*', 1+len(args))
349	err = c.writeString(cmd)
350	for _, arg := range args {
351		if err != nil {
352			break
353		}
354		switch arg := arg.(type) {
355		case string:
356			err = c.writeString(arg)
357		case []byte:
358			err = c.writeBytes(arg)
359		case int:
360			err = c.writeInt64(int64(arg))
361		case int64:
362			err = c.writeInt64(arg)
363		case float64:
364			err = c.writeFloat64(arg)
365		case bool:
366			if arg {
367				err = c.writeString("1")
368			} else {
369				err = c.writeString("0")
370			}
371		case nil:
372			err = c.writeString("")
373		case Argument:
374			var buf bytes.Buffer
375			fmt.Fprint(&buf, arg.RedisArg())
376			err = c.writeBytes(buf.Bytes())
377		default:
378			var buf bytes.Buffer
379			fmt.Fprint(&buf, arg)
380			err = c.writeBytes(buf.Bytes())
381		}
382	}
383	return err
384}
385
386type protocolError string
387
388func (pe protocolError) Error() string {
389	return fmt.Sprintf("redigo: %s (possible server error or unsupported concurrent read by application)", string(pe))
390}
391
392func (c *conn) readLine() ([]byte, error) {
393	p, err := c.br.ReadSlice('\n')
394	if err == bufio.ErrBufferFull {
395		return nil, protocolError("long response line")
396	}
397	if err != nil {
398		return nil, err
399	}
400	i := len(p) - 2
401	if i < 0 || p[i] != '\r' {
402		return nil, protocolError("bad response line terminator")
403	}
404	return p[:i], nil
405}
406
407// parseLen parses bulk string and array lengths.
408func parseLen(p []byte) (int, error) {
409	if len(p) == 0 {
410		return -1, protocolError("malformed length")
411	}
412
413	if p[0] == '-' && len(p) == 2 && p[1] == '1' {
414		// handle $-1 and $-1 null replies.
415		return -1, nil
416	}
417
418	var n int
419	for _, b := range p {
420		n *= 10
421		if b < '0' || b > '9' {
422			return -1, protocolError("illegal bytes in length")
423		}
424		n += int(b - '0')
425	}
426
427	return n, nil
428}
429
430// parseInt parses an integer reply.
431func parseInt(p []byte) (interface{}, error) {
432	if len(p) == 0 {
433		return 0, protocolError("malformed integer")
434	}
435
436	var negate bool
437	if p[0] == '-' {
438		negate = true
439		p = p[1:]
440		if len(p) == 0 {
441			return 0, protocolError("malformed integer")
442		}
443	}
444
445	var n int64
446	for _, b := range p {
447		n *= 10
448		if b < '0' || b > '9' {
449			return 0, protocolError("illegal bytes in length")
450		}
451		n += int64(b - '0')
452	}
453
454	if negate {
455		n = -n
456	}
457	return n, nil
458}
459
460var (
461	okReply   interface{} = "OK"
462	pongReply interface{} = "PONG"
463)
464
465func (c *conn) readReply() (interface{}, error) {
466	line, err := c.readLine()
467	if err != nil {
468		return nil, err
469	}
470	if len(line) == 0 {
471		return nil, protocolError("short response line")
472	}
473	switch line[0] {
474	case '+':
475		switch {
476		case len(line) == 3 && line[1] == 'O' && line[2] == 'K':
477			// Avoid allocation for frequent "+OK" response.
478			return okReply, nil
479		case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G':
480			// Avoid allocation in PING command benchmarks :)
481			return pongReply, nil
482		default:
483			return string(line[1:]), nil
484		}
485	case '-':
486		return Error(string(line[1:])), nil
487	case ':':
488		return parseInt(line[1:])
489	case '$':
490		n, err := parseLen(line[1:])
491		if n < 0 || err != nil {
492			return nil, err
493		}
494		p := make([]byte, n)
495		_, err = io.ReadFull(c.br, p)
496		if err != nil {
497			return nil, err
498		}
499		if line, err := c.readLine(); err != nil {
500			return nil, err
501		} else if len(line) != 0 {
502			return nil, protocolError("bad bulk string format")
503		}
504		return p, nil
505	case '*':
506		n, err := parseLen(line[1:])
507		if n < 0 || err != nil {
508			return nil, err
509		}
510		r := make([]interface{}, n)
511		for i := range r {
512			r[i], err = c.readReply()
513			if err != nil {
514				return nil, err
515			}
516		}
517		return r, nil
518	}
519	return nil, protocolError("unexpected response line")
520}
521
522func (c *conn) Send(cmd string, args ...interface{}) error {
523	c.mu.Lock()
524	c.pending += 1
525	c.mu.Unlock()
526	if c.writeTimeout != 0 {
527		c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
528	}
529	if err := c.writeCommand(cmd, args); err != nil {
530		return c.fatal(err)
531	}
532	return nil
533}
534
535func (c *conn) Flush() error {
536	if c.writeTimeout != 0 {
537		c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
538	}
539	if err := c.bw.Flush(); err != nil {
540		return c.fatal(err)
541	}
542	return nil
543}
544
545func (c *conn) Receive() (reply interface{}, err error) {
546	if c.readTimeout != 0 {
547		c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
548	}
549	if reply, err = c.readReply(); err != nil {
550		return nil, c.fatal(err)
551	}
552	// When using pub/sub, the number of receives can be greater than the
553	// number of sends. To enable normal use of the connection after
554	// unsubscribing from all channels, we do not decrement pending to a
555	// negative value.
556	//
557	// The pending field is decremented after the reply is read to handle the
558	// case where Receive is called before Send.
559	c.mu.Lock()
560	if c.pending > 0 {
561		c.pending -= 1
562	}
563	c.mu.Unlock()
564	if err, ok := reply.(Error); ok {
565		return nil, err
566	}
567	return
568}
569
570func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
571	c.mu.Lock()
572	pending := c.pending
573	c.pending = 0
574	c.mu.Unlock()
575
576	if cmd == "" && pending == 0 {
577		return nil, nil
578	}
579
580	if c.writeTimeout != 0 {
581		c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
582	}
583
584	if cmd != "" {
585		if err := c.writeCommand(cmd, args); err != nil {
586			return nil, c.fatal(err)
587		}
588	}
589
590	if err := c.bw.Flush(); err != nil {
591		return nil, c.fatal(err)
592	}
593
594	if c.readTimeout != 0 {
595		c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
596	}
597
598	if cmd == "" {
599		reply := make([]interface{}, pending)
600		for i := range reply {
601			r, e := c.readReply()
602			if e != nil {
603				return nil, c.fatal(e)
604			}
605			reply[i] = r
606		}
607		return reply, nil
608	}
609
610	var err error
611	var reply interface{}
612	for i := 0; i <= pending; i++ {
613		var e error
614		if reply, e = c.readReply(); e != nil {
615			return nil, c.fatal(e)
616		}
617		if e, ok := reply.(Error); ok && err == nil {
618			err = e
619		}
620	}
621	return reply, err
622}
623