1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package bsonrw
8
9import (
10	"bytes"
11	"errors"
12	"fmt"
13	"io"
14	"math"
15	"strconv"
16	"unicode"
17	"unicode/utf16"
18)
19
20type jsonTokenType byte
21
22const (
23	jttBeginObject jsonTokenType = iota
24	jttEndObject
25	jttBeginArray
26	jttEndArray
27	jttColon
28	jttComma
29	jttInt32
30	jttInt64
31	jttDouble
32	jttString
33	jttBool
34	jttNull
35	jttEOF
36)
37
38type jsonToken struct {
39	t jsonTokenType
40	v interface{}
41	p int
42}
43
44type jsonScanner struct {
45	r           io.Reader
46	buf         []byte
47	pos         int
48	lastReadErr error
49}
50
51// nextToken returns the next JSON token if one exists. A token is a character
52// of the JSON grammar, a number, a string, or a literal.
53func (js *jsonScanner) nextToken() (*jsonToken, error) {
54	c, err := js.readNextByte()
55
56	// keep reading until a non-space is encountered (break on read error or EOF)
57	for isWhiteSpace(c) && err == nil {
58		c, err = js.readNextByte()
59	}
60
61	if err == io.EOF {
62		return &jsonToken{t: jttEOF}, nil
63	} else if err != nil {
64		return nil, err
65	}
66
67	// switch on the character
68	switch c {
69	case '{':
70		return &jsonToken{t: jttBeginObject, v: byte('{'), p: js.pos - 1}, nil
71	case '}':
72		return &jsonToken{t: jttEndObject, v: byte('}'), p: js.pos - 1}, nil
73	case '[':
74		return &jsonToken{t: jttBeginArray, v: byte('['), p: js.pos - 1}, nil
75	case ']':
76		return &jsonToken{t: jttEndArray, v: byte(']'), p: js.pos - 1}, nil
77	case ':':
78		return &jsonToken{t: jttColon, v: byte(':'), p: js.pos - 1}, nil
79	case ',':
80		return &jsonToken{t: jttComma, v: byte(','), p: js.pos - 1}, nil
81	case '"': // RFC-8259 only allows for double quotes (") not single (')
82		return js.scanString()
83	default:
84		// check if it's a number
85		if c == '-' || isDigit(c) {
86			return js.scanNumber(c)
87		} else if c == 't' || c == 'f' || c == 'n' {
88			// maybe a literal
89			return js.scanLiteral(c)
90		} else {
91			return nil, fmt.Errorf("invalid JSON input. Position: %d. Character: %c", js.pos-1, c)
92		}
93	}
94}
95
96// readNextByte attempts to read the next byte from the buffer. If the buffer
97// has been exhausted, this function calls readIntoBuf, thus refilling the
98// buffer and resetting the read position to 0
99func (js *jsonScanner) readNextByte() (byte, error) {
100	if js.pos >= len(js.buf) {
101		err := js.readIntoBuf()
102
103		if err != nil {
104			return 0, err
105		}
106	}
107
108	b := js.buf[js.pos]
109	js.pos++
110
111	return b, nil
112}
113
114// readNNextBytes reads n bytes into dst, starting at offset
115func (js *jsonScanner) readNNextBytes(dst []byte, n, offset int) error {
116	var err error
117
118	for i := 0; i < n; i++ {
119		dst[i+offset], err = js.readNextByte()
120		if err != nil {
121			return err
122		}
123	}
124
125	return nil
126}
127
128// readIntoBuf reads up to 512 bytes from the scanner's io.Reader into the buffer
129func (js *jsonScanner) readIntoBuf() error {
130	if js.lastReadErr != nil {
131		js.buf = js.buf[:0]
132		js.pos = 0
133		return js.lastReadErr
134	}
135
136	if cap(js.buf) == 0 {
137		js.buf = make([]byte, 0, 512)
138	}
139
140	n, err := js.r.Read(js.buf[:cap(js.buf)])
141	if err != nil {
142		js.lastReadErr = err
143		if n > 0 {
144			err = nil
145		}
146	}
147	js.buf = js.buf[:n]
148	js.pos = 0
149
150	return err
151}
152
153func isWhiteSpace(c byte) bool {
154	return c == ' ' || c == '\t' || c == '\r' || c == '\n'
155}
156
157func isDigit(c byte) bool {
158	return unicode.IsDigit(rune(c))
159}
160
161func isValueTerminator(c byte) bool {
162	return c == ',' || c == '}' || c == ']' || isWhiteSpace(c)
163}
164
165// getu4 decodes the 4-byte hex sequence from the beginning of s, returning the hex value as a rune,
166// or it returns -1. Note that the "\u" from the unicode escape sequence should not be present.
167// It is copied and lightly modified from the Go JSON decode function at
168// https://github.com/golang/go/blob/1b0a0316802b8048d69da49dc23c5a5ab08e8ae8/src/encoding/json/decode.go#L1169-L1188
169func getu4(s []byte) rune {
170	if len(s) < 4 {
171		return -1
172	}
173	var r rune
174	for _, c := range s[:4] {
175		switch {
176		case '0' <= c && c <= '9':
177			c = c - '0'
178		case 'a' <= c && c <= 'f':
179			c = c - 'a' + 10
180		case 'A' <= c && c <= 'F':
181			c = c - 'A' + 10
182		default:
183			return -1
184		}
185		r = r*16 + rune(c)
186	}
187	return r
188}
189
190// scanString reads from an opening '"' to a closing '"' and handles escaped characters
191func (js *jsonScanner) scanString() (*jsonToken, error) {
192	var b bytes.Buffer
193	var c byte
194	var err error
195
196	p := js.pos - 1
197
198	for {
199		c, err = js.readNextByte()
200		if err != nil {
201			if err == io.EOF {
202				return nil, errors.New("end of input in JSON string")
203			}
204			return nil, err
205		}
206
207	evalNextChar:
208		switch c {
209		case '\\':
210			c, err = js.readNextByte()
211			if err != nil {
212				if err == io.EOF {
213					return nil, errors.New("end of input in JSON string")
214				}
215				return nil, err
216			}
217
218		evalNextEscapeChar:
219			switch c {
220			case '"', '\\', '/':
221				b.WriteByte(c)
222			case 'b':
223				b.WriteByte('\b')
224			case 'f':
225				b.WriteByte('\f')
226			case 'n':
227				b.WriteByte('\n')
228			case 'r':
229				b.WriteByte('\r')
230			case 't':
231				b.WriteByte('\t')
232			case 'u':
233				us := make([]byte, 4)
234				err = js.readNNextBytes(us, 4, 0)
235				if err != nil {
236					return nil, fmt.Errorf("invalid unicode sequence in JSON string: %s", us)
237				}
238
239				rn := getu4(us)
240
241				// If the rune we just decoded is the high or low value of a possible surrogate pair,
242				// try to decode the next sequence as the low value of a surrogate pair. We're
243				// expecting the next sequence to be another Unicode escape sequence (e.g. "\uDD1E"),
244				// but need to handle cases where the input is not a valid surrogate pair.
245				// For more context on unicode surrogate pairs, see:
246				// https://www.christianfscott.com/rust-chars-vs-go-runes/
247				// https://www.unicode.org/glossary/#high_surrogate_code_point
248				if utf16.IsSurrogate(rn) {
249					c, err = js.readNextByte()
250					if err != nil {
251						if err == io.EOF {
252							return nil, errors.New("end of input in JSON string")
253						}
254						return nil, err
255					}
256
257					// If the next value isn't the beginning of a backslash escape sequence, write
258					// the Unicode replacement character for the surrogate value and goto the
259					// beginning of the next char eval block.
260					if c != '\\' {
261						b.WriteRune(unicode.ReplacementChar)
262						goto evalNextChar
263					}
264
265					c, err = js.readNextByte()
266					if err != nil {
267						if err == io.EOF {
268							return nil, errors.New("end of input in JSON string")
269						}
270						return nil, err
271					}
272
273					// If the next value isn't the beginning of a unicode escape sequence, write the
274					// Unicode replacement character for the surrogate value and goto the beginning
275					// of the next escape char eval block.
276					if c != 'u' {
277						b.WriteRune(unicode.ReplacementChar)
278						goto evalNextEscapeChar
279					}
280
281					err = js.readNNextBytes(us, 4, 0)
282					if err != nil {
283						return nil, fmt.Errorf("invalid unicode sequence in JSON string: %s", us)
284					}
285
286					rn2 := getu4(us)
287
288					// Try to decode the pair of runes as a utf16 surrogate pair. If that fails, write
289					// the Unicode replacement character for the surrogate value and the 2nd decoded rune.
290					if rnPair := utf16.DecodeRune(rn, rn2); rnPair != unicode.ReplacementChar {
291						b.WriteRune(rnPair)
292					} else {
293						b.WriteRune(unicode.ReplacementChar)
294						b.WriteRune(rn2)
295					}
296
297					break
298				}
299
300				b.WriteRune(rn)
301			default:
302				return nil, fmt.Errorf("invalid escape sequence in JSON string '\\%c'", c)
303			}
304		case '"':
305			return &jsonToken{t: jttString, v: b.String(), p: p}, nil
306		default:
307			b.WriteByte(c)
308		}
309	}
310}
311
312// scanLiteral reads an unquoted sequence of characters and determines if it is one of
313// three valid JSON literals (true, false, null); if so, it returns the appropriate
314// jsonToken; otherwise, it returns an error
315func (js *jsonScanner) scanLiteral(first byte) (*jsonToken, error) {
316	p := js.pos - 1
317
318	lit := make([]byte, 4)
319	lit[0] = first
320
321	err := js.readNNextBytes(lit, 3, 1)
322	if err != nil {
323		return nil, err
324	}
325
326	c5, err := js.readNextByte()
327
328	if bytes.Equal([]byte("true"), lit) && (isValueTerminator(c5) || err == io.EOF) {
329		js.pos = int(math.Max(0, float64(js.pos-1)))
330		return &jsonToken{t: jttBool, v: true, p: p}, nil
331	} else if bytes.Equal([]byte("null"), lit) && (isValueTerminator(c5) || err == io.EOF) {
332		js.pos = int(math.Max(0, float64(js.pos-1)))
333		return &jsonToken{t: jttNull, v: nil, p: p}, nil
334	} else if bytes.Equal([]byte("fals"), lit) {
335		if c5 == 'e' {
336			c5, err = js.readNextByte()
337
338			if isValueTerminator(c5) || err == io.EOF {
339				js.pos = int(math.Max(0, float64(js.pos-1)))
340				return &jsonToken{t: jttBool, v: false, p: p}, nil
341			}
342		}
343	}
344
345	return nil, fmt.Errorf("invalid JSON literal. Position: %d, literal: %s", p, lit)
346}
347
348type numberScanState byte
349
350const (
351	nssSawLeadingMinus numberScanState = iota
352	nssSawLeadingZero
353	nssSawIntegerDigits
354	nssSawDecimalPoint
355	nssSawFractionDigits
356	nssSawExponentLetter
357	nssSawExponentSign
358	nssSawExponentDigits
359	nssDone
360	nssInvalid
361)
362
363// scanNumber reads a JSON number (according to RFC-8259)
364func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) {
365	var b bytes.Buffer
366	var s numberScanState
367	var c byte
368	var err error
369
370	t := jttInt64 // assume it's an int64 until the type can be determined
371	start := js.pos - 1
372
373	b.WriteByte(first)
374
375	switch first {
376	case '-':
377		s = nssSawLeadingMinus
378	case '0':
379		s = nssSawLeadingZero
380	default:
381		s = nssSawIntegerDigits
382	}
383
384	for {
385		c, err = js.readNextByte()
386
387		if err != nil && err != io.EOF {
388			return nil, err
389		}
390
391		switch s {
392		case nssSawLeadingMinus:
393			switch c {
394			case '0':
395				s = nssSawLeadingZero
396				b.WriteByte(c)
397			default:
398				if isDigit(c) {
399					s = nssSawIntegerDigits
400					b.WriteByte(c)
401				} else {
402					s = nssInvalid
403				}
404			}
405		case nssSawLeadingZero:
406			switch c {
407			case '.':
408				s = nssSawDecimalPoint
409				b.WriteByte(c)
410			case 'e', 'E':
411				s = nssSawExponentLetter
412				b.WriteByte(c)
413			case '}', ']', ',':
414				s = nssDone
415			default:
416				if isWhiteSpace(c) || err == io.EOF {
417					s = nssDone
418				} else {
419					s = nssInvalid
420				}
421			}
422		case nssSawIntegerDigits:
423			switch c {
424			case '.':
425				s = nssSawDecimalPoint
426				b.WriteByte(c)
427			case 'e', 'E':
428				s = nssSawExponentLetter
429				b.WriteByte(c)
430			case '}', ']', ',':
431				s = nssDone
432			default:
433				if isWhiteSpace(c) || err == io.EOF {
434					s = nssDone
435				} else if isDigit(c) {
436					s = nssSawIntegerDigits
437					b.WriteByte(c)
438				} else {
439					s = nssInvalid
440				}
441			}
442		case nssSawDecimalPoint:
443			t = jttDouble
444			if isDigit(c) {
445				s = nssSawFractionDigits
446				b.WriteByte(c)
447			} else {
448				s = nssInvalid
449			}
450		case nssSawFractionDigits:
451			switch c {
452			case 'e', 'E':
453				s = nssSawExponentLetter
454				b.WriteByte(c)
455			case '}', ']', ',':
456				s = nssDone
457			default:
458				if isWhiteSpace(c) || err == io.EOF {
459					s = nssDone
460				} else if isDigit(c) {
461					s = nssSawFractionDigits
462					b.WriteByte(c)
463				} else {
464					s = nssInvalid
465				}
466			}
467		case nssSawExponentLetter:
468			t = jttDouble
469			switch c {
470			case '+', '-':
471				s = nssSawExponentSign
472				b.WriteByte(c)
473			default:
474				if isDigit(c) {
475					s = nssSawExponentDigits
476					b.WriteByte(c)
477				} else {
478					s = nssInvalid
479				}
480			}
481		case nssSawExponentSign:
482			if isDigit(c) {
483				s = nssSawExponentDigits
484				b.WriteByte(c)
485			} else {
486				s = nssInvalid
487			}
488		case nssSawExponentDigits:
489			switch c {
490			case '}', ']', ',':
491				s = nssDone
492			default:
493				if isWhiteSpace(c) || err == io.EOF {
494					s = nssDone
495				} else if isDigit(c) {
496					s = nssSawExponentDigits
497					b.WriteByte(c)
498				} else {
499					s = nssInvalid
500				}
501			}
502		}
503
504		switch s {
505		case nssInvalid:
506			return nil, fmt.Errorf("invalid JSON number. Position: %d", start)
507		case nssDone:
508			js.pos = int(math.Max(0, float64(js.pos-1)))
509			if t != jttDouble {
510				v, err := strconv.ParseInt(b.String(), 10, 64)
511				if err == nil {
512					if v < math.MinInt32 || v > math.MaxInt32 {
513						return &jsonToken{t: jttInt64, v: v, p: start}, nil
514					}
515
516					return &jsonToken{t: jttInt32, v: int32(v), p: start}, nil
517				}
518			}
519
520			v, err := strconv.ParseFloat(b.String(), 64)
521			if err != nil {
522				return nil, err
523			}
524
525			return &jsonToken{t: jttDouble, v: v, p: start}, nil
526		}
527	}
528}
529