1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package bsonrw
8
9import (
10	"encoding/base64"
11	"encoding/hex"
12	"errors"
13	"fmt"
14	"io"
15	"strings"
16
17	"go.mongodb.org/mongo-driver/bson/bsontype"
18)
19
20const maxNestingDepth = 200
21
22// ErrInvalidJSON indicates the JSON input is invalid
23var ErrInvalidJSON = errors.New("invalid JSON input")
24
25type jsonParseState byte
26
27const (
28	jpsStartState jsonParseState = iota
29	jpsSawBeginObject
30	jpsSawEndObject
31	jpsSawBeginArray
32	jpsSawEndArray
33	jpsSawColon
34	jpsSawComma
35	jpsSawKey
36	jpsSawValue
37	jpsDoneState
38	jpsInvalidState
39)
40
41type jsonParseMode byte
42
43const (
44	jpmInvalidMode jsonParseMode = iota
45	jpmObjectMode
46	jpmArrayMode
47)
48
49type extJSONValue struct {
50	t bsontype.Type
51	v interface{}
52}
53
54type extJSONObject struct {
55	keys   []string
56	values []*extJSONValue
57}
58
59type extJSONParser struct {
60	js *jsonScanner
61	s  jsonParseState
62	m  []jsonParseMode
63	k  string
64	v  *extJSONValue
65
66	err       error
67	canonical bool
68	depth     int
69	maxDepth  int
70
71	emptyObject bool
72	relaxedUUID bool
73}
74
75// newExtJSONParser returns a new extended JSON parser, ready to to begin
76// parsing from the first character of the argued json input. It will not
77// perform any read-ahead and will therefore not report any errors about
78// malformed JSON at this point.
79func newExtJSONParser(r io.Reader, canonical bool) *extJSONParser {
80	return &extJSONParser{
81		js:        &jsonScanner{r: r},
82		s:         jpsStartState,
83		m:         []jsonParseMode{},
84		canonical: canonical,
85		maxDepth:  maxNestingDepth,
86	}
87}
88
89// peekType examines the next value and returns its BSON Type
90func (ejp *extJSONParser) peekType() (bsontype.Type, error) {
91	var t bsontype.Type
92	var err error
93	initialState := ejp.s
94
95	ejp.advanceState()
96	switch ejp.s {
97	case jpsSawValue:
98		t = ejp.v.t
99	case jpsSawBeginArray:
100		t = bsontype.Array
101	case jpsInvalidState:
102		err = ejp.err
103	case jpsSawComma:
104		// in array mode, seeing a comma means we need to progress again to actually observe a type
105		if ejp.peekMode() == jpmArrayMode {
106			return ejp.peekType()
107		}
108	case jpsSawEndArray:
109		// this would only be a valid state if we were in array mode, so return end-of-array error
110		err = ErrEOA
111	case jpsSawBeginObject:
112		// peek key to determine type
113		ejp.advanceState()
114		switch ejp.s {
115		case jpsSawEndObject: // empty embedded document
116			t = bsontype.EmbeddedDocument
117			ejp.emptyObject = true
118		case jpsInvalidState:
119			err = ejp.err
120		case jpsSawKey:
121			if initialState == jpsStartState {
122				return bsontype.EmbeddedDocument, nil
123			}
124			t = wrapperKeyBSONType(ejp.k)
125
126			// if $uuid is encountered, parse as binary subtype 4
127			if ejp.k == "$uuid" {
128				ejp.relaxedUUID = true
129				t = bsontype.Binary
130			}
131
132			switch t {
133			case bsontype.JavaScript:
134				// just saw $code, need to check for $scope at same level
135				_, err = ejp.readValue(bsontype.JavaScript)
136				if err != nil {
137					break
138				}
139
140				switch ejp.s {
141				case jpsSawEndObject: // type is TypeJavaScript
142				case jpsSawComma:
143					ejp.advanceState()
144
145					if ejp.s == jpsSawKey && ejp.k == "$scope" {
146						t = bsontype.CodeWithScope
147					} else {
148						err = fmt.Errorf("invalid extended JSON: unexpected key %s in CodeWithScope object", ejp.k)
149					}
150				case jpsInvalidState:
151					err = ejp.err
152				default:
153					err = ErrInvalidJSON
154				}
155			case bsontype.CodeWithScope:
156				err = errors.New("invalid extended JSON: code with $scope must contain $code before $scope")
157			}
158		}
159	}
160
161	return t, err
162}
163
164// readKey parses the next key and its type and returns them
165func (ejp *extJSONParser) readKey() (string, bsontype.Type, error) {
166	if ejp.emptyObject {
167		ejp.emptyObject = false
168		return "", 0, ErrEOD
169	}
170
171	// advance to key (or return with error)
172	switch ejp.s {
173	case jpsStartState:
174		ejp.advanceState()
175		if ejp.s == jpsSawBeginObject {
176			ejp.advanceState()
177		}
178	case jpsSawBeginObject:
179		ejp.advanceState()
180	case jpsSawValue, jpsSawEndObject, jpsSawEndArray:
181		ejp.advanceState()
182		switch ejp.s {
183		case jpsSawBeginObject, jpsSawComma:
184			ejp.advanceState()
185		case jpsSawEndObject:
186			return "", 0, ErrEOD
187		case jpsDoneState:
188			return "", 0, io.EOF
189		case jpsInvalidState:
190			return "", 0, ejp.err
191		default:
192			return "", 0, ErrInvalidJSON
193		}
194	case jpsSawKey: // do nothing (key was peeked before)
195	default:
196		return "", 0, invalidRequestError("key")
197	}
198
199	// read key
200	var key string
201
202	switch ejp.s {
203	case jpsSawKey:
204		key = ejp.k
205	case jpsSawEndObject:
206		return "", 0, ErrEOD
207	case jpsInvalidState:
208		return "", 0, ejp.err
209	default:
210		return "", 0, invalidRequestError("key")
211	}
212
213	// check for colon
214	ejp.advanceState()
215	if err := ensureColon(ejp.s, key); err != nil {
216		return "", 0, err
217	}
218
219	// peek at the value to determine type
220	t, err := ejp.peekType()
221	if err != nil {
222		return "", 0, err
223	}
224
225	return key, t, nil
226}
227
228// readValue returns the value corresponding to the Type returned by peekType
229func (ejp *extJSONParser) readValue(t bsontype.Type) (*extJSONValue, error) {
230	if ejp.s == jpsInvalidState {
231		return nil, ejp.err
232	}
233
234	var v *extJSONValue
235
236	switch t {
237	case bsontype.Null, bsontype.Boolean, bsontype.String:
238		if ejp.s != jpsSawValue {
239			return nil, invalidRequestError(t.String())
240		}
241		v = ejp.v
242	case bsontype.Int32, bsontype.Int64, bsontype.Double:
243		// relaxed version allows these to be literal number values
244		if ejp.s == jpsSawValue {
245			v = ejp.v
246			break
247		}
248		fallthrough
249	case bsontype.Decimal128, bsontype.Symbol, bsontype.ObjectID, bsontype.MinKey, bsontype.MaxKey, bsontype.Undefined:
250		switch ejp.s {
251		case jpsSawKey:
252			// read colon
253			ejp.advanceState()
254			if err := ensureColon(ejp.s, ejp.k); err != nil {
255				return nil, err
256			}
257
258			// read value
259			ejp.advanceState()
260			if ejp.s != jpsSawValue || !ejp.ensureExtValueType(t) {
261				return nil, invalidJSONErrorForType("value", t)
262			}
263
264			v = ejp.v
265
266			// read end object
267			ejp.advanceState()
268			if ejp.s != jpsSawEndObject {
269				return nil, invalidJSONErrorForType("} after value", t)
270			}
271		default:
272			return nil, invalidRequestError(t.String())
273		}
274	case bsontype.Binary, bsontype.Regex, bsontype.Timestamp, bsontype.DBPointer:
275		if ejp.s != jpsSawKey {
276			return nil, invalidRequestError(t.String())
277		}
278		// read colon
279		ejp.advanceState()
280		if err := ensureColon(ejp.s, ejp.k); err != nil {
281			return nil, err
282		}
283
284		ejp.advanceState()
285		if t == bsontype.Binary && ejp.s == jpsSawValue {
286			// convert relaxed $uuid format
287			if ejp.relaxedUUID {
288				defer func() { ejp.relaxedUUID = false }()
289				uuid, err := ejp.v.parseSymbol()
290				if err != nil {
291					return nil, err
292				}
293
294				// RFC 4122 defines the length of a UUID as 36 and the hyphens in a UUID as appearing
295				// in the 8th, 13th, 18th, and 23rd characters.
296				//
297				// See https://tools.ietf.org/html/rfc4122#section-3
298				valid := len(uuid) == 36 &&
299					string(uuid[8]) == "-" &&
300					string(uuid[13]) == "-" &&
301					string(uuid[18]) == "-" &&
302					string(uuid[23]) == "-"
303				if !valid {
304					return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens")
305				}
306
307				// remove hyphens
308				uuidNoHyphens := strings.Replace(uuid, "-", "", -1)
309				if len(uuidNoHyphens) != 32 {
310					return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens")
311				}
312
313				// convert hex to bytes
314				bytes, err := hex.DecodeString(uuidNoHyphens)
315				if err != nil {
316					return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding hex bytes: %v", err)
317				}
318
319				ejp.advanceState()
320				if ejp.s != jpsSawEndObject {
321					return nil, invalidJSONErrorForType("$uuid and value and then }", bsontype.Binary)
322				}
323
324				base64 := &extJSONValue{
325					t: bsontype.String,
326					v: base64.StdEncoding.EncodeToString(bytes),
327				}
328				subType := &extJSONValue{
329					t: bsontype.String,
330					v: "04",
331				}
332
333				v = &extJSONValue{
334					t: bsontype.EmbeddedDocument,
335					v: &extJSONObject{
336						keys:   []string{"base64", "subType"},
337						values: []*extJSONValue{base64, subType},
338					},
339				}
340
341				break
342			}
343
344			// convert legacy $binary format
345			base64 := ejp.v
346
347			ejp.advanceState()
348			if ejp.s != jpsSawComma {
349				return nil, invalidJSONErrorForType(",", bsontype.Binary)
350			}
351
352			ejp.advanceState()
353			key, t, err := ejp.readKey()
354			if err != nil {
355				return nil, err
356			}
357			if key != "$type" {
358				return nil, invalidJSONErrorForType("$type", bsontype.Binary)
359			}
360
361			subType, err := ejp.readValue(t)
362			if err != nil {
363				return nil, err
364			}
365
366			ejp.advanceState()
367			if ejp.s != jpsSawEndObject {
368				return nil, invalidJSONErrorForType("2 key-value pairs and then }", bsontype.Binary)
369			}
370
371			v = &extJSONValue{
372				t: bsontype.EmbeddedDocument,
373				v: &extJSONObject{
374					keys:   []string{"base64", "subType"},
375					values: []*extJSONValue{base64, subType},
376				},
377			}
378			break
379		}
380
381		// read KV pairs
382		if ejp.s != jpsSawBeginObject {
383			return nil, invalidJSONErrorForType("{", t)
384		}
385
386		keys, vals, err := ejp.readObject(2, true)
387		if err != nil {
388			return nil, err
389		}
390
391		ejp.advanceState()
392		if ejp.s != jpsSawEndObject {
393			return nil, invalidJSONErrorForType("2 key-value pairs and then }", t)
394		}
395
396		v = &extJSONValue{t: bsontype.EmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
397
398	case bsontype.DateTime:
399		switch ejp.s {
400		case jpsSawValue:
401			v = ejp.v
402		case jpsSawKey:
403			// read colon
404			ejp.advanceState()
405			if err := ensureColon(ejp.s, ejp.k); err != nil {
406				return nil, err
407			}
408
409			ejp.advanceState()
410			switch ejp.s {
411			case jpsSawBeginObject:
412				keys, vals, err := ejp.readObject(1, true)
413				if err != nil {
414					return nil, err
415				}
416				v = &extJSONValue{t: bsontype.EmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
417			case jpsSawValue:
418				if ejp.canonical {
419					return nil, invalidJSONError("{")
420				}
421				v = ejp.v
422			default:
423				if ejp.canonical {
424					return nil, invalidJSONErrorForType("object", t)
425				}
426				return nil, invalidJSONErrorForType("ISO-8601 Internet Date/Time Format as decribed in RFC-3339", t)
427			}
428
429			ejp.advanceState()
430			if ejp.s != jpsSawEndObject {
431				return nil, invalidJSONErrorForType("value and then }", t)
432			}
433		default:
434			return nil, invalidRequestError(t.String())
435		}
436	case bsontype.JavaScript:
437		switch ejp.s {
438		case jpsSawKey:
439			// read colon
440			ejp.advanceState()
441			if err := ensureColon(ejp.s, ejp.k); err != nil {
442				return nil, err
443			}
444
445			// read value
446			ejp.advanceState()
447			if ejp.s != jpsSawValue {
448				return nil, invalidJSONErrorForType("value", t)
449			}
450			v = ejp.v
451
452			// read end object or comma and just return
453			ejp.advanceState()
454		case jpsSawEndObject:
455			v = ejp.v
456		default:
457			return nil, invalidRequestError(t.String())
458		}
459	case bsontype.CodeWithScope:
460		if ejp.s == jpsSawKey && ejp.k == "$scope" {
461			v = ejp.v // this is the $code string from earlier
462
463			// read colon
464			ejp.advanceState()
465			if err := ensureColon(ejp.s, ejp.k); err != nil {
466				return nil, err
467			}
468
469			// read {
470			ejp.advanceState()
471			if ejp.s != jpsSawBeginObject {
472				return nil, invalidJSONError("$scope to be embedded document")
473			}
474		} else {
475			return nil, invalidRequestError(t.String())
476		}
477	case bsontype.EmbeddedDocument, bsontype.Array:
478		return nil, invalidRequestError(t.String())
479	}
480
481	return v, nil
482}
483
484// readObject is a utility method for reading full objects of known (or expected) size
485// it is useful for extended JSON types such as binary, datetime, regex, and timestamp
486func (ejp *extJSONParser) readObject(numKeys int, started bool) ([]string, []*extJSONValue, error) {
487	keys := make([]string, numKeys)
488	vals := make([]*extJSONValue, numKeys)
489
490	if !started {
491		ejp.advanceState()
492		if ejp.s != jpsSawBeginObject {
493			return nil, nil, invalidJSONError("{")
494		}
495	}
496
497	for i := 0; i < numKeys; i++ {
498		key, t, err := ejp.readKey()
499		if err != nil {
500			return nil, nil, err
501		}
502
503		switch ejp.s {
504		case jpsSawKey:
505			v, err := ejp.readValue(t)
506			if err != nil {
507				return nil, nil, err
508			}
509
510			keys[i] = key
511			vals[i] = v
512		case jpsSawValue:
513			keys[i] = key
514			vals[i] = ejp.v
515		default:
516			return nil, nil, invalidJSONError("value")
517		}
518	}
519
520	ejp.advanceState()
521	if ejp.s != jpsSawEndObject {
522		return nil, nil, invalidJSONError("}")
523	}
524
525	return keys, vals, nil
526}
527
528// advanceState reads the next JSON token from the scanner and transitions
529// from the current state based on that token's type
530func (ejp *extJSONParser) advanceState() {
531	if ejp.s == jpsDoneState || ejp.s == jpsInvalidState {
532		return
533	}
534
535	jt, err := ejp.js.nextToken()
536
537	if err != nil {
538		ejp.err = err
539		ejp.s = jpsInvalidState
540		return
541	}
542
543	valid := ejp.validateToken(jt.t)
544	if !valid {
545		ejp.err = unexpectedTokenError(jt)
546		ejp.s = jpsInvalidState
547		return
548	}
549
550	switch jt.t {
551	case jttBeginObject:
552		ejp.s = jpsSawBeginObject
553		ejp.pushMode(jpmObjectMode)
554		ejp.depth++
555
556		if ejp.depth > ejp.maxDepth {
557			ejp.err = nestingDepthError(jt.p, ejp.depth)
558			ejp.s = jpsInvalidState
559		}
560	case jttEndObject:
561		ejp.s = jpsSawEndObject
562		ejp.depth--
563
564		if ejp.popMode() != jpmObjectMode {
565			ejp.err = unexpectedTokenError(jt)
566			ejp.s = jpsInvalidState
567		}
568	case jttBeginArray:
569		ejp.s = jpsSawBeginArray
570		ejp.pushMode(jpmArrayMode)
571	case jttEndArray:
572		ejp.s = jpsSawEndArray
573
574		if ejp.popMode() != jpmArrayMode {
575			ejp.err = unexpectedTokenError(jt)
576			ejp.s = jpsInvalidState
577		}
578	case jttColon:
579		ejp.s = jpsSawColon
580	case jttComma:
581		ejp.s = jpsSawComma
582	case jttEOF:
583		ejp.s = jpsDoneState
584		if len(ejp.m) != 0 {
585			ejp.err = unexpectedTokenError(jt)
586			ejp.s = jpsInvalidState
587		}
588	case jttString:
589		switch ejp.s {
590		case jpsSawComma:
591			if ejp.peekMode() == jpmArrayMode {
592				ejp.s = jpsSawValue
593				ejp.v = extendJSONToken(jt)
594				return
595			}
596			fallthrough
597		case jpsSawBeginObject:
598			ejp.s = jpsSawKey
599			ejp.k = jt.v.(string)
600			return
601		}
602		fallthrough
603	default:
604		ejp.s = jpsSawValue
605		ejp.v = extendJSONToken(jt)
606	}
607}
608
609var jpsValidTransitionTokens = map[jsonParseState]map[jsonTokenType]bool{
610	jpsStartState: {
611		jttBeginObject: true,
612		jttBeginArray:  true,
613		jttInt32:       true,
614		jttInt64:       true,
615		jttDouble:      true,
616		jttString:      true,
617		jttBool:        true,
618		jttNull:        true,
619		jttEOF:         true,
620	},
621	jpsSawBeginObject: {
622		jttEndObject: true,
623		jttString:    true,
624	},
625	jpsSawEndObject: {
626		jttEndObject: true,
627		jttEndArray:  true,
628		jttComma:     true,
629		jttEOF:       true,
630	},
631	jpsSawBeginArray: {
632		jttBeginObject: true,
633		jttBeginArray:  true,
634		jttEndArray:    true,
635		jttInt32:       true,
636		jttInt64:       true,
637		jttDouble:      true,
638		jttString:      true,
639		jttBool:        true,
640		jttNull:        true,
641	},
642	jpsSawEndArray: {
643		jttEndObject: true,
644		jttEndArray:  true,
645		jttComma:     true,
646		jttEOF:       true,
647	},
648	jpsSawColon: {
649		jttBeginObject: true,
650		jttBeginArray:  true,
651		jttInt32:       true,
652		jttInt64:       true,
653		jttDouble:      true,
654		jttString:      true,
655		jttBool:        true,
656		jttNull:        true,
657	},
658	jpsSawComma: {
659		jttBeginObject: true,
660		jttBeginArray:  true,
661		jttInt32:       true,
662		jttInt64:       true,
663		jttDouble:      true,
664		jttString:      true,
665		jttBool:        true,
666		jttNull:        true,
667	},
668	jpsSawKey: {
669		jttColon: true,
670	},
671	jpsSawValue: {
672		jttEndObject: true,
673		jttEndArray:  true,
674		jttComma:     true,
675		jttEOF:       true,
676	},
677	jpsDoneState:    {},
678	jpsInvalidState: {},
679}
680
681func (ejp *extJSONParser) validateToken(jtt jsonTokenType) bool {
682	switch ejp.s {
683	case jpsSawEndObject:
684		// if we are at depth zero and the next token is a '{',
685		// we can consider it valid only if we are not in array mode.
686		if jtt == jttBeginObject && ejp.depth == 0 {
687			return ejp.peekMode() != jpmArrayMode
688		}
689	case jpsSawComma:
690		switch ejp.peekMode() {
691		// the only valid next token after a comma inside a document is a string (a key)
692		case jpmObjectMode:
693			return jtt == jttString
694		case jpmInvalidMode:
695			return false
696		}
697	}
698
699	_, ok := jpsValidTransitionTokens[ejp.s][jtt]
700	return ok
701}
702
703// ensureExtValueType returns true if the current value has the expected
704// value type for single-key extended JSON types. For example,
705// {"$numberInt": v} v must be TypeString
706func (ejp *extJSONParser) ensureExtValueType(t bsontype.Type) bool {
707	switch t {
708	case bsontype.MinKey, bsontype.MaxKey:
709		return ejp.v.t == bsontype.Int32
710	case bsontype.Undefined:
711		return ejp.v.t == bsontype.Boolean
712	case bsontype.Int32, bsontype.Int64, bsontype.Double, bsontype.Decimal128, bsontype.Symbol, bsontype.ObjectID:
713		return ejp.v.t == bsontype.String
714	default:
715		return false
716	}
717}
718
719func (ejp *extJSONParser) pushMode(m jsonParseMode) {
720	ejp.m = append(ejp.m, m)
721}
722
723func (ejp *extJSONParser) popMode() jsonParseMode {
724	l := len(ejp.m)
725	if l == 0 {
726		return jpmInvalidMode
727	}
728
729	m := ejp.m[l-1]
730	ejp.m = ejp.m[:l-1]
731
732	return m
733}
734
735func (ejp *extJSONParser) peekMode() jsonParseMode {
736	l := len(ejp.m)
737	if l == 0 {
738		return jpmInvalidMode
739	}
740
741	return ejp.m[l-1]
742}
743
744func extendJSONToken(jt *jsonToken) *extJSONValue {
745	var t bsontype.Type
746
747	switch jt.t {
748	case jttInt32:
749		t = bsontype.Int32
750	case jttInt64:
751		t = bsontype.Int64
752	case jttDouble:
753		t = bsontype.Double
754	case jttString:
755		t = bsontype.String
756	case jttBool:
757		t = bsontype.Boolean
758	case jttNull:
759		t = bsontype.Null
760	default:
761		return nil
762	}
763
764	return &extJSONValue{t: t, v: jt.v}
765}
766
767func ensureColon(s jsonParseState, key string) error {
768	if s != jpsSawColon {
769		return fmt.Errorf("invalid JSON input: missing colon after key \"%s\"", key)
770	}
771
772	return nil
773}
774
775func invalidRequestError(s string) error {
776	return fmt.Errorf("invalid request to read %s", s)
777}
778
779func invalidJSONError(expected string) error {
780	return fmt.Errorf("invalid JSON input; expected %s", expected)
781}
782
783func invalidJSONErrorForType(expected string, t bsontype.Type) error {
784	return fmt.Errorf("invalid JSON input; expected %s for %s", expected, t)
785}
786
787func unexpectedTokenError(jt *jsonToken) error {
788	switch jt.t {
789	case jttInt32, jttInt64, jttDouble:
790		return fmt.Errorf("invalid JSON input; unexpected number (%v) at position %d", jt.v, jt.p)
791	case jttString:
792		return fmt.Errorf("invalid JSON input; unexpected string (\"%v\") at position %d", jt.v, jt.p)
793	case jttBool:
794		return fmt.Errorf("invalid JSON input; unexpected boolean literal (%v) at position %d", jt.v, jt.p)
795	case jttNull:
796		return fmt.Errorf("invalid JSON input; unexpected null literal at position %d", jt.p)
797	case jttEOF:
798		return fmt.Errorf("invalid JSON input; unexpected end of input at position %d", jt.p)
799	default:
800		return fmt.Errorf("invalid JSON input; unexpected %c at position %d", jt.v.(byte), jt.p)
801	}
802}
803
804func nestingDepthError(p, depth int) error {
805	return fmt.Errorf("invalid JSON input; nesting too deep (%d levels) at position %d", depth, p)
806}
807