1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package bsonrw
8
9import (
10	"bytes"
11	"errors"
12	"fmt"
13	"io"
14	"math"
15	"strconv"
16	"strings"
17	"unicode"
18)
19
20type jsonTokenType byte
21
22const (
23	jttBeginObject jsonTokenType = iota
24	jttEndObject
25	jttBeginArray
26	jttEndArray
27	jttColon
28	jttComma
29	jttInt32
30	jttInt64
31	jttDouble
32	jttString
33	jttBool
34	jttNull
35	jttEOF
36)
37
38type jsonToken struct {
39	t jsonTokenType
40	v interface{}
41	p int
42}
43
44type jsonScanner struct {
45	r           io.Reader
46	buf         []byte
47	pos         int
48	lastReadErr error
49}
50
51// nextToken returns the next JSON token if one exists. A token is a character
52// of the JSON grammar, a number, a string, or a literal.
53func (js *jsonScanner) nextToken() (*jsonToken, error) {
54	c, err := js.readNextByte()
55
56	// keep reading until a non-space is encountered (break on read error or EOF)
57	for isWhiteSpace(c) && err == nil {
58		c, err = js.readNextByte()
59	}
60
61	if err == io.EOF {
62		return &jsonToken{t: jttEOF}, nil
63	} else if err != nil {
64		return nil, err
65	}
66
67	// switch on the character
68	switch c {
69	case '{':
70		return &jsonToken{t: jttBeginObject, v: byte('{'), p: js.pos - 1}, nil
71	case '}':
72		return &jsonToken{t: jttEndObject, v: byte('}'), p: js.pos - 1}, nil
73	case '[':
74		return &jsonToken{t: jttBeginArray, v: byte('['), p: js.pos - 1}, nil
75	case ']':
76		return &jsonToken{t: jttEndArray, v: byte(']'), p: js.pos - 1}, nil
77	case ':':
78		return &jsonToken{t: jttColon, v: byte(':'), p: js.pos - 1}, nil
79	case ',':
80		return &jsonToken{t: jttComma, v: byte(','), p: js.pos - 1}, nil
81	case '"': // RFC-8259 only allows for double quotes (") not single (')
82		return js.scanString()
83	default:
84		// check if it's a number
85		if c == '-' || isDigit(c) {
86			return js.scanNumber(c)
87		} else if c == 't' || c == 'f' || c == 'n' {
88			// maybe a literal
89			return js.scanLiteral(c)
90		} else {
91			return nil, fmt.Errorf("invalid JSON input. Position: %d. Character: %c", js.pos-1, c)
92		}
93	}
94}
95
96// readNextByte attempts to read the next byte from the buffer. If the buffer
97// has been exhausted, this function calls readIntoBuf, thus refilling the
98// buffer and resetting the read position to 0
99func (js *jsonScanner) readNextByte() (byte, error) {
100	if js.pos >= len(js.buf) {
101		err := js.readIntoBuf()
102
103		if err != nil {
104			return 0, err
105		}
106	}
107
108	b := js.buf[js.pos]
109	js.pos++
110
111	return b, nil
112}
113
114// readNNextBytes reads n bytes into dst, starting at offset
115func (js *jsonScanner) readNNextBytes(dst []byte, n, offset int) error {
116	var err error
117
118	for i := 0; i < n; i++ {
119		dst[i+offset], err = js.readNextByte()
120		if err != nil {
121			return err
122		}
123	}
124
125	return nil
126}
127
128// readIntoBuf reads up to 512 bytes from the scanner's io.Reader into the buffer
129func (js *jsonScanner) readIntoBuf() error {
130	if js.lastReadErr != nil {
131		js.buf = js.buf[:0]
132		js.pos = 0
133		return js.lastReadErr
134	}
135
136	if cap(js.buf) == 0 {
137		js.buf = make([]byte, 0, 512)
138	}
139
140	n, err := js.r.Read(js.buf[:cap(js.buf)])
141	if err != nil {
142		js.lastReadErr = err
143		if n > 0 {
144			err = nil
145		}
146	}
147	js.buf = js.buf[:n]
148	js.pos = 0
149
150	return err
151}
152
153func isWhiteSpace(c byte) bool {
154	return c == ' ' || c == '\t' || c == '\r' || c == '\n'
155}
156
157func isDigit(c byte) bool {
158	return unicode.IsDigit(rune(c))
159}
160
161func isValueTerminator(c byte) bool {
162	return c == ',' || c == '}' || c == ']' || isWhiteSpace(c)
163}
164
165// scanString reads from an opening '"' to a closing '"' and handles escaped characters
166func (js *jsonScanner) scanString() (*jsonToken, error) {
167	var b bytes.Buffer
168	var c byte
169	var err error
170
171	p := js.pos - 1
172
173	for {
174		c, err = js.readNextByte()
175		if err != nil {
176			if err == io.EOF {
177				return nil, errors.New("end of input in JSON string")
178			}
179			return nil, err
180		}
181
182		switch c {
183		case '\\':
184			c, err = js.readNextByte()
185			switch c {
186			case '"', '\\', '/':
187				b.WriteByte(c)
188			case 'b':
189				b.WriteByte('\b')
190			case 'f':
191				b.WriteByte('\f')
192			case 'n':
193				b.WriteByte('\n')
194			case 'r':
195				b.WriteByte('\r')
196			case 't':
197				b.WriteByte('\t')
198			case 'u':
199				us := make([]byte, 4)
200				err = js.readNNextBytes(us, 4, 0)
201				if err != nil {
202					return nil, fmt.Errorf("invalid unicode sequence in JSON string: %s", us)
203				}
204
205				s := fmt.Sprintf(`\u%s`, us)
206				s, err = strconv.Unquote(strings.Replace(strconv.Quote(s), `\\u`, `\u`, 1))
207				if err != nil {
208					return nil, err
209				}
210
211				b.WriteString(s)
212			default:
213				return nil, fmt.Errorf("invalid escape sequence in JSON string '\\%c'", c)
214			}
215		case '"':
216			return &jsonToken{t: jttString, v: b.String(), p: p}, nil
217		default:
218			b.WriteByte(c)
219		}
220	}
221}
222
223// scanLiteral reads an unquoted sequence of characters and determines if it is one of
224// three valid JSON literals (true, false, null); if so, it returns the appropriate
225// jsonToken; otherwise, it returns an error
226func (js *jsonScanner) scanLiteral(first byte) (*jsonToken, error) {
227	p := js.pos - 1
228
229	lit := make([]byte, 4)
230	lit[0] = first
231
232	err := js.readNNextBytes(lit, 3, 1)
233	if err != nil {
234		return nil, err
235	}
236
237	c5, err := js.readNextByte()
238
239	if bytes.Equal([]byte("true"), lit) && (isValueTerminator(c5) || err == io.EOF) {
240		js.pos = int(math.Max(0, float64(js.pos-1)))
241		return &jsonToken{t: jttBool, v: true, p: p}, nil
242	} else if bytes.Equal([]byte("null"), lit) && (isValueTerminator(c5) || err == io.EOF) {
243		js.pos = int(math.Max(0, float64(js.pos-1)))
244		return &jsonToken{t: jttNull, v: nil, p: p}, nil
245	} else if bytes.Equal([]byte("fals"), lit) {
246		if c5 == 'e' {
247			c5, err = js.readNextByte()
248
249			if isValueTerminator(c5) || err == io.EOF {
250				js.pos = int(math.Max(0, float64(js.pos-1)))
251				return &jsonToken{t: jttBool, v: false, p: p}, nil
252			}
253		}
254	}
255
256	return nil, fmt.Errorf("invalid JSON literal. Position: %d, literal: %s", p, lit)
257}
258
259type numberScanState byte
260
261const (
262	nssSawLeadingMinus numberScanState = iota
263	nssSawLeadingZero
264	nssSawIntegerDigits
265	nssSawDecimalPoint
266	nssSawFractionDigits
267	nssSawExponentLetter
268	nssSawExponentSign
269	nssSawExponentDigits
270	nssDone
271	nssInvalid
272)
273
274// scanNumber reads a JSON number (according to RFC-8259)
275func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) {
276	var b bytes.Buffer
277	var s numberScanState
278	var c byte
279	var err error
280
281	t := jttInt64 // assume it's an int64 until the type can be determined
282	start := js.pos - 1
283
284	b.WriteByte(first)
285
286	switch first {
287	case '-':
288		s = nssSawLeadingMinus
289	case '0':
290		s = nssSawLeadingZero
291	default:
292		s = nssSawIntegerDigits
293	}
294
295	for {
296		c, err = js.readNextByte()
297
298		if err != nil && err != io.EOF {
299			return nil, err
300		}
301
302		switch s {
303		case nssSawLeadingMinus:
304			switch c {
305			case '0':
306				s = nssSawLeadingZero
307				b.WriteByte(c)
308			default:
309				if isDigit(c) {
310					s = nssSawIntegerDigits
311					b.WriteByte(c)
312				} else {
313					s = nssInvalid
314				}
315			}
316		case nssSawLeadingZero:
317			switch c {
318			case '.':
319				s = nssSawDecimalPoint
320				b.WriteByte(c)
321			case 'e', 'E':
322				s = nssSawExponentLetter
323				b.WriteByte(c)
324			case '}', ']', ',':
325				s = nssDone
326			default:
327				if isWhiteSpace(c) || err == io.EOF {
328					s = nssDone
329				} else {
330					s = nssInvalid
331				}
332			}
333		case nssSawIntegerDigits:
334			switch c {
335			case '.':
336				s = nssSawDecimalPoint
337				b.WriteByte(c)
338			case 'e', 'E':
339				s = nssSawExponentLetter
340				b.WriteByte(c)
341			case '}', ']', ',':
342				s = nssDone
343			default:
344				if isWhiteSpace(c) || err == io.EOF {
345					s = nssDone
346				} else if isDigit(c) {
347					s = nssSawIntegerDigits
348					b.WriteByte(c)
349				} else {
350					s = nssInvalid
351				}
352			}
353		case nssSawDecimalPoint:
354			t = jttDouble
355			if isDigit(c) {
356				s = nssSawFractionDigits
357				b.WriteByte(c)
358			} else {
359				s = nssInvalid
360			}
361		case nssSawFractionDigits:
362			switch c {
363			case 'e', 'E':
364				s = nssSawExponentLetter
365				b.WriteByte(c)
366			case '}', ']', ',':
367				s = nssDone
368			default:
369				if isWhiteSpace(c) || err == io.EOF {
370					s = nssDone
371				} else if isDigit(c) {
372					s = nssSawFractionDigits
373					b.WriteByte(c)
374				} else {
375					s = nssInvalid
376				}
377			}
378		case nssSawExponentLetter:
379			t = jttDouble
380			switch c {
381			case '+', '-':
382				s = nssSawExponentSign
383				b.WriteByte(c)
384			default:
385				if isDigit(c) {
386					s = nssSawExponentDigits
387					b.WriteByte(c)
388				} else {
389					s = nssInvalid
390				}
391			}
392		case nssSawExponentSign:
393			if isDigit(c) {
394				s = nssSawExponentDigits
395				b.WriteByte(c)
396			} else {
397				s = nssInvalid
398			}
399		case nssSawExponentDigits:
400			switch c {
401			case '}', ']', ',':
402				s = nssDone
403			default:
404				if isWhiteSpace(c) || err == io.EOF {
405					s = nssDone
406				} else if isDigit(c) {
407					s = nssSawExponentDigits
408					b.WriteByte(c)
409				} else {
410					s = nssInvalid
411				}
412			}
413		}
414
415		switch s {
416		case nssInvalid:
417			return nil, fmt.Errorf("invalid JSON number. Position: %d", start)
418		case nssDone:
419			js.pos = int(math.Max(0, float64(js.pos-1)))
420			if t != jttDouble {
421				v, err := strconv.ParseInt(b.String(), 10, 64)
422				if err == nil {
423					if v < math.MinInt32 || v > math.MaxInt32 {
424						return &jsonToken{t: jttInt64, v: v, p: start}, nil
425					}
426
427					return &jsonToken{t: jttInt32, v: int32(v), p: start}, nil
428				}
429			}
430
431			v, err := strconv.ParseFloat(b.String(), 64)
432			if err != nil {
433				return nil, err
434			}
435
436			return &jsonToken{t: jttDouble, v: v, p: start}, nil
437		}
438	}
439}
440