1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package bsonrw
8
9import (
10	"errors"
11	"fmt"
12	"io"
13	"math"
14	"strconv"
15	"strings"
16	"sync"
17
18	"go.mongodb.org/mongo-driver/bson/bsontype"
19	"go.mongodb.org/mongo-driver/bson/primitive"
20	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
21)
22
23var _ ValueWriter = (*valueWriter)(nil)
24
25var vwPool = sync.Pool{
26	New: func() interface{} {
27		return new(valueWriter)
28	},
29}
30
31// BSONValueWriterPool is a pool for BSON ValueWriters.
32type BSONValueWriterPool struct {
33	pool sync.Pool
34}
35
36// NewBSONValueWriterPool creates a new pool for ValueWriter instances that write to BSON.
37func NewBSONValueWriterPool() *BSONValueWriterPool {
38	return &BSONValueWriterPool{
39		pool: sync.Pool{
40			New: func() interface{} {
41				return new(valueWriter)
42			},
43		},
44	}
45}
46
47// Get retrieves a BSON ValueWriter from the pool and resets it to use w as the destination.
48func (bvwp *BSONValueWriterPool) Get(w io.Writer) ValueWriter {
49	vw := bvwp.pool.Get().(*valueWriter)
50	if writer, ok := w.(*SliceWriter); ok {
51		vw.reset(*writer)
52		vw.w = writer
53		return vw
54	}
55	vw.buf = vw.buf[:0]
56	vw.w = w
57	return vw
58}
59
60// GetAtModeElement retrieves a ValueWriterFlusher from the pool and resets it to use w as the destination.
61func (bvwp *BSONValueWriterPool) GetAtModeElement(w io.Writer) ValueWriterFlusher {
62	vw := bvwp.Get(w).(*valueWriter)
63	vw.push(mElement)
64	return vw
65}
66
67// Put inserts a ValueWriter into the pool. If the ValueWriter is not a BSON ValueWriter, nothing
68// happens and ok will be false.
69func (bvwp *BSONValueWriterPool) Put(vw ValueWriter) (ok bool) {
70	bvw, ok := vw.(*valueWriter)
71	if !ok {
72		return false
73	}
74
75	if _, ok := bvw.w.(*SliceWriter); ok {
76		bvw.buf = nil
77	}
78	bvw.w = nil
79
80	bvwp.pool.Put(bvw)
81	return true
82}
83
84// This is here so that during testing we can change it and not require
85// allocating a 4GB slice.
86var maxSize = math.MaxInt32
87
88var errNilWriter = errors.New("cannot create a ValueWriter from a nil io.Writer")
89
90type errMaxDocumentSizeExceeded struct {
91	size int64
92}
93
94func (mdse errMaxDocumentSizeExceeded) Error() string {
95	return fmt.Sprintf("document size (%d) is larger than the max int32", mdse.size)
96}
97
98type vwMode int
99
100const (
101	_ vwMode = iota
102	vwTopLevel
103	vwDocument
104	vwArray
105	vwValue
106	vwElement
107	vwCodeWithScope
108)
109
110func (vm vwMode) String() string {
111	var str string
112
113	switch vm {
114	case vwTopLevel:
115		str = "TopLevel"
116	case vwDocument:
117		str = "DocumentMode"
118	case vwArray:
119		str = "ArrayMode"
120	case vwValue:
121		str = "ValueMode"
122	case vwElement:
123		str = "ElementMode"
124	case vwCodeWithScope:
125		str = "CodeWithScopeMode"
126	default:
127		str = "UnknownMode"
128	}
129
130	return str
131}
132
133type vwState struct {
134	mode   mode
135	key    string
136	arrkey int
137	start  int32
138}
139
140type valueWriter struct {
141	w   io.Writer
142	buf []byte
143
144	stack []vwState
145	frame int64
146}
147
148func (vw *valueWriter) advanceFrame() {
149	if vw.frame+1 >= int64(len(vw.stack)) { // We need to grow the stack
150		length := len(vw.stack)
151		if length+1 >= cap(vw.stack) {
152			// double it
153			buf := make([]vwState, 2*cap(vw.stack)+1)
154			copy(buf, vw.stack)
155			vw.stack = buf
156		}
157		vw.stack = vw.stack[:length+1]
158	}
159	vw.frame++
160}
161
162func (vw *valueWriter) push(m mode) {
163	vw.advanceFrame()
164
165	// Clean the stack
166	vw.stack[vw.frame].mode = m
167	vw.stack[vw.frame].key = ""
168	vw.stack[vw.frame].arrkey = 0
169	vw.stack[vw.frame].start = 0
170
171	vw.stack[vw.frame].mode = m
172	switch m {
173	case mDocument, mArray, mCodeWithScope:
174		vw.reserveLength()
175	}
176}
177
178func (vw *valueWriter) reserveLength() {
179	vw.stack[vw.frame].start = int32(len(vw.buf))
180	vw.buf = append(vw.buf, 0x00, 0x00, 0x00, 0x00)
181}
182
183func (vw *valueWriter) pop() {
184	switch vw.stack[vw.frame].mode {
185	case mElement, mValue:
186		vw.frame--
187	case mDocument, mArray, mCodeWithScope:
188		vw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
189	}
190}
191
192// NewBSONValueWriter creates a ValueWriter that writes BSON to w.
193//
194// This ValueWriter will only write entire documents to the io.Writer and it
195// will buffer the document as it is built.
196func NewBSONValueWriter(w io.Writer) (ValueWriter, error) {
197	if w == nil {
198		return nil, errNilWriter
199	}
200	return newValueWriter(w), nil
201}
202
203func newValueWriter(w io.Writer) *valueWriter {
204	vw := new(valueWriter)
205	stack := make([]vwState, 1, 5)
206	stack[0] = vwState{mode: mTopLevel}
207	vw.w = w
208	vw.stack = stack
209
210	return vw
211}
212
213func newValueWriterFromSlice(buf []byte) *valueWriter {
214	vw := new(valueWriter)
215	stack := make([]vwState, 1, 5)
216	stack[0] = vwState{mode: mTopLevel}
217	vw.stack = stack
218	vw.buf = buf
219
220	return vw
221}
222
223func (vw *valueWriter) reset(buf []byte) {
224	if vw.stack == nil {
225		vw.stack = make([]vwState, 1, 5)
226	}
227	vw.stack = vw.stack[:1]
228	vw.stack[0] = vwState{mode: mTopLevel}
229	vw.buf = buf
230	vw.frame = 0
231	vw.w = nil
232}
233
234func (vw *valueWriter) invalidTransitionError(destination mode, name string, modes []mode) error {
235	te := TransitionError{
236		name:        name,
237		current:     vw.stack[vw.frame].mode,
238		destination: destination,
239		modes:       modes,
240		action:      "write",
241	}
242	if vw.frame != 0 {
243		te.parent = vw.stack[vw.frame-1].mode
244	}
245	return te
246}
247
248func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error {
249	switch vw.stack[vw.frame].mode {
250	case mElement:
251		key := vw.stack[vw.frame].key
252		if !isValidCString(key) {
253			return errors.New("BSON element key cannot contain null bytes")
254		}
255
256		vw.buf = bsoncore.AppendHeader(vw.buf, t, key)
257	case mValue:
258		// TODO: Do this with a cache of the first 1000 or so array keys.
259		vw.buf = bsoncore.AppendHeader(vw.buf, t, strconv.Itoa(vw.stack[vw.frame].arrkey))
260	default:
261		modes := []mode{mElement, mValue}
262		if addmodes != nil {
263			modes = append(modes, addmodes...)
264		}
265		return vw.invalidTransitionError(destination, callerName, modes)
266	}
267
268	return nil
269}
270
271func (vw *valueWriter) WriteValueBytes(t bsontype.Type, b []byte) error {
272	if err := vw.writeElementHeader(t, mode(0), "WriteValueBytes"); err != nil {
273		return err
274	}
275	vw.buf = append(vw.buf, b...)
276	vw.pop()
277	return nil
278}
279
280func (vw *valueWriter) WriteArray() (ArrayWriter, error) {
281	if err := vw.writeElementHeader(bsontype.Array, mArray, "WriteArray"); err != nil {
282		return nil, err
283	}
284
285	vw.push(mArray)
286
287	return vw, nil
288}
289
290func (vw *valueWriter) WriteBinary(b []byte) error {
291	return vw.WriteBinaryWithSubtype(b, 0x00)
292}
293
294func (vw *valueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
295	if err := vw.writeElementHeader(bsontype.Binary, mode(0), "WriteBinaryWithSubtype"); err != nil {
296		return err
297	}
298
299	vw.buf = bsoncore.AppendBinary(vw.buf, btype, b)
300	vw.pop()
301	return nil
302}
303
304func (vw *valueWriter) WriteBoolean(b bool) error {
305	if err := vw.writeElementHeader(bsontype.Boolean, mode(0), "WriteBoolean"); err != nil {
306		return err
307	}
308
309	vw.buf = bsoncore.AppendBoolean(vw.buf, b)
310	vw.pop()
311	return nil
312}
313
314func (vw *valueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
315	if err := vw.writeElementHeader(bsontype.CodeWithScope, mCodeWithScope, "WriteCodeWithScope"); err != nil {
316		return nil, err
317	}
318
319	// CodeWithScope is a different than other types because we need an extra
320	// frame on the stack. In the EndDocument code, we write the document
321	// length, pop, write the code with scope length, and pop. To simplify the
322	// pop code, we push a spacer frame that we'll always jump over.
323	vw.push(mCodeWithScope)
324	vw.buf = bsoncore.AppendString(vw.buf, code)
325	vw.push(mSpacer)
326	vw.push(mDocument)
327
328	return vw, nil
329}
330
331func (vw *valueWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error {
332	if err := vw.writeElementHeader(bsontype.DBPointer, mode(0), "WriteDBPointer"); err != nil {
333		return err
334	}
335
336	vw.buf = bsoncore.AppendDBPointer(vw.buf, ns, oid)
337	vw.pop()
338	return nil
339}
340
341func (vw *valueWriter) WriteDateTime(dt int64) error {
342	if err := vw.writeElementHeader(bsontype.DateTime, mode(0), "WriteDateTime"); err != nil {
343		return err
344	}
345
346	vw.buf = bsoncore.AppendDateTime(vw.buf, dt)
347	vw.pop()
348	return nil
349}
350
351func (vw *valueWriter) WriteDecimal128(d128 primitive.Decimal128) error {
352	if err := vw.writeElementHeader(bsontype.Decimal128, mode(0), "WriteDecimal128"); err != nil {
353		return err
354	}
355
356	vw.buf = bsoncore.AppendDecimal128(vw.buf, d128)
357	vw.pop()
358	return nil
359}
360
361func (vw *valueWriter) WriteDouble(f float64) error {
362	if err := vw.writeElementHeader(bsontype.Double, mode(0), "WriteDouble"); err != nil {
363		return err
364	}
365
366	vw.buf = bsoncore.AppendDouble(vw.buf, f)
367	vw.pop()
368	return nil
369}
370
371func (vw *valueWriter) WriteInt32(i32 int32) error {
372	if err := vw.writeElementHeader(bsontype.Int32, mode(0), "WriteInt32"); err != nil {
373		return err
374	}
375
376	vw.buf = bsoncore.AppendInt32(vw.buf, i32)
377	vw.pop()
378	return nil
379}
380
381func (vw *valueWriter) WriteInt64(i64 int64) error {
382	if err := vw.writeElementHeader(bsontype.Int64, mode(0), "WriteInt64"); err != nil {
383		return err
384	}
385
386	vw.buf = bsoncore.AppendInt64(vw.buf, i64)
387	vw.pop()
388	return nil
389}
390
391func (vw *valueWriter) WriteJavascript(code string) error {
392	if err := vw.writeElementHeader(bsontype.JavaScript, mode(0), "WriteJavascript"); err != nil {
393		return err
394	}
395
396	vw.buf = bsoncore.AppendJavaScript(vw.buf, code)
397	vw.pop()
398	return nil
399}
400
401func (vw *valueWriter) WriteMaxKey() error {
402	if err := vw.writeElementHeader(bsontype.MaxKey, mode(0), "WriteMaxKey"); err != nil {
403		return err
404	}
405
406	vw.pop()
407	return nil
408}
409
410func (vw *valueWriter) WriteMinKey() error {
411	if err := vw.writeElementHeader(bsontype.MinKey, mode(0), "WriteMinKey"); err != nil {
412		return err
413	}
414
415	vw.pop()
416	return nil
417}
418
419func (vw *valueWriter) WriteNull() error {
420	if err := vw.writeElementHeader(bsontype.Null, mode(0), "WriteNull"); err != nil {
421		return err
422	}
423
424	vw.pop()
425	return nil
426}
427
428func (vw *valueWriter) WriteObjectID(oid primitive.ObjectID) error {
429	if err := vw.writeElementHeader(bsontype.ObjectID, mode(0), "WriteObjectID"); err != nil {
430		return err
431	}
432
433	vw.buf = bsoncore.AppendObjectID(vw.buf, oid)
434	vw.pop()
435	return nil
436}
437
438func (vw *valueWriter) WriteRegex(pattern string, options string) error {
439	if !isValidCString(pattern) || !isValidCString(options) {
440		return errors.New("BSON regex values cannot contain null bytes")
441	}
442	if err := vw.writeElementHeader(bsontype.Regex, mode(0), "WriteRegex"); err != nil {
443		return err
444	}
445
446	vw.buf = bsoncore.AppendRegex(vw.buf, pattern, sortStringAlphebeticAscending(options))
447	vw.pop()
448	return nil
449}
450
451func (vw *valueWriter) WriteString(s string) error {
452	if err := vw.writeElementHeader(bsontype.String, mode(0), "WriteString"); err != nil {
453		return err
454	}
455
456	vw.buf = bsoncore.AppendString(vw.buf, s)
457	vw.pop()
458	return nil
459}
460
461func (vw *valueWriter) WriteDocument() (DocumentWriter, error) {
462	if vw.stack[vw.frame].mode == mTopLevel {
463		vw.reserveLength()
464		return vw, nil
465	}
466	if err := vw.writeElementHeader(bsontype.EmbeddedDocument, mDocument, "WriteDocument", mTopLevel); err != nil {
467		return nil, err
468	}
469
470	vw.push(mDocument)
471	return vw, nil
472}
473
474func (vw *valueWriter) WriteSymbol(symbol string) error {
475	if err := vw.writeElementHeader(bsontype.Symbol, mode(0), "WriteSymbol"); err != nil {
476		return err
477	}
478
479	vw.buf = bsoncore.AppendSymbol(vw.buf, symbol)
480	vw.pop()
481	return nil
482}
483
484func (vw *valueWriter) WriteTimestamp(t uint32, i uint32) error {
485	if err := vw.writeElementHeader(bsontype.Timestamp, mode(0), "WriteTimestamp"); err != nil {
486		return err
487	}
488
489	vw.buf = bsoncore.AppendTimestamp(vw.buf, t, i)
490	vw.pop()
491	return nil
492}
493
494func (vw *valueWriter) WriteUndefined() error {
495	if err := vw.writeElementHeader(bsontype.Undefined, mode(0), "WriteUndefined"); err != nil {
496		return err
497	}
498
499	vw.pop()
500	return nil
501}
502
503func (vw *valueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
504	switch vw.stack[vw.frame].mode {
505	case mTopLevel, mDocument:
506	default:
507		return nil, vw.invalidTransitionError(mElement, "WriteDocumentElement", []mode{mTopLevel, mDocument})
508	}
509
510	vw.push(mElement)
511	vw.stack[vw.frame].key = key
512
513	return vw, nil
514}
515
516func (vw *valueWriter) WriteDocumentEnd() error {
517	switch vw.stack[vw.frame].mode {
518	case mTopLevel, mDocument:
519	default:
520		return fmt.Errorf("incorrect mode to end document: %s", vw.stack[vw.frame].mode)
521	}
522
523	vw.buf = append(vw.buf, 0x00)
524
525	err := vw.writeLength()
526	if err != nil {
527		return err
528	}
529
530	if vw.stack[vw.frame].mode == mTopLevel {
531		if err = vw.Flush(); err != nil {
532			return err
533		}
534	}
535
536	vw.pop()
537
538	if vw.stack[vw.frame].mode == mCodeWithScope {
539		// We ignore the error here because of the gaurantee of writeLength.
540		// See the docs for writeLength for more info.
541		_ = vw.writeLength()
542		vw.pop()
543	}
544	return nil
545}
546
547func (vw *valueWriter) Flush() error {
548	if vw.w == nil {
549		return nil
550	}
551
552	if sw, ok := vw.w.(*SliceWriter); ok {
553		*sw = vw.buf
554		return nil
555	}
556	if _, err := vw.w.Write(vw.buf); err != nil {
557		return err
558	}
559	// reset buffer
560	vw.buf = vw.buf[:0]
561	return nil
562}
563
564func (vw *valueWriter) WriteArrayElement() (ValueWriter, error) {
565	if vw.stack[vw.frame].mode != mArray {
566		return nil, vw.invalidTransitionError(mValue, "WriteArrayElement", []mode{mArray})
567	}
568
569	arrkey := vw.stack[vw.frame].arrkey
570	vw.stack[vw.frame].arrkey++
571
572	vw.push(mValue)
573	vw.stack[vw.frame].arrkey = arrkey
574
575	return vw, nil
576}
577
578func (vw *valueWriter) WriteArrayEnd() error {
579	if vw.stack[vw.frame].mode != mArray {
580		return fmt.Errorf("incorrect mode to end array: %s", vw.stack[vw.frame].mode)
581	}
582
583	vw.buf = append(vw.buf, 0x00)
584
585	err := vw.writeLength()
586	if err != nil {
587		return err
588	}
589
590	vw.pop()
591	return nil
592}
593
594// NOTE: We assume that if we call writeLength more than once the same function
595// within the same function without altering the vw.buf that this method will
596// not return an error. If this changes ensure that the following methods are
597// updated:
598//
599// - WriteDocumentEnd
600func (vw *valueWriter) writeLength() error {
601	length := len(vw.buf)
602	if length > maxSize {
603		return errMaxDocumentSizeExceeded{size: int64(len(vw.buf))}
604	}
605	length = length - int(vw.stack[vw.frame].start)
606	start := vw.stack[vw.frame].start
607
608	vw.buf[start+0] = byte(length)
609	vw.buf[start+1] = byte(length >> 8)
610	vw.buf[start+2] = byte(length >> 16)
611	vw.buf[start+3] = byte(length >> 24)
612	return nil
613}
614
615func isValidCString(cs string) bool {
616	return !strings.ContainsRune(cs, '\x00')
617}
618