1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package bsonrw
8
9import (
10	"errors"
11	"fmt"
12	"io"
13
14	"go.mongodb.org/mongo-driver/bson/bsontype"
15)
16
17const maxNestingDepth = 200
18
19// ErrInvalidJSON indicates the JSON input is invalid
20var ErrInvalidJSON = errors.New("invalid JSON input")
21
22type jsonParseState byte
23
24const (
25	jpsStartState jsonParseState = iota
26	jpsSawBeginObject
27	jpsSawEndObject
28	jpsSawBeginArray
29	jpsSawEndArray
30	jpsSawColon
31	jpsSawComma
32	jpsSawKey
33	jpsSawValue
34	jpsDoneState
35	jpsInvalidState
36)
37
38type jsonParseMode byte
39
40const (
41	jpmInvalidMode jsonParseMode = iota
42	jpmObjectMode
43	jpmArrayMode
44)
45
46type extJSONValue struct {
47	t bsontype.Type
48	v interface{}
49}
50
51type extJSONObject struct {
52	keys   []string
53	values []*extJSONValue
54}
55
56type extJSONParser struct {
57	js *jsonScanner
58	s  jsonParseState
59	m  []jsonParseMode
60	k  string
61	v  *extJSONValue
62
63	err       error
64	canonical bool
65	depth     int
66	maxDepth  int
67
68	emptyObject bool
69}
70
71// newExtJSONParser returns a new extended JSON parser, ready to to begin
72// parsing from the first character of the argued json input. It will not
73// perform any read-ahead and will therefore not report any errors about
74// malformed JSON at this point.
75func newExtJSONParser(r io.Reader, canonical bool) *extJSONParser {
76	return &extJSONParser{
77		js:        &jsonScanner{r: r},
78		s:         jpsStartState,
79		m:         []jsonParseMode{},
80		canonical: canonical,
81		maxDepth:  maxNestingDepth,
82	}
83}
84
85// peekType examines the next value and returns its BSON Type
86func (ejp *extJSONParser) peekType() (bsontype.Type, error) {
87	var t bsontype.Type
88	var err error
89
90	ejp.advanceState()
91	switch ejp.s {
92	case jpsSawValue:
93		t = ejp.v.t
94	case jpsSawBeginArray:
95		t = bsontype.Array
96	case jpsInvalidState:
97		err = ejp.err
98	case jpsSawComma:
99		// in array mode, seeing a comma means we need to progress again to actually observe a type
100		if ejp.peekMode() == jpmArrayMode {
101			return ejp.peekType()
102		}
103	case jpsSawEndArray:
104		// this would only be a valid state if we were in array mode, so return end-of-array error
105		err = ErrEOA
106	case jpsSawBeginObject:
107		// peek key to determine type
108		ejp.advanceState()
109		switch ejp.s {
110		case jpsSawEndObject: // empty embedded document
111			t = bsontype.EmbeddedDocument
112			ejp.emptyObject = true
113		case jpsInvalidState:
114			err = ejp.err
115		case jpsSawKey:
116			t = wrapperKeyBSONType(ejp.k)
117
118			if t == bsontype.JavaScript {
119				// just saw $code, need to check for $scope at same level
120				_, err := ejp.readValue(bsontype.JavaScript)
121
122				if err != nil {
123					break
124				}
125
126				switch ejp.s {
127				case jpsSawEndObject: // type is TypeJavaScript
128				case jpsSawComma:
129					ejp.advanceState()
130					if ejp.s == jpsSawKey && ejp.k == "$scope" {
131						t = bsontype.CodeWithScope
132					} else {
133						err = fmt.Errorf("invalid extended JSON: unexpected key %s in CodeWithScope object", ejp.k)
134					}
135				case jpsInvalidState:
136					err = ejp.err
137				default:
138					err = ErrInvalidJSON
139				}
140			}
141		}
142	}
143
144	return t, err
145}
146
147// readKey parses the next key and its type and returns them
148func (ejp *extJSONParser) readKey() (string, bsontype.Type, error) {
149	if ejp.emptyObject {
150		ejp.emptyObject = false
151		return "", 0, ErrEOD
152	}
153
154	// advance to key (or return with error)
155	switch ejp.s {
156	case jpsStartState:
157		ejp.advanceState()
158		if ejp.s == jpsSawBeginObject {
159			ejp.advanceState()
160		}
161	case jpsSawBeginObject:
162		ejp.advanceState()
163	case jpsSawValue, jpsSawEndObject, jpsSawEndArray:
164		ejp.advanceState()
165		switch ejp.s {
166		case jpsSawBeginObject, jpsSawComma:
167			ejp.advanceState()
168		case jpsSawEndObject:
169			return "", 0, ErrEOD
170		case jpsDoneState:
171			return "", 0, io.EOF
172		case jpsInvalidState:
173			return "", 0, ejp.err
174		default:
175			return "", 0, ErrInvalidJSON
176		}
177	case jpsSawKey: // do nothing (key was peeked before)
178	default:
179		return "", 0, invalidRequestError("key")
180	}
181
182	// read key
183	var key string
184
185	switch ejp.s {
186	case jpsSawKey:
187		key = ejp.k
188	case jpsSawEndObject:
189		return "", 0, ErrEOD
190	case jpsInvalidState:
191		return "", 0, ejp.err
192	default:
193		return "", 0, invalidRequestError("key")
194	}
195
196	// check for colon
197	ejp.advanceState()
198	if err := ensureColon(ejp.s, key); err != nil {
199		return "", 0, err
200	}
201
202	// peek at the value to determine type
203	t, err := ejp.peekType()
204	if err != nil {
205		return "", 0, err
206	}
207
208	return key, t, nil
209}
210
211// readValue returns the value corresponding to the Type returned by peekType
212func (ejp *extJSONParser) readValue(t bsontype.Type) (*extJSONValue, error) {
213	if ejp.s == jpsInvalidState {
214		return nil, ejp.err
215	}
216
217	var v *extJSONValue
218
219	switch t {
220	case bsontype.Null, bsontype.Boolean, bsontype.String:
221		if ejp.s != jpsSawValue {
222			return nil, invalidRequestError(t.String())
223		}
224		v = ejp.v
225	case bsontype.Int32, bsontype.Int64, bsontype.Double:
226		// relaxed version allows these to be literal number values
227		if ejp.s == jpsSawValue {
228			v = ejp.v
229			break
230		}
231		fallthrough
232	case bsontype.Decimal128, bsontype.Symbol, bsontype.ObjectID, bsontype.MinKey, bsontype.MaxKey, bsontype.Undefined:
233		switch ejp.s {
234		case jpsSawKey:
235			// read colon
236			ejp.advanceState()
237			if err := ensureColon(ejp.s, ejp.k); err != nil {
238				return nil, err
239			}
240
241			// read value
242			ejp.advanceState()
243			if ejp.s != jpsSawValue || !ejp.ensureExtValueType(t) {
244				return nil, invalidJSONErrorForType("value", t)
245			}
246
247			v = ejp.v
248
249			// read end object
250			ejp.advanceState()
251			if ejp.s != jpsSawEndObject {
252				return nil, invalidJSONErrorForType("} after value", t)
253			}
254		default:
255			return nil, invalidRequestError(t.String())
256		}
257	case bsontype.Binary, bsontype.Regex, bsontype.Timestamp, bsontype.DBPointer:
258		if ejp.s != jpsSawKey {
259			return nil, invalidRequestError(t.String())
260		}
261		// read colon
262		ejp.advanceState()
263		if err := ensureColon(ejp.s, ejp.k); err != nil {
264			return nil, err
265		}
266
267		ejp.advanceState()
268		if t == bsontype.Binary && ejp.s == jpsSawValue {
269			// convert legacy $binary format
270			base64 := ejp.v
271
272			ejp.advanceState()
273			if ejp.s != jpsSawComma {
274				return nil, invalidJSONErrorForType(",", bsontype.Binary)
275			}
276
277			ejp.advanceState()
278			key, t, err := ejp.readKey()
279			if err != nil {
280				return nil, err
281			}
282			if key != "$type" {
283				return nil, invalidJSONErrorForType("$type", bsontype.Binary)
284			}
285
286			subType, err := ejp.readValue(t)
287			if err != nil {
288				return nil, err
289			}
290
291			ejp.advanceState()
292			if ejp.s != jpsSawEndObject {
293				return nil, invalidJSONErrorForType("2 key-value pairs and then }", bsontype.Binary)
294			}
295
296			v = &extJSONValue{
297				t: bsontype.EmbeddedDocument,
298				v: &extJSONObject{
299					keys:   []string{"base64", "subType"},
300					values: []*extJSONValue{base64, subType},
301				},
302			}
303			break
304		}
305
306		// read KV pairs
307		if ejp.s != jpsSawBeginObject {
308			return nil, invalidJSONErrorForType("{", t)
309		}
310
311		keys, vals, err := ejp.readObject(2, true)
312		if err != nil {
313			return nil, err
314		}
315
316		ejp.advanceState()
317		if ejp.s != jpsSawEndObject {
318			return nil, invalidJSONErrorForType("2 key-value pairs and then }", t)
319		}
320
321		v = &extJSONValue{t: bsontype.EmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
322
323	case bsontype.DateTime:
324		switch ejp.s {
325		case jpsSawValue:
326			v = ejp.v
327		case jpsSawKey:
328			// read colon
329			ejp.advanceState()
330			if err := ensureColon(ejp.s, ejp.k); err != nil {
331				return nil, err
332			}
333
334			ejp.advanceState()
335			switch ejp.s {
336			case jpsSawBeginObject:
337				keys, vals, err := ejp.readObject(1, true)
338				if err != nil {
339					return nil, err
340				}
341				v = &extJSONValue{t: bsontype.EmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
342			case jpsSawValue:
343				if ejp.canonical {
344					return nil, invalidJSONError("{")
345				}
346				v = ejp.v
347			default:
348				if ejp.canonical {
349					return nil, invalidJSONErrorForType("object", t)
350				}
351				return nil, invalidJSONErrorForType("ISO-8601 Internet Date/Time Format as decribed in RFC-3339", t)
352			}
353
354			ejp.advanceState()
355			if ejp.s != jpsSawEndObject {
356				return nil, invalidJSONErrorForType("value and then }", t)
357			}
358		default:
359			return nil, invalidRequestError(t.String())
360		}
361	case bsontype.JavaScript:
362		switch ejp.s {
363		case jpsSawKey:
364			// read colon
365			ejp.advanceState()
366			if err := ensureColon(ejp.s, ejp.k); err != nil {
367				return nil, err
368			}
369
370			// read value
371			ejp.advanceState()
372			if ejp.s != jpsSawValue {
373				return nil, invalidJSONErrorForType("value", t)
374			}
375			v = ejp.v
376
377			// read end object or comma and just return
378			ejp.advanceState()
379		case jpsSawEndObject:
380			v = ejp.v
381		default:
382			return nil, invalidRequestError(t.String())
383		}
384	case bsontype.CodeWithScope:
385		if ejp.s == jpsSawKey && ejp.k == "$scope" {
386			v = ejp.v // this is the $code string from earlier
387
388			// read colon
389			ejp.advanceState()
390			if err := ensureColon(ejp.s, ejp.k); err != nil {
391				return nil, err
392			}
393
394			// read {
395			ejp.advanceState()
396			if ejp.s != jpsSawBeginObject {
397				return nil, invalidJSONError("$scope to be embedded document")
398			}
399		} else {
400			return nil, invalidRequestError(t.String())
401		}
402	case bsontype.EmbeddedDocument, bsontype.Array:
403		return nil, invalidRequestError(t.String())
404	}
405
406	return v, nil
407}
408
409// readObject is a utility method for reading full objects of known (or expected) size
410// it is useful for extended JSON types such as binary, datetime, regex, and timestamp
411func (ejp *extJSONParser) readObject(numKeys int, started bool) ([]string, []*extJSONValue, error) {
412	keys := make([]string, numKeys)
413	vals := make([]*extJSONValue, numKeys)
414
415	if !started {
416		ejp.advanceState()
417		if ejp.s != jpsSawBeginObject {
418			return nil, nil, invalidJSONError("{")
419		}
420	}
421
422	for i := 0; i < numKeys; i++ {
423		key, t, err := ejp.readKey()
424		if err != nil {
425			return nil, nil, err
426		}
427
428		switch ejp.s {
429		case jpsSawKey:
430			v, err := ejp.readValue(t)
431			if err != nil {
432				return nil, nil, err
433			}
434
435			keys[i] = key
436			vals[i] = v
437		case jpsSawValue:
438			keys[i] = key
439			vals[i] = ejp.v
440		default:
441			return nil, nil, invalidJSONError("value")
442		}
443	}
444
445	ejp.advanceState()
446	if ejp.s != jpsSawEndObject {
447		return nil, nil, invalidJSONError("}")
448	}
449
450	return keys, vals, nil
451}
452
453// advanceState reads the next JSON token from the scanner and transitions
454// from the current state based on that token's type
455func (ejp *extJSONParser) advanceState() {
456	if ejp.s == jpsDoneState || ejp.s == jpsInvalidState {
457		return
458	}
459
460	jt, err := ejp.js.nextToken()
461
462	if err != nil {
463		ejp.err = err
464		ejp.s = jpsInvalidState
465		return
466	}
467
468	valid := ejp.validateToken(jt.t)
469	if !valid {
470		ejp.err = unexpectedTokenError(jt)
471		ejp.s = jpsInvalidState
472		return
473	}
474
475	switch jt.t {
476	case jttBeginObject:
477		ejp.s = jpsSawBeginObject
478		ejp.pushMode(jpmObjectMode)
479		ejp.depth++
480
481		if ejp.depth > ejp.maxDepth {
482			ejp.err = nestingDepthError(jt.p, ejp.depth)
483			ejp.s = jpsInvalidState
484		}
485	case jttEndObject:
486		ejp.s = jpsSawEndObject
487		ejp.depth--
488
489		if ejp.popMode() != jpmObjectMode {
490			ejp.err = unexpectedTokenError(jt)
491			ejp.s = jpsInvalidState
492		}
493	case jttBeginArray:
494		ejp.s = jpsSawBeginArray
495		ejp.pushMode(jpmArrayMode)
496	case jttEndArray:
497		ejp.s = jpsSawEndArray
498
499		if ejp.popMode() != jpmArrayMode {
500			ejp.err = unexpectedTokenError(jt)
501			ejp.s = jpsInvalidState
502		}
503	case jttColon:
504		ejp.s = jpsSawColon
505	case jttComma:
506		ejp.s = jpsSawComma
507	case jttEOF:
508		ejp.s = jpsDoneState
509		if len(ejp.m) != 0 {
510			ejp.err = unexpectedTokenError(jt)
511			ejp.s = jpsInvalidState
512		}
513	case jttString:
514		switch ejp.s {
515		case jpsSawComma:
516			if ejp.peekMode() == jpmArrayMode {
517				ejp.s = jpsSawValue
518				ejp.v = extendJSONToken(jt)
519				return
520			}
521			fallthrough
522		case jpsSawBeginObject:
523			ejp.s = jpsSawKey
524			ejp.k = jt.v.(string)
525			return
526		}
527		fallthrough
528	default:
529		ejp.s = jpsSawValue
530		ejp.v = extendJSONToken(jt)
531	}
532}
533
534var jpsValidTransitionTokens = map[jsonParseState]map[jsonTokenType]bool{
535	jpsStartState: {
536		jttBeginObject: true,
537		jttBeginArray:  true,
538		jttInt32:       true,
539		jttInt64:       true,
540		jttDouble:      true,
541		jttString:      true,
542		jttBool:        true,
543		jttNull:        true,
544		jttEOF:         true,
545	},
546	jpsSawBeginObject: {
547		jttEndObject: true,
548		jttString:    true,
549	},
550	jpsSawEndObject: {
551		jttEndObject: true,
552		jttEndArray:  true,
553		jttComma:     true,
554		jttEOF:       true,
555	},
556	jpsSawBeginArray: {
557		jttBeginObject: true,
558		jttBeginArray:  true,
559		jttEndArray:    true,
560		jttInt32:       true,
561		jttInt64:       true,
562		jttDouble:      true,
563		jttString:      true,
564		jttBool:        true,
565		jttNull:        true,
566	},
567	jpsSawEndArray: {
568		jttEndObject: true,
569		jttEndArray:  true,
570		jttComma:     true,
571		jttEOF:       true,
572	},
573	jpsSawColon: {
574		jttBeginObject: true,
575		jttBeginArray:  true,
576		jttInt32:       true,
577		jttInt64:       true,
578		jttDouble:      true,
579		jttString:      true,
580		jttBool:        true,
581		jttNull:        true,
582	},
583	jpsSawComma: {
584		jttBeginObject: true,
585		jttBeginArray:  true,
586		jttInt32:       true,
587		jttInt64:       true,
588		jttDouble:      true,
589		jttString:      true,
590		jttBool:        true,
591		jttNull:        true,
592	},
593	jpsSawKey: {
594		jttColon: true,
595	},
596	jpsSawValue: {
597		jttEndObject: true,
598		jttEndArray:  true,
599		jttComma:     true,
600		jttEOF:       true,
601	},
602	jpsDoneState:    {},
603	jpsInvalidState: {},
604}
605
606func (ejp *extJSONParser) validateToken(jtt jsonTokenType) bool {
607	switch ejp.s {
608	case jpsSawEndObject:
609		// if we are at depth zero and the next token is a '{',
610		// we can consider it valid only if we are not in array mode.
611		if jtt == jttBeginObject && ejp.depth == 0 {
612			return ejp.peekMode() != jpmArrayMode
613		}
614	case jpsSawComma:
615		switch ejp.peekMode() {
616		// the only valid next token after a comma inside a document is a string (a key)
617		case jpmObjectMode:
618			return jtt == jttString
619		case jpmInvalidMode:
620			return false
621		}
622	}
623
624	_, ok := jpsValidTransitionTokens[ejp.s][jtt]
625	return ok
626}
627
628// ensureExtValueType returns true if the current value has the expected
629// value type for single-key extended JSON types. For example,
630// {"$numberInt": v} v must be TypeString
631func (ejp *extJSONParser) ensureExtValueType(t bsontype.Type) bool {
632	switch t {
633	case bsontype.MinKey, bsontype.MaxKey:
634		return ejp.v.t == bsontype.Int32
635	case bsontype.Undefined:
636		return ejp.v.t == bsontype.Boolean
637	case bsontype.Int32, bsontype.Int64, bsontype.Double, bsontype.Decimal128, bsontype.Symbol, bsontype.ObjectID:
638		return ejp.v.t == bsontype.String
639	default:
640		return false
641	}
642}
643
644func (ejp *extJSONParser) pushMode(m jsonParseMode) {
645	ejp.m = append(ejp.m, m)
646}
647
648func (ejp *extJSONParser) popMode() jsonParseMode {
649	l := len(ejp.m)
650	if l == 0 {
651		return jpmInvalidMode
652	}
653
654	m := ejp.m[l-1]
655	ejp.m = ejp.m[:l-1]
656
657	return m
658}
659
660func (ejp *extJSONParser) peekMode() jsonParseMode {
661	l := len(ejp.m)
662	if l == 0 {
663		return jpmInvalidMode
664	}
665
666	return ejp.m[l-1]
667}
668
669func extendJSONToken(jt *jsonToken) *extJSONValue {
670	var t bsontype.Type
671
672	switch jt.t {
673	case jttInt32:
674		t = bsontype.Int32
675	case jttInt64:
676		t = bsontype.Int64
677	case jttDouble:
678		t = bsontype.Double
679	case jttString:
680		t = bsontype.String
681	case jttBool:
682		t = bsontype.Boolean
683	case jttNull:
684		t = bsontype.Null
685	default:
686		return nil
687	}
688
689	return &extJSONValue{t: t, v: jt.v}
690}
691
692func ensureColon(s jsonParseState, key string) error {
693	if s != jpsSawColon {
694		return fmt.Errorf("invalid JSON input: missing colon after key \"%s\"", key)
695	}
696
697	return nil
698}
699
700func invalidRequestError(s string) error {
701	return fmt.Errorf("invalid request to read %s", s)
702}
703
704func invalidJSONError(expected string) error {
705	return fmt.Errorf("invalid JSON input; expected %s", expected)
706}
707
708func invalidJSONErrorForType(expected string, t bsontype.Type) error {
709	return fmt.Errorf("invalid JSON input; expected %s for %s", expected, t)
710}
711
712func unexpectedTokenError(jt *jsonToken) error {
713	switch jt.t {
714	case jttInt32, jttInt64, jttDouble:
715		return fmt.Errorf("invalid JSON input; unexpected number (%v) at position %d", jt.v, jt.p)
716	case jttString:
717		return fmt.Errorf("invalid JSON input; unexpected string (\"%v\") at position %d", jt.v, jt.p)
718	case jttBool:
719		return fmt.Errorf("invalid JSON input; unexpected boolean literal (%v) at position %d", jt.v, jt.p)
720	case jttNull:
721		return fmt.Errorf("invalid JSON input; unexpected null literal at position %d", jt.p)
722	case jttEOF:
723		return fmt.Errorf("invalid JSON input; unexpected end of input at position %d", jt.p)
724	default:
725		return fmt.Errorf("invalid JSON input; unexpected %c at position %d", jt.v.(byte), jt.p)
726	}
727}
728
729func nestingDepthError(p, depth int) error {
730	return fmt.Errorf("invalid JSON input; nesting too deep (%d levels) at position %d", depth, p)
731}
732