1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package bsonrw 8 9import ( 10 "errors" 11 "fmt" 12 "io" 13 "math" 14 "strconv" 15 "strings" 16 "sync" 17 18 "go.mongodb.org/mongo-driver/bson/bsontype" 19 "go.mongodb.org/mongo-driver/bson/primitive" 20 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" 21) 22 23var _ ValueWriter = (*valueWriter)(nil) 24 25var vwPool = sync.Pool{ 26 New: func() interface{} { 27 return new(valueWriter) 28 }, 29} 30 31// BSONValueWriterPool is a pool for BSON ValueWriters. 32type BSONValueWriterPool struct { 33 pool sync.Pool 34} 35 36// NewBSONValueWriterPool creates a new pool for ValueWriter instances that write to BSON. 37func NewBSONValueWriterPool() *BSONValueWriterPool { 38 return &BSONValueWriterPool{ 39 pool: sync.Pool{ 40 New: func() interface{} { 41 return new(valueWriter) 42 }, 43 }, 44 } 45} 46 47// Get retrieves a BSON ValueWriter from the pool and resets it to use w as the destination. 48func (bvwp *BSONValueWriterPool) Get(w io.Writer) ValueWriter { 49 vw := bvwp.pool.Get().(*valueWriter) 50 if writer, ok := w.(*SliceWriter); ok { 51 vw.reset(*writer) 52 vw.w = writer 53 return vw 54 } 55 vw.buf = vw.buf[:0] 56 vw.w = w 57 return vw 58} 59 60// GetAtModeElement retrieves a ValueWriterFlusher from the pool and resets it to use w as the destination. 61func (bvwp *BSONValueWriterPool) GetAtModeElement(w io.Writer) ValueWriterFlusher { 62 vw := bvwp.Get(w).(*valueWriter) 63 vw.push(mElement) 64 return vw 65} 66 67// Put inserts a ValueWriter into the pool. If the ValueWriter is not a BSON ValueWriter, nothing 68// happens and ok will be false. 69func (bvwp *BSONValueWriterPool) Put(vw ValueWriter) (ok bool) { 70 bvw, ok := vw.(*valueWriter) 71 if !ok { 72 return false 73 } 74 75 if _, ok := bvw.w.(*SliceWriter); ok { 76 bvw.buf = nil 77 } 78 bvw.w = nil 79 80 bvwp.pool.Put(bvw) 81 return true 82} 83 84// This is here so that during testing we can change it and not require 85// allocating a 4GB slice. 86var maxSize = math.MaxInt32 87 88var errNilWriter = errors.New("cannot create a ValueWriter from a nil io.Writer") 89 90type errMaxDocumentSizeExceeded struct { 91 size int64 92} 93 94func (mdse errMaxDocumentSizeExceeded) Error() string { 95 return fmt.Sprintf("document size (%d) is larger than the max int32", mdse.size) 96} 97 98type vwMode int 99 100const ( 101 _ vwMode = iota 102 vwTopLevel 103 vwDocument 104 vwArray 105 vwValue 106 vwElement 107 vwCodeWithScope 108) 109 110func (vm vwMode) String() string { 111 var str string 112 113 switch vm { 114 case vwTopLevel: 115 str = "TopLevel" 116 case vwDocument: 117 str = "DocumentMode" 118 case vwArray: 119 str = "ArrayMode" 120 case vwValue: 121 str = "ValueMode" 122 case vwElement: 123 str = "ElementMode" 124 case vwCodeWithScope: 125 str = "CodeWithScopeMode" 126 default: 127 str = "UnknownMode" 128 } 129 130 return str 131} 132 133type vwState struct { 134 mode mode 135 key string 136 arrkey int 137 start int32 138} 139 140type valueWriter struct { 141 w io.Writer 142 buf []byte 143 144 stack []vwState 145 frame int64 146} 147 148func (vw *valueWriter) advanceFrame() { 149 if vw.frame+1 >= int64(len(vw.stack)) { // We need to grow the stack 150 length := len(vw.stack) 151 if length+1 >= cap(vw.stack) { 152 // double it 153 buf := make([]vwState, 2*cap(vw.stack)+1) 154 copy(buf, vw.stack) 155 vw.stack = buf 156 } 157 vw.stack = vw.stack[:length+1] 158 } 159 vw.frame++ 160} 161 162func (vw *valueWriter) push(m mode) { 163 vw.advanceFrame() 164 165 // Clean the stack 166 vw.stack[vw.frame].mode = m 167 vw.stack[vw.frame].key = "" 168 vw.stack[vw.frame].arrkey = 0 169 vw.stack[vw.frame].start = 0 170 171 vw.stack[vw.frame].mode = m 172 switch m { 173 case mDocument, mArray, mCodeWithScope: 174 vw.reserveLength() 175 } 176} 177 178func (vw *valueWriter) reserveLength() { 179 vw.stack[vw.frame].start = int32(len(vw.buf)) 180 vw.buf = append(vw.buf, 0x00, 0x00, 0x00, 0x00) 181} 182 183func (vw *valueWriter) pop() { 184 switch vw.stack[vw.frame].mode { 185 case mElement, mValue: 186 vw.frame-- 187 case mDocument, mArray, mCodeWithScope: 188 vw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc... 189 } 190} 191 192// NewBSONValueWriter creates a ValueWriter that writes BSON to w. 193// 194// This ValueWriter will only write entire documents to the io.Writer and it 195// will buffer the document as it is built. 196func NewBSONValueWriter(w io.Writer) (ValueWriter, error) { 197 if w == nil { 198 return nil, errNilWriter 199 } 200 return newValueWriter(w), nil 201} 202 203func newValueWriter(w io.Writer) *valueWriter { 204 vw := new(valueWriter) 205 stack := make([]vwState, 1, 5) 206 stack[0] = vwState{mode: mTopLevel} 207 vw.w = w 208 vw.stack = stack 209 210 return vw 211} 212 213func newValueWriterFromSlice(buf []byte) *valueWriter { 214 vw := new(valueWriter) 215 stack := make([]vwState, 1, 5) 216 stack[0] = vwState{mode: mTopLevel} 217 vw.stack = stack 218 vw.buf = buf 219 220 return vw 221} 222 223func (vw *valueWriter) reset(buf []byte) { 224 if vw.stack == nil { 225 vw.stack = make([]vwState, 1, 5) 226 } 227 vw.stack = vw.stack[:1] 228 vw.stack[0] = vwState{mode: mTopLevel} 229 vw.buf = buf 230 vw.frame = 0 231 vw.w = nil 232} 233 234func (vw *valueWriter) invalidTransitionError(destination mode, name string, modes []mode) error { 235 te := TransitionError{ 236 name: name, 237 current: vw.stack[vw.frame].mode, 238 destination: destination, 239 modes: modes, 240 action: "write", 241 } 242 if vw.frame != 0 { 243 te.parent = vw.stack[vw.frame-1].mode 244 } 245 return te 246} 247 248func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error { 249 switch vw.stack[vw.frame].mode { 250 case mElement: 251 key := vw.stack[vw.frame].key 252 if !isValidCString(key) { 253 return errors.New("BSON element key cannot contain null bytes") 254 } 255 256 vw.buf = bsoncore.AppendHeader(vw.buf, t, key) 257 case mValue: 258 // TODO: Do this with a cache of the first 1000 or so array keys. 259 vw.buf = bsoncore.AppendHeader(vw.buf, t, strconv.Itoa(vw.stack[vw.frame].arrkey)) 260 default: 261 modes := []mode{mElement, mValue} 262 if addmodes != nil { 263 modes = append(modes, addmodes...) 264 } 265 return vw.invalidTransitionError(destination, callerName, modes) 266 } 267 268 return nil 269} 270 271func (vw *valueWriter) WriteValueBytes(t bsontype.Type, b []byte) error { 272 if err := vw.writeElementHeader(t, mode(0), "WriteValueBytes"); err != nil { 273 return err 274 } 275 vw.buf = append(vw.buf, b...) 276 vw.pop() 277 return nil 278} 279 280func (vw *valueWriter) WriteArray() (ArrayWriter, error) { 281 if err := vw.writeElementHeader(bsontype.Array, mArray, "WriteArray"); err != nil { 282 return nil, err 283 } 284 285 vw.push(mArray) 286 287 return vw, nil 288} 289 290func (vw *valueWriter) WriteBinary(b []byte) error { 291 return vw.WriteBinaryWithSubtype(b, 0x00) 292} 293 294func (vw *valueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error { 295 if err := vw.writeElementHeader(bsontype.Binary, mode(0), "WriteBinaryWithSubtype"); err != nil { 296 return err 297 } 298 299 vw.buf = bsoncore.AppendBinary(vw.buf, btype, b) 300 vw.pop() 301 return nil 302} 303 304func (vw *valueWriter) WriteBoolean(b bool) error { 305 if err := vw.writeElementHeader(bsontype.Boolean, mode(0), "WriteBoolean"); err != nil { 306 return err 307 } 308 309 vw.buf = bsoncore.AppendBoolean(vw.buf, b) 310 vw.pop() 311 return nil 312} 313 314func (vw *valueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) { 315 if err := vw.writeElementHeader(bsontype.CodeWithScope, mCodeWithScope, "WriteCodeWithScope"); err != nil { 316 return nil, err 317 } 318 319 // CodeWithScope is a different than other types because we need an extra 320 // frame on the stack. In the EndDocument code, we write the document 321 // length, pop, write the code with scope length, and pop. To simplify the 322 // pop code, we push a spacer frame that we'll always jump over. 323 vw.push(mCodeWithScope) 324 vw.buf = bsoncore.AppendString(vw.buf, code) 325 vw.push(mSpacer) 326 vw.push(mDocument) 327 328 return vw, nil 329} 330 331func (vw *valueWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error { 332 if err := vw.writeElementHeader(bsontype.DBPointer, mode(0), "WriteDBPointer"); err != nil { 333 return err 334 } 335 336 vw.buf = bsoncore.AppendDBPointer(vw.buf, ns, oid) 337 vw.pop() 338 return nil 339} 340 341func (vw *valueWriter) WriteDateTime(dt int64) error { 342 if err := vw.writeElementHeader(bsontype.DateTime, mode(0), "WriteDateTime"); err != nil { 343 return err 344 } 345 346 vw.buf = bsoncore.AppendDateTime(vw.buf, dt) 347 vw.pop() 348 return nil 349} 350 351func (vw *valueWriter) WriteDecimal128(d128 primitive.Decimal128) error { 352 if err := vw.writeElementHeader(bsontype.Decimal128, mode(0), "WriteDecimal128"); err != nil { 353 return err 354 } 355 356 vw.buf = bsoncore.AppendDecimal128(vw.buf, d128) 357 vw.pop() 358 return nil 359} 360 361func (vw *valueWriter) WriteDouble(f float64) error { 362 if err := vw.writeElementHeader(bsontype.Double, mode(0), "WriteDouble"); err != nil { 363 return err 364 } 365 366 vw.buf = bsoncore.AppendDouble(vw.buf, f) 367 vw.pop() 368 return nil 369} 370 371func (vw *valueWriter) WriteInt32(i32 int32) error { 372 if err := vw.writeElementHeader(bsontype.Int32, mode(0), "WriteInt32"); err != nil { 373 return err 374 } 375 376 vw.buf = bsoncore.AppendInt32(vw.buf, i32) 377 vw.pop() 378 return nil 379} 380 381func (vw *valueWriter) WriteInt64(i64 int64) error { 382 if err := vw.writeElementHeader(bsontype.Int64, mode(0), "WriteInt64"); err != nil { 383 return err 384 } 385 386 vw.buf = bsoncore.AppendInt64(vw.buf, i64) 387 vw.pop() 388 return nil 389} 390 391func (vw *valueWriter) WriteJavascript(code string) error { 392 if err := vw.writeElementHeader(bsontype.JavaScript, mode(0), "WriteJavascript"); err != nil { 393 return err 394 } 395 396 vw.buf = bsoncore.AppendJavaScript(vw.buf, code) 397 vw.pop() 398 return nil 399} 400 401func (vw *valueWriter) WriteMaxKey() error { 402 if err := vw.writeElementHeader(bsontype.MaxKey, mode(0), "WriteMaxKey"); err != nil { 403 return err 404 } 405 406 vw.pop() 407 return nil 408} 409 410func (vw *valueWriter) WriteMinKey() error { 411 if err := vw.writeElementHeader(bsontype.MinKey, mode(0), "WriteMinKey"); err != nil { 412 return err 413 } 414 415 vw.pop() 416 return nil 417} 418 419func (vw *valueWriter) WriteNull() error { 420 if err := vw.writeElementHeader(bsontype.Null, mode(0), "WriteNull"); err != nil { 421 return err 422 } 423 424 vw.pop() 425 return nil 426} 427 428func (vw *valueWriter) WriteObjectID(oid primitive.ObjectID) error { 429 if err := vw.writeElementHeader(bsontype.ObjectID, mode(0), "WriteObjectID"); err != nil { 430 return err 431 } 432 433 vw.buf = bsoncore.AppendObjectID(vw.buf, oid) 434 vw.pop() 435 return nil 436} 437 438func (vw *valueWriter) WriteRegex(pattern string, options string) error { 439 if !isValidCString(pattern) || !isValidCString(options) { 440 return errors.New("BSON regex values cannot contain null bytes") 441 } 442 if err := vw.writeElementHeader(bsontype.Regex, mode(0), "WriteRegex"); err != nil { 443 return err 444 } 445 446 vw.buf = bsoncore.AppendRegex(vw.buf, pattern, sortStringAlphebeticAscending(options)) 447 vw.pop() 448 return nil 449} 450 451func (vw *valueWriter) WriteString(s string) error { 452 if err := vw.writeElementHeader(bsontype.String, mode(0), "WriteString"); err != nil { 453 return err 454 } 455 456 vw.buf = bsoncore.AppendString(vw.buf, s) 457 vw.pop() 458 return nil 459} 460 461func (vw *valueWriter) WriteDocument() (DocumentWriter, error) { 462 if vw.stack[vw.frame].mode == mTopLevel { 463 vw.reserveLength() 464 return vw, nil 465 } 466 if err := vw.writeElementHeader(bsontype.EmbeddedDocument, mDocument, "WriteDocument", mTopLevel); err != nil { 467 return nil, err 468 } 469 470 vw.push(mDocument) 471 return vw, nil 472} 473 474func (vw *valueWriter) WriteSymbol(symbol string) error { 475 if err := vw.writeElementHeader(bsontype.Symbol, mode(0), "WriteSymbol"); err != nil { 476 return err 477 } 478 479 vw.buf = bsoncore.AppendSymbol(vw.buf, symbol) 480 vw.pop() 481 return nil 482} 483 484func (vw *valueWriter) WriteTimestamp(t uint32, i uint32) error { 485 if err := vw.writeElementHeader(bsontype.Timestamp, mode(0), "WriteTimestamp"); err != nil { 486 return err 487 } 488 489 vw.buf = bsoncore.AppendTimestamp(vw.buf, t, i) 490 vw.pop() 491 return nil 492} 493 494func (vw *valueWriter) WriteUndefined() error { 495 if err := vw.writeElementHeader(bsontype.Undefined, mode(0), "WriteUndefined"); err != nil { 496 return err 497 } 498 499 vw.pop() 500 return nil 501} 502 503func (vw *valueWriter) WriteDocumentElement(key string) (ValueWriter, error) { 504 switch vw.stack[vw.frame].mode { 505 case mTopLevel, mDocument: 506 default: 507 return nil, vw.invalidTransitionError(mElement, "WriteDocumentElement", []mode{mTopLevel, mDocument}) 508 } 509 510 vw.push(mElement) 511 vw.stack[vw.frame].key = key 512 513 return vw, nil 514} 515 516func (vw *valueWriter) WriteDocumentEnd() error { 517 switch vw.stack[vw.frame].mode { 518 case mTopLevel, mDocument: 519 default: 520 return fmt.Errorf("incorrect mode to end document: %s", vw.stack[vw.frame].mode) 521 } 522 523 vw.buf = append(vw.buf, 0x00) 524 525 err := vw.writeLength() 526 if err != nil { 527 return err 528 } 529 530 if vw.stack[vw.frame].mode == mTopLevel { 531 if err = vw.Flush(); err != nil { 532 return err 533 } 534 } 535 536 vw.pop() 537 538 if vw.stack[vw.frame].mode == mCodeWithScope { 539 // We ignore the error here because of the gaurantee of writeLength. 540 // See the docs for writeLength for more info. 541 _ = vw.writeLength() 542 vw.pop() 543 } 544 return nil 545} 546 547func (vw *valueWriter) Flush() error { 548 if vw.w == nil { 549 return nil 550 } 551 552 if sw, ok := vw.w.(*SliceWriter); ok { 553 *sw = vw.buf 554 return nil 555 } 556 if _, err := vw.w.Write(vw.buf); err != nil { 557 return err 558 } 559 // reset buffer 560 vw.buf = vw.buf[:0] 561 return nil 562} 563 564func (vw *valueWriter) WriteArrayElement() (ValueWriter, error) { 565 if vw.stack[vw.frame].mode != mArray { 566 return nil, vw.invalidTransitionError(mValue, "WriteArrayElement", []mode{mArray}) 567 } 568 569 arrkey := vw.stack[vw.frame].arrkey 570 vw.stack[vw.frame].arrkey++ 571 572 vw.push(mValue) 573 vw.stack[vw.frame].arrkey = arrkey 574 575 return vw, nil 576} 577 578func (vw *valueWriter) WriteArrayEnd() error { 579 if vw.stack[vw.frame].mode != mArray { 580 return fmt.Errorf("incorrect mode to end array: %s", vw.stack[vw.frame].mode) 581 } 582 583 vw.buf = append(vw.buf, 0x00) 584 585 err := vw.writeLength() 586 if err != nil { 587 return err 588 } 589 590 vw.pop() 591 return nil 592} 593 594// NOTE: We assume that if we call writeLength more than once the same function 595// within the same function without altering the vw.buf that this method will 596// not return an error. If this changes ensure that the following methods are 597// updated: 598// 599// - WriteDocumentEnd 600func (vw *valueWriter) writeLength() error { 601 length := len(vw.buf) 602 if length > maxSize { 603 return errMaxDocumentSizeExceeded{size: int64(len(vw.buf))} 604 } 605 length = length - int(vw.stack[vw.frame].start) 606 start := vw.stack[vw.frame].start 607 608 vw.buf[start+0] = byte(length) 609 vw.buf[start+1] = byte(length >> 8) 610 vw.buf[start+2] = byte(length >> 16) 611 vw.buf[start+3] = byte(length >> 24) 612 return nil 613} 614 615func isValidCString(cs string) bool { 616 return !strings.ContainsRune(cs, '\x00') 617} 618