1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package bsonrw
8
9import (
10	"bytes"
11	"encoding/base64"
12	"fmt"
13	"go.mongodb.org/mongo-driver/bson/primitive"
14	"io"
15	"math"
16	"sort"
17	"strconv"
18	"strings"
19	"sync"
20	"time"
21	"unicode/utf8"
22)
23
24var ejvwPool = sync.Pool{
25	New: func() interface{} {
26		return new(extJSONValueWriter)
27	},
28}
29
30// ExtJSONValueWriterPool is a pool for ExtJSON ValueWriters.
31type ExtJSONValueWriterPool struct {
32	pool sync.Pool
33}
34
35// NewExtJSONValueWriterPool creates a new pool for ValueWriter instances that write to ExtJSON.
36func NewExtJSONValueWriterPool() *ExtJSONValueWriterPool {
37	return &ExtJSONValueWriterPool{
38		pool: sync.Pool{
39			New: func() interface{} {
40				return new(extJSONValueWriter)
41			},
42		},
43	}
44}
45
46// Get retrieves a ExtJSON ValueWriter from the pool and resets it to use w as the destination.
47func (bvwp *ExtJSONValueWriterPool) Get(w io.Writer, canonical, escapeHTML bool) ValueWriter {
48	vw := bvwp.pool.Get().(*extJSONValueWriter)
49	if writer, ok := w.(*SliceWriter); ok {
50		vw.reset(*writer, canonical, escapeHTML)
51		vw.w = writer
52		return vw
53	}
54	vw.buf = vw.buf[:0]
55	vw.w = w
56	return vw
57}
58
59// Put inserts a ValueWriter into the pool. If the ValueWriter is not a ExtJSON ValueWriter, nothing
60// happens and ok will be false.
61func (bvwp *ExtJSONValueWriterPool) Put(vw ValueWriter) (ok bool) {
62	bvw, ok := vw.(*extJSONValueWriter)
63	if !ok {
64		return false
65	}
66
67	if _, ok := bvw.w.(*SliceWriter); ok {
68		bvw.buf = nil
69	}
70	bvw.w = nil
71
72	bvwp.pool.Put(bvw)
73	return true
74}
75
76type ejvwState struct {
77	mode mode
78}
79
80type extJSONValueWriter struct {
81	w   io.Writer
82	buf []byte
83
84	stack      []ejvwState
85	frame      int64
86	canonical  bool
87	escapeHTML bool
88}
89
90// NewExtJSONValueWriter creates a ValueWriter that writes Extended JSON to w.
91func NewExtJSONValueWriter(w io.Writer, canonical, escapeHTML bool) (ValueWriter, error) {
92	if w == nil {
93		return nil, errNilWriter
94	}
95
96	return newExtJSONWriter(w, canonical, escapeHTML), nil
97}
98
99func newExtJSONWriter(w io.Writer, canonical, escapeHTML bool) *extJSONValueWriter {
100	stack := make([]ejvwState, 1, 5)
101	stack[0] = ejvwState{mode: mTopLevel}
102
103	return &extJSONValueWriter{
104		w:          w,
105		buf:        []byte{},
106		stack:      stack,
107		canonical:  canonical,
108		escapeHTML: escapeHTML,
109	}
110}
111
112func newExtJSONWriterFromSlice(buf []byte, canonical, escapeHTML bool) *extJSONValueWriter {
113	stack := make([]ejvwState, 1, 5)
114	stack[0] = ejvwState{mode: mTopLevel}
115
116	return &extJSONValueWriter{
117		buf:        buf,
118		stack:      stack,
119		canonical:  canonical,
120		escapeHTML: escapeHTML,
121	}
122}
123
124func (ejvw *extJSONValueWriter) reset(buf []byte, canonical, escapeHTML bool) {
125	if ejvw.stack == nil {
126		ejvw.stack = make([]ejvwState, 1, 5)
127	}
128
129	ejvw.stack = ejvw.stack[:1]
130	ejvw.stack[0] = ejvwState{mode: mTopLevel}
131	ejvw.canonical = canonical
132	ejvw.escapeHTML = escapeHTML
133	ejvw.frame = 0
134	ejvw.buf = buf
135	ejvw.w = nil
136}
137
138func (ejvw *extJSONValueWriter) advanceFrame() {
139	if ejvw.frame+1 >= int64(len(ejvw.stack)) { // We need to grow the stack
140		length := len(ejvw.stack)
141		if length+1 >= cap(ejvw.stack) {
142			// double it
143			buf := make([]ejvwState, 2*cap(ejvw.stack)+1)
144			copy(buf, ejvw.stack)
145			ejvw.stack = buf
146		}
147		ejvw.stack = ejvw.stack[:length+1]
148	}
149	ejvw.frame++
150}
151
152func (ejvw *extJSONValueWriter) push(m mode) {
153	ejvw.advanceFrame()
154
155	ejvw.stack[ejvw.frame].mode = m
156}
157
158func (ejvw *extJSONValueWriter) pop() {
159	switch ejvw.stack[ejvw.frame].mode {
160	case mElement, mValue:
161		ejvw.frame--
162	case mDocument, mArray, mCodeWithScope:
163		ejvw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
164	}
165}
166
167func (ejvw *extJSONValueWriter) invalidTransitionErr(destination mode, name string, modes []mode) error {
168	te := TransitionError{
169		name:        name,
170		current:     ejvw.stack[ejvw.frame].mode,
171		destination: destination,
172		modes:       modes,
173		action:      "write",
174	}
175	if ejvw.frame != 0 {
176		te.parent = ejvw.stack[ejvw.frame-1].mode
177	}
178	return te
179}
180
181func (ejvw *extJSONValueWriter) ensureElementValue(destination mode, callerName string, addmodes ...mode) error {
182	switch ejvw.stack[ejvw.frame].mode {
183	case mElement, mValue:
184	default:
185		modes := []mode{mElement, mValue}
186		if addmodes != nil {
187			modes = append(modes, addmodes...)
188		}
189		return ejvw.invalidTransitionErr(destination, callerName, modes)
190	}
191
192	return nil
193}
194
195func (ejvw *extJSONValueWriter) writeExtendedSingleValue(key string, value string, quotes bool) {
196	var s string
197	if quotes {
198		s = fmt.Sprintf(`{"$%s":"%s"}`, key, value)
199	} else {
200		s = fmt.Sprintf(`{"$%s":%s}`, key, value)
201	}
202
203	ejvw.buf = append(ejvw.buf, []byte(s)...)
204}
205
206func (ejvw *extJSONValueWriter) WriteArray() (ArrayWriter, error) {
207	if err := ejvw.ensureElementValue(mArray, "WriteArray"); err != nil {
208		return nil, err
209	}
210
211	ejvw.buf = append(ejvw.buf, '[')
212
213	ejvw.push(mArray)
214	return ejvw, nil
215}
216
217func (ejvw *extJSONValueWriter) WriteBinary(b []byte) error {
218	return ejvw.WriteBinaryWithSubtype(b, 0x00)
219}
220
221func (ejvw *extJSONValueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
222	if err := ejvw.ensureElementValue(mode(0), "WriteBinaryWithSubtype"); err != nil {
223		return err
224	}
225
226	var buf bytes.Buffer
227	buf.WriteString(`{"$binary":{"base64":"`)
228	buf.WriteString(base64.StdEncoding.EncodeToString(b))
229	buf.WriteString(fmt.Sprintf(`","subType":"%02x"}},`, btype))
230
231	ejvw.buf = append(ejvw.buf, buf.Bytes()...)
232
233	ejvw.pop()
234	return nil
235}
236
237func (ejvw *extJSONValueWriter) WriteBoolean(b bool) error {
238	if err := ejvw.ensureElementValue(mode(0), "WriteBoolean"); err != nil {
239		return err
240	}
241
242	ejvw.buf = append(ejvw.buf, []byte(strconv.FormatBool(b))...)
243	ejvw.buf = append(ejvw.buf, ',')
244
245	ejvw.pop()
246	return nil
247}
248
249func (ejvw *extJSONValueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
250	if err := ejvw.ensureElementValue(mCodeWithScope, "WriteCodeWithScope"); err != nil {
251		return nil, err
252	}
253
254	var buf bytes.Buffer
255	buf.WriteString(`{"$code":`)
256	writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
257	buf.WriteString(`,"$scope":{`)
258
259	ejvw.buf = append(ejvw.buf, buf.Bytes()...)
260
261	ejvw.push(mCodeWithScope)
262	return ejvw, nil
263}
264
265func (ejvw *extJSONValueWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error {
266	if err := ejvw.ensureElementValue(mode(0), "WriteDBPointer"); err != nil {
267		return err
268	}
269
270	var buf bytes.Buffer
271	buf.WriteString(`{"$dbPointer":{"$ref":"`)
272	buf.WriteString(ns)
273	buf.WriteString(`","$id":{"$oid":"`)
274	buf.WriteString(oid.Hex())
275	buf.WriteString(`"}}},`)
276
277	ejvw.buf = append(ejvw.buf, buf.Bytes()...)
278
279	ejvw.pop()
280	return nil
281}
282
283func (ejvw *extJSONValueWriter) WriteDateTime(dt int64) error {
284	if err := ejvw.ensureElementValue(mode(0), "WriteDateTime"); err != nil {
285		return err
286	}
287
288	t := time.Unix(dt/1e3, dt%1e3*1e6).UTC()
289
290	if ejvw.canonical || t.Year() < 1970 || t.Year() > 9999 {
291		s := fmt.Sprintf(`{"$numberLong":"%d"}`, dt)
292		ejvw.writeExtendedSingleValue("date", s, false)
293	} else {
294		ejvw.writeExtendedSingleValue("date", t.Format(rfc3339Milli), true)
295	}
296
297	ejvw.buf = append(ejvw.buf, ',')
298
299	ejvw.pop()
300	return nil
301}
302
303func (ejvw *extJSONValueWriter) WriteDecimal128(d primitive.Decimal128) error {
304	if err := ejvw.ensureElementValue(mode(0), "WriteDecimal128"); err != nil {
305		return err
306	}
307
308	ejvw.writeExtendedSingleValue("numberDecimal", d.String(), true)
309	ejvw.buf = append(ejvw.buf, ',')
310
311	ejvw.pop()
312	return nil
313}
314
315func (ejvw *extJSONValueWriter) WriteDocument() (DocumentWriter, error) {
316	if ejvw.stack[ejvw.frame].mode == mTopLevel {
317		ejvw.buf = append(ejvw.buf, '{')
318		return ejvw, nil
319	}
320
321	if err := ejvw.ensureElementValue(mDocument, "WriteDocument", mTopLevel); err != nil {
322		return nil, err
323	}
324
325	ejvw.buf = append(ejvw.buf, '{')
326	ejvw.push(mDocument)
327	return ejvw, nil
328}
329
330func (ejvw *extJSONValueWriter) WriteDouble(f float64) error {
331	if err := ejvw.ensureElementValue(mode(0), "WriteDouble"); err != nil {
332		return err
333	}
334
335	s := formatDouble(f)
336
337	if ejvw.canonical {
338		ejvw.writeExtendedSingleValue("numberDouble", s, true)
339	} else {
340		switch s {
341		case "Infinity":
342			fallthrough
343		case "-Infinity":
344			fallthrough
345		case "NaN":
346			s = fmt.Sprintf(`{"$numberDouble":"%s"}`, s)
347		}
348		ejvw.buf = append(ejvw.buf, []byte(s)...)
349	}
350
351	ejvw.buf = append(ejvw.buf, ',')
352
353	ejvw.pop()
354	return nil
355}
356
357func (ejvw *extJSONValueWriter) WriteInt32(i int32) error {
358	if err := ejvw.ensureElementValue(mode(0), "WriteInt32"); err != nil {
359		return err
360	}
361
362	s := strconv.FormatInt(int64(i), 10)
363
364	if ejvw.canonical {
365		ejvw.writeExtendedSingleValue("numberInt", s, true)
366	} else {
367		ejvw.buf = append(ejvw.buf, []byte(s)...)
368	}
369
370	ejvw.buf = append(ejvw.buf, ',')
371
372	ejvw.pop()
373	return nil
374}
375
376func (ejvw *extJSONValueWriter) WriteInt64(i int64) error {
377	if err := ejvw.ensureElementValue(mode(0), "WriteInt64"); err != nil {
378		return err
379	}
380
381	s := strconv.FormatInt(i, 10)
382
383	if ejvw.canonical {
384		ejvw.writeExtendedSingleValue("numberLong", s, true)
385	} else {
386		ejvw.buf = append(ejvw.buf, []byte(s)...)
387	}
388
389	ejvw.buf = append(ejvw.buf, ',')
390
391	ejvw.pop()
392	return nil
393}
394
395func (ejvw *extJSONValueWriter) WriteJavascript(code string) error {
396	if err := ejvw.ensureElementValue(mode(0), "WriteJavascript"); err != nil {
397		return err
398	}
399
400	var buf bytes.Buffer
401	writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
402
403	ejvw.writeExtendedSingleValue("code", buf.String(), false)
404	ejvw.buf = append(ejvw.buf, ',')
405
406	ejvw.pop()
407	return nil
408}
409
410func (ejvw *extJSONValueWriter) WriteMaxKey() error {
411	if err := ejvw.ensureElementValue(mode(0), "WriteMaxKey"); err != nil {
412		return err
413	}
414
415	ejvw.writeExtendedSingleValue("maxKey", "1", false)
416	ejvw.buf = append(ejvw.buf, ',')
417
418	ejvw.pop()
419	return nil
420}
421
422func (ejvw *extJSONValueWriter) WriteMinKey() error {
423	if err := ejvw.ensureElementValue(mode(0), "WriteMinKey"); err != nil {
424		return err
425	}
426
427	ejvw.writeExtendedSingleValue("minKey", "1", false)
428	ejvw.buf = append(ejvw.buf, ',')
429
430	ejvw.pop()
431	return nil
432}
433
434func (ejvw *extJSONValueWriter) WriteNull() error {
435	if err := ejvw.ensureElementValue(mode(0), "WriteNull"); err != nil {
436		return err
437	}
438
439	ejvw.buf = append(ejvw.buf, []byte("null")...)
440	ejvw.buf = append(ejvw.buf, ',')
441
442	ejvw.pop()
443	return nil
444}
445
446func (ejvw *extJSONValueWriter) WriteObjectID(oid primitive.ObjectID) error {
447	if err := ejvw.ensureElementValue(mode(0), "WriteObjectID"); err != nil {
448		return err
449	}
450
451	ejvw.writeExtendedSingleValue("oid", oid.Hex(), true)
452	ejvw.buf = append(ejvw.buf, ',')
453
454	ejvw.pop()
455	return nil
456}
457
458func (ejvw *extJSONValueWriter) WriteRegex(pattern string, options string) error {
459	if err := ejvw.ensureElementValue(mode(0), "WriteRegex"); err != nil {
460		return err
461	}
462
463	var buf bytes.Buffer
464	buf.WriteString(`{"$regularExpression":{"pattern":`)
465	writeStringWithEscapes(pattern, &buf, ejvw.escapeHTML)
466	buf.WriteString(`,"options":"`)
467	buf.WriteString(sortStringAlphebeticAscending(options))
468	buf.WriteString(`"}},`)
469
470	ejvw.buf = append(ejvw.buf, buf.Bytes()...)
471
472	ejvw.pop()
473	return nil
474}
475
476func (ejvw *extJSONValueWriter) WriteString(s string) error {
477	if err := ejvw.ensureElementValue(mode(0), "WriteString"); err != nil {
478		return err
479	}
480
481	var buf bytes.Buffer
482	writeStringWithEscapes(s, &buf, ejvw.escapeHTML)
483
484	ejvw.buf = append(ejvw.buf, buf.Bytes()...)
485	ejvw.buf = append(ejvw.buf, ',')
486
487	ejvw.pop()
488	return nil
489}
490
491func (ejvw *extJSONValueWriter) WriteSymbol(symbol string) error {
492	if err := ejvw.ensureElementValue(mode(0), "WriteSymbol"); err != nil {
493		return err
494	}
495
496	var buf bytes.Buffer
497	writeStringWithEscapes(symbol, &buf, ejvw.escapeHTML)
498
499	ejvw.writeExtendedSingleValue("symbol", buf.String(), false)
500	ejvw.buf = append(ejvw.buf, ',')
501
502	ejvw.pop()
503	return nil
504}
505
506func (ejvw *extJSONValueWriter) WriteTimestamp(t uint32, i uint32) error {
507	if err := ejvw.ensureElementValue(mode(0), "WriteTimestamp"); err != nil {
508		return err
509	}
510
511	var buf bytes.Buffer
512	buf.WriteString(`{"$timestamp":{"t":`)
513	buf.WriteString(strconv.FormatUint(uint64(t), 10))
514	buf.WriteString(`,"i":`)
515	buf.WriteString(strconv.FormatUint(uint64(i), 10))
516	buf.WriteString(`}},`)
517
518	ejvw.buf = append(ejvw.buf, buf.Bytes()...)
519
520	ejvw.pop()
521	return nil
522}
523
524func (ejvw *extJSONValueWriter) WriteUndefined() error {
525	if err := ejvw.ensureElementValue(mode(0), "WriteUndefined"); err != nil {
526		return err
527	}
528
529	ejvw.writeExtendedSingleValue("undefined", "true", false)
530	ejvw.buf = append(ejvw.buf, ',')
531
532	ejvw.pop()
533	return nil
534}
535
536func (ejvw *extJSONValueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
537	switch ejvw.stack[ejvw.frame].mode {
538	case mDocument, mTopLevel, mCodeWithScope:
539		var buf bytes.Buffer
540		writeStringWithEscapes(key, &buf, ejvw.escapeHTML)
541
542		ejvw.buf = append(ejvw.buf, []byte(fmt.Sprintf(`%s:`, buf.String()))...)
543		ejvw.push(mElement)
544	default:
545		return nil, ejvw.invalidTransitionErr(mElement, "WriteDocumentElement", []mode{mDocument, mTopLevel, mCodeWithScope})
546	}
547
548	return ejvw, nil
549}
550
551func (ejvw *extJSONValueWriter) WriteDocumentEnd() error {
552	switch ejvw.stack[ejvw.frame].mode {
553	case mDocument, mTopLevel, mCodeWithScope:
554	default:
555		return fmt.Errorf("incorrect mode to end document: %s", ejvw.stack[ejvw.frame].mode)
556	}
557
558	// close the document
559	if ejvw.buf[len(ejvw.buf)-1] == ',' {
560		ejvw.buf[len(ejvw.buf)-1] = '}'
561	} else {
562		ejvw.buf = append(ejvw.buf, '}')
563	}
564
565	switch ejvw.stack[ejvw.frame].mode {
566	case mCodeWithScope:
567		ejvw.buf = append(ejvw.buf, '}')
568		fallthrough
569	case mDocument:
570		ejvw.buf = append(ejvw.buf, ',')
571	case mTopLevel:
572		if ejvw.w != nil {
573			if _, err := ejvw.w.Write(ejvw.buf); err != nil {
574				return err
575			}
576			ejvw.buf = ejvw.buf[:0]
577		}
578	}
579
580	ejvw.pop()
581	return nil
582}
583
584func (ejvw *extJSONValueWriter) WriteArrayElement() (ValueWriter, error) {
585	switch ejvw.stack[ejvw.frame].mode {
586	case mArray:
587		ejvw.push(mValue)
588	default:
589		return nil, ejvw.invalidTransitionErr(mValue, "WriteArrayElement", []mode{mArray})
590	}
591
592	return ejvw, nil
593}
594
595func (ejvw *extJSONValueWriter) WriteArrayEnd() error {
596	switch ejvw.stack[ejvw.frame].mode {
597	case mArray:
598		// close the array
599		if ejvw.buf[len(ejvw.buf)-1] == ',' {
600			ejvw.buf[len(ejvw.buf)-1] = ']'
601		} else {
602			ejvw.buf = append(ejvw.buf, ']')
603		}
604
605		ejvw.buf = append(ejvw.buf, ',')
606
607		ejvw.pop()
608	default:
609		return fmt.Errorf("incorrect mode to end array: %s", ejvw.stack[ejvw.frame].mode)
610	}
611
612	return nil
613}
614
615func formatDouble(f float64) string {
616	var s string
617	if math.IsInf(f, 1) {
618		s = "Infinity"
619	} else if math.IsInf(f, -1) {
620		s = "-Infinity"
621	} else if math.IsNaN(f) {
622		s = "NaN"
623	} else {
624		// Print exactly one decimalType place for integers; otherwise, print as many are necessary to
625		// perfectly represent it.
626		s = strconv.FormatFloat(f, 'G', -1, 64)
627		if !strings.ContainsRune(s, 'E') && !strings.ContainsRune(s, '.') {
628			s += ".0"
629		}
630	}
631
632	return s
633}
634
635var hexChars = "0123456789abcdef"
636
637func writeStringWithEscapes(s string, buf *bytes.Buffer, escapeHTML bool) {
638	buf.WriteByte('"')
639	start := 0
640	for i := 0; i < len(s); {
641		if b := s[i]; b < utf8.RuneSelf {
642			if htmlSafeSet[b] || (!escapeHTML && safeSet[b]) {
643				i++
644				continue
645			}
646			if start < i {
647				buf.WriteString(s[start:i])
648			}
649			switch b {
650			case '\\', '"':
651				buf.WriteByte('\\')
652				buf.WriteByte(b)
653			case '\n':
654				buf.WriteByte('\\')
655				buf.WriteByte('n')
656			case '\r':
657				buf.WriteByte('\\')
658				buf.WriteByte('r')
659			case '\t':
660				buf.WriteByte('\\')
661				buf.WriteByte('t')
662			case '\b':
663				buf.WriteByte('\\')
664				buf.WriteByte('b')
665			case '\f':
666				buf.WriteByte('\\')
667				buf.WriteByte('f')
668			default:
669				// This encodes bytes < 0x20 except for \t, \n and \r.
670				// If escapeHTML is set, it also escapes <, >, and &
671				// because they can lead to security holes when
672				// user-controlled strings are rendered into JSON
673				// and served to some browsers.
674				buf.WriteString(`\u00`)
675				buf.WriteByte(hexChars[b>>4])
676				buf.WriteByte(hexChars[b&0xF])
677			}
678			i++
679			start = i
680			continue
681		}
682		c, size := utf8.DecodeRuneInString(s[i:])
683		if c == utf8.RuneError && size == 1 {
684			if start < i {
685				buf.WriteString(s[start:i])
686			}
687			buf.WriteString(`\ufffd`)
688			i += size
689			start = i
690			continue
691		}
692		// U+2028 is LINE SEPARATOR.
693		// U+2029 is PARAGRAPH SEPARATOR.
694		// They are both technically valid characters in JSON strings,
695		// but don't work in JSONP, which has to be evaluated as JavaScript,
696		// and can lead to security holes there. It is valid JSON to
697		// escape them, so we do so unconditionally.
698		// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
699		if c == '\u2028' || c == '\u2029' {
700			if start < i {
701				buf.WriteString(s[start:i])
702			}
703			buf.WriteString(`\u202`)
704			buf.WriteByte(hexChars[c&0xF])
705			i += size
706			start = i
707			continue
708		}
709		i += size
710	}
711	if start < len(s) {
712		buf.WriteString(s[start:])
713	}
714	buf.WriteByte('"')
715}
716
717type sortableString []rune
718
719func (ss sortableString) Len() int {
720	return len(ss)
721}
722
723func (ss sortableString) Less(i, j int) bool {
724	return ss[i] < ss[j]
725}
726
727func (ss sortableString) Swap(i, j int) {
728	oldI := ss[i]
729	ss[i] = ss[j]
730	ss[j] = oldI
731}
732
733func sortStringAlphebeticAscending(s string) string {
734	ss := sortableString([]rune(s))
735	sort.Sort(ss)
736	return string([]rune(ss))
737}
738