1package proto
2
3import (
4	"bufio"
5	"fmt"
6	"io"
7
8	"github.com/go-redis/redis/v8/internal/util"
9)
10
11const (
12	ErrorReply  = '-'
13	StatusReply = '+'
14	IntReply    = ':'
15	StringReply = '$'
16	ArrayReply  = '*'
17)
18
19//------------------------------------------------------------------------------
20
21const Nil = RedisError("redis: nil")
22
23type RedisError string
24
25func (e RedisError) Error() string { return string(e) }
26
27func (RedisError) RedisError() {}
28
29//------------------------------------------------------------------------------
30
31type MultiBulkParse func(*Reader, int64) (interface{}, error)
32
33type Reader struct {
34	rd   *bufio.Reader
35	_buf []byte
36}
37
38func NewReader(rd io.Reader) *Reader {
39	return &Reader{
40		rd:   bufio.NewReader(rd),
41		_buf: make([]byte, 64),
42	}
43}
44
45func (r *Reader) Buffered() int {
46	return r.rd.Buffered()
47}
48
49func (r *Reader) Peek(n int) ([]byte, error) {
50	return r.rd.Peek(n)
51}
52
53func (r *Reader) Reset(rd io.Reader) {
54	r.rd.Reset(rd)
55}
56
57func (r *Reader) ReadLine() ([]byte, error) {
58	line, err := r.readLine()
59	if err != nil {
60		return nil, err
61	}
62	if isNilReply(line) {
63		return nil, Nil
64	}
65	return line, nil
66}
67
68// readLine that returns an error if:
69//   - there is a pending read error;
70//   - or line does not end with \r\n.
71func (r *Reader) readLine() ([]byte, error) {
72	b, err := r.rd.ReadSlice('\n')
73	if err != nil {
74		if err != bufio.ErrBufferFull {
75			return nil, err
76		}
77
78		full := make([]byte, len(b))
79		copy(full, b)
80
81		b, err = r.rd.ReadBytes('\n')
82		if err != nil {
83			return nil, err
84		}
85
86		full = append(full, b...) //nolint:makezero
87		b = full
88	}
89	if len(b) <= 2 || b[len(b)-1] != '\n' || b[len(b)-2] != '\r' {
90		return nil, fmt.Errorf("redis: invalid reply: %q", b)
91	}
92	return b[:len(b)-2], nil
93}
94
95func (r *Reader) ReadReply(m MultiBulkParse) (interface{}, error) {
96	line, err := r.ReadLine()
97	if err != nil {
98		return nil, err
99	}
100
101	switch line[0] {
102	case ErrorReply:
103		return nil, ParseErrorReply(line)
104	case StatusReply:
105		return string(line[1:]), nil
106	case IntReply:
107		return util.ParseInt(line[1:], 10, 64)
108	case StringReply:
109		return r.readStringReply(line)
110	case ArrayReply:
111		n, err := parseArrayLen(line)
112		if err != nil {
113			return nil, err
114		}
115		if m == nil {
116			err := fmt.Errorf("redis: got %.100q, but multi bulk parser is nil", line)
117			return nil, err
118		}
119		return m(r, n)
120	}
121	return nil, fmt.Errorf("redis: can't parse %.100q", line)
122}
123
124func (r *Reader) ReadIntReply() (int64, error) {
125	line, err := r.ReadLine()
126	if err != nil {
127		return 0, err
128	}
129	switch line[0] {
130	case ErrorReply:
131		return 0, ParseErrorReply(line)
132	case IntReply:
133		return util.ParseInt(line[1:], 10, 64)
134	default:
135		return 0, fmt.Errorf("redis: can't parse int reply: %.100q", line)
136	}
137}
138
139func (r *Reader) ReadString() (string, error) {
140	line, err := r.ReadLine()
141	if err != nil {
142		return "", err
143	}
144	switch line[0] {
145	case ErrorReply:
146		return "", ParseErrorReply(line)
147	case StringReply:
148		return r.readStringReply(line)
149	case StatusReply:
150		return string(line[1:]), nil
151	case IntReply:
152		return string(line[1:]), nil
153	default:
154		return "", fmt.Errorf("redis: can't parse reply=%.100q reading string", line)
155	}
156}
157
158func (r *Reader) readStringReply(line []byte) (string, error) {
159	if isNilReply(line) {
160		return "", Nil
161	}
162
163	replyLen, err := util.Atoi(line[1:])
164	if err != nil {
165		return "", err
166	}
167
168	b := make([]byte, replyLen+2)
169	_, err = io.ReadFull(r.rd, b)
170	if err != nil {
171		return "", err
172	}
173
174	return util.BytesToString(b[:replyLen]), nil
175}
176
177func (r *Reader) ReadArrayReply(m MultiBulkParse) (interface{}, error) {
178	line, err := r.ReadLine()
179	if err != nil {
180		return nil, err
181	}
182	switch line[0] {
183	case ErrorReply:
184		return nil, ParseErrorReply(line)
185	case ArrayReply:
186		n, err := parseArrayLen(line)
187		if err != nil {
188			return nil, err
189		}
190		return m(r, n)
191	default:
192		return nil, fmt.Errorf("redis: can't parse array reply: %.100q", line)
193	}
194}
195
196func (r *Reader) ReadArrayLen() (int, error) {
197	line, err := r.ReadLine()
198	if err != nil {
199		return 0, err
200	}
201	switch line[0] {
202	case ErrorReply:
203		return 0, ParseErrorReply(line)
204	case ArrayReply:
205		n, err := parseArrayLen(line)
206		if err != nil {
207			return 0, err
208		}
209		return int(n), nil
210	default:
211		return 0, fmt.Errorf("redis: can't parse array reply: %.100q", line)
212	}
213}
214
215func (r *Reader) ReadScanReply() ([]string, uint64, error) {
216	n, err := r.ReadArrayLen()
217	if err != nil {
218		return nil, 0, err
219	}
220	if n != 2 {
221		return nil, 0, fmt.Errorf("redis: got %d elements in scan reply, expected 2", n)
222	}
223
224	cursor, err := r.ReadUint()
225	if err != nil {
226		return nil, 0, err
227	}
228
229	n, err = r.ReadArrayLen()
230	if err != nil {
231		return nil, 0, err
232	}
233
234	keys := make([]string, n)
235
236	for i := 0; i < n; i++ {
237		key, err := r.ReadString()
238		if err != nil {
239			return nil, 0, err
240		}
241		keys[i] = key
242	}
243
244	return keys, cursor, err
245}
246
247func (r *Reader) ReadInt() (int64, error) {
248	b, err := r.readTmpBytesReply()
249	if err != nil {
250		return 0, err
251	}
252	return util.ParseInt(b, 10, 64)
253}
254
255func (r *Reader) ReadUint() (uint64, error) {
256	b, err := r.readTmpBytesReply()
257	if err != nil {
258		return 0, err
259	}
260	return util.ParseUint(b, 10, 64)
261}
262
263func (r *Reader) ReadFloatReply() (float64, error) {
264	b, err := r.readTmpBytesReply()
265	if err != nil {
266		return 0, err
267	}
268	return util.ParseFloat(b, 64)
269}
270
271func (r *Reader) readTmpBytesReply() ([]byte, error) {
272	line, err := r.ReadLine()
273	if err != nil {
274		return nil, err
275	}
276	switch line[0] {
277	case ErrorReply:
278		return nil, ParseErrorReply(line)
279	case StringReply:
280		return r._readTmpBytesReply(line)
281	case StatusReply:
282		return line[1:], nil
283	default:
284		return nil, fmt.Errorf("redis: can't parse string reply: %.100q", line)
285	}
286}
287
288func (r *Reader) _readTmpBytesReply(line []byte) ([]byte, error) {
289	if isNilReply(line) {
290		return nil, Nil
291	}
292
293	replyLen, err := util.Atoi(line[1:])
294	if err != nil {
295		return nil, err
296	}
297
298	buf := r.buf(replyLen + 2)
299	_, err = io.ReadFull(r.rd, buf)
300	if err != nil {
301		return nil, err
302	}
303
304	return buf[:replyLen], nil
305}
306
307func (r *Reader) buf(n int) []byte {
308	if n <= cap(r._buf) {
309		return r._buf[:n]
310	}
311	d := n - cap(r._buf)
312	r._buf = append(r._buf, make([]byte, d)...)
313	return r._buf
314}
315
316func isNilReply(b []byte) bool {
317	return len(b) == 3 &&
318		(b[0] == StringReply || b[0] == ArrayReply) &&
319		b[1] == '-' && b[2] == '1'
320}
321
322func ParseErrorReply(line []byte) error {
323	return RedisError(string(line[1:]))
324}
325
326func parseArrayLen(line []byte) (int64, error) {
327	if isNilReply(line) {
328		return 0, Nil
329	}
330	return util.ParseInt(line[1:], 10, 64)
331}
332