1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package command // import "go.mongodb.org/mongo-driver/x/network/command" 8 9import ( 10 "errors" 11 12 "context" 13 14 "fmt" 15 16 "go.mongodb.org/mongo-driver/bson" 17 "go.mongodb.org/mongo-driver/bson/bsontype" 18 "go.mongodb.org/mongo-driver/bson/primitive" 19 "go.mongodb.org/mongo-driver/mongo/readconcern" 20 "go.mongodb.org/mongo-driver/mongo/writeconcern" 21 "go.mongodb.org/mongo-driver/x/bsonx" 22 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" 23 "go.mongodb.org/mongo-driver/x/mongo/driver/session" 24 "go.mongodb.org/mongo-driver/x/network/description" 25 "go.mongodb.org/mongo-driver/x/network/result" 26 "go.mongodb.org/mongo-driver/x/network/wiremessage" 27) 28 29// WriteBatch represents a single batch for a write operation. 30type WriteBatch struct { 31 *Write 32 numDocs int 33} 34 35// DecodeError attempts to decode the wiremessage as an error 36func DecodeError(wm wiremessage.WireMessage) error { 37 var rdr bson.Raw 38 switch msg := wm.(type) { 39 case wiremessage.Msg: 40 for _, section := range msg.Sections { 41 switch converted := section.(type) { 42 case wiremessage.SectionBody: 43 rdr = converted.Document 44 } 45 } 46 case wiremessage.Reply: 47 if msg.ResponseFlags&wiremessage.QueryFailure != wiremessage.QueryFailure { 48 return nil 49 } 50 rdr = msg.Documents[0] 51 } 52 53 err := rdr.Validate() 54 if err != nil { 55 return nil 56 } 57 58 extractedError := extractError(rdr) 59 60 // If parsed successfully return the error 61 if _, ok := extractedError.(Error); ok { 62 return err 63 } 64 65 return nil 66} 67 68// helper method to extract an error from a reader if there is one; first returned item is the 69// error if it exists, the second holds parsing errors 70func extractError(rdr bson.Raw) error { 71 var errmsg, codeName string 72 var code int32 73 var labels []string 74 elems, err := rdr.Elements() 75 if err != nil { 76 return err 77 } 78 79 for _, elem := range elems { 80 switch elem.Key() { 81 case "ok": 82 switch elem.Value().Type { 83 case bson.TypeInt32: 84 if elem.Value().Int32() == 1 { 85 return nil 86 } 87 case bson.TypeInt64: 88 if elem.Value().Int64() == 1 { 89 return nil 90 } 91 case bson.TypeDouble: 92 if elem.Value().Double() == 1 { 93 return nil 94 } 95 } 96 case "errmsg": 97 if str, okay := elem.Value().StringValueOK(); okay { 98 errmsg = str 99 } 100 case "codeName": 101 if str, okay := elem.Value().StringValueOK(); okay { 102 codeName = str 103 } 104 case "code": 105 if c, okay := elem.Value().Int32OK(); okay { 106 code = c 107 } 108 case "errorLabels": 109 if arr, okay := elem.Value().ArrayOK(); okay { 110 elems, err := arr.Elements() 111 if err != nil { 112 continue 113 } 114 for _, elem := range elems { 115 if str, ok := elem.Value().StringValueOK(); ok { 116 labels = append(labels, str) 117 } 118 } 119 120 } 121 } 122 } 123 124 if errmsg == "" { 125 errmsg = "command failed" 126 } 127 128 return Error{ 129 Code: code, 130 Message: errmsg, 131 Name: codeName, 132 Labels: labels, 133 } 134} 135 136func responseClusterTime(response bson.Raw) bson.Raw { 137 clusterTime, err := response.LookupErr("$clusterTime") 138 if err != nil { 139 // $clusterTime not included by the server 140 return nil 141 } 142 idx, doc := bsoncore.AppendDocumentStart(nil) 143 doc = bsoncore.AppendHeader(doc, clusterTime.Type, "$clusterTime") 144 doc = append(doc, clusterTime.Value...) 145 doc, _ = bsoncore.AppendDocumentEnd(doc, idx) 146 return doc 147} 148 149func updateClusterTimes(sess *session.Client, clock *session.ClusterClock, response bson.Raw) error { 150 clusterTime := responseClusterTime(response) 151 if clusterTime == nil { 152 return nil 153 } 154 155 if sess != nil { 156 err := sess.AdvanceClusterTime(clusterTime) 157 if err != nil { 158 return err 159 } 160 } 161 162 if clock != nil { 163 clock.AdvanceClusterTime(clusterTime) 164 } 165 166 return nil 167} 168 169func updateOperationTime(sess *session.Client, response bson.Raw) error { 170 if sess == nil { 171 return nil 172 } 173 174 opTimeElem, err := response.LookupErr("operationTime") 175 if err != nil { 176 // operationTime not included by the server 177 return nil 178 } 179 180 t, i := opTimeElem.Timestamp() 181 return sess.AdvanceOperationTime(&primitive.Timestamp{ 182 T: t, 183 I: i, 184 }) 185} 186 187func marshalCommand(cmd bsonx.Doc) (bson.Raw, error) { 188 if cmd == nil { 189 return bson.Raw{5, 0, 0, 0, 0}, nil 190 } 191 192 return cmd.MarshalBSON() 193} 194 195// adds session related fields to a BSON doc representing a command 196func addSessionFields(cmd bsonx.Doc, desc description.SelectedServer, client *session.Client) (bsonx.Doc, error) { 197 if client == nil || !description.SessionsSupported(desc.WireVersion) || desc.SessionTimeoutMinutes == 0 { 198 return cmd, nil 199 } 200 201 if client.Terminated { 202 return cmd, session.ErrSessionEnded 203 } 204 205 if _, err := cmd.LookupElementErr("lsid"); err != nil { 206 cmd = cmd.Delete("lsid") 207 } 208 209 cmd = append(cmd, bsonx.Elem{"lsid", bsonx.Document(client.SessionID)}) 210 211 if client.TransactionRunning() || 212 client.RetryingCommit { 213 cmd = addTransaction(cmd, client) 214 } 215 216 client.ApplyCommand() // advance the state machine based on a command executing 217 218 return cmd, nil 219} 220 221// if in a transaction, add the transaction fields 222func addTransaction(cmd bsonx.Doc, client *session.Client) bsonx.Doc { 223 cmd = append(cmd, bsonx.Elem{"txnNumber", bsonx.Int64(client.TxnNumber)}) 224 if client.TransactionStarting() { 225 // When starting transaction, always transition to the next state, even on error 226 cmd = append(cmd, bsonx.Elem{"startTransaction", bsonx.Boolean(true)}) 227 } 228 return append(cmd, bsonx.Elem{"autocommit", bsonx.Boolean(false)}) 229} 230 231func addClusterTime(cmd bsonx.Doc, desc description.SelectedServer, sess *session.Client, clock *session.ClusterClock) bsonx.Doc { 232 if (clock == nil && sess == nil) || !description.SessionsSupported(desc.WireVersion) { 233 return cmd 234 } 235 236 var clusterTime bson.Raw 237 if clock != nil { 238 clusterTime = clock.GetClusterTime() 239 } 240 241 if sess != nil { 242 if clusterTime == nil { 243 clusterTime = sess.ClusterTime 244 } else { 245 clusterTime = session.MaxClusterTime(clusterTime, sess.ClusterTime) 246 } 247 } 248 249 if clusterTime == nil { 250 return cmd 251 } 252 253 d, err := bsonx.ReadDoc(clusterTime) 254 if err != nil { 255 return cmd // broken clusterTime 256 } 257 258 cmd = cmd.Delete("$clusterTime") 259 260 return append(cmd, d...) 261} 262 263// add a read concern to a BSON doc representing a command 264func addReadConcern(cmd bsonx.Doc, desc description.SelectedServer, rc *readconcern.ReadConcern, sess *session.Client) (bsonx.Doc, error) { 265 // Starting transaction's read concern overrides all others 266 if sess != nil && sess.TransactionStarting() && sess.CurrentRc != nil { 267 rc = sess.CurrentRc 268 } 269 270 // start transaction must append afterclustertime IF causally consistent and operation time exists 271 if rc == nil && sess != nil && sess.TransactionStarting() && sess.Consistent && sess.OperationTime != nil { 272 rc = readconcern.New() 273 } 274 275 if rc == nil { 276 return cmd, nil 277 } 278 279 t, data, err := rc.MarshalBSONValue() 280 if err != nil { 281 return cmd, err 282 } 283 284 var rcDoc bsonx.Doc 285 err = rcDoc.UnmarshalBSONValue(t, data) 286 if err != nil { 287 return cmd, err 288 } 289 if description.SessionsSupported(desc.WireVersion) && sess != nil && sess.Consistent && sess.OperationTime != nil { 290 rcDoc = append(rcDoc, bsonx.Elem{"afterClusterTime", bsonx.Timestamp(sess.OperationTime.T, sess.OperationTime.I)}) 291 } 292 293 cmd = cmd.Delete("readConcern") 294 295 if len(rcDoc) != 0 { 296 cmd = append(cmd, bsonx.Elem{"readConcern", bsonx.Document(rcDoc)}) 297 } 298 return cmd, nil 299} 300 301// add a write concern to a BSON doc representing a command 302func addWriteConcern(cmd bsonx.Doc, wc *writeconcern.WriteConcern) (bsonx.Doc, error) { 303 if wc == nil { 304 return cmd, nil 305 } 306 307 t, data, err := wc.MarshalBSONValue() 308 if err != nil { 309 if err == writeconcern.ErrEmptyWriteConcern { 310 return cmd, nil 311 } 312 return cmd, err 313 } 314 315 var xval bsonx.Val 316 err = xval.UnmarshalBSONValue(t, data) 317 if err != nil { 318 return cmd, err 319 } 320 321 // delete if doc already has write concern 322 cmd = cmd.Delete("writeConcern") 323 324 return append(cmd, bsonx.Elem{Key: "writeConcern", Value: xval}), nil 325} 326 327// Get the error labels from a command response 328func getErrorLabels(rdr *bson.Raw) ([]string, error) { 329 var labels []string 330 labelsElem, err := rdr.LookupErr("errorLabels") 331 if err != bsoncore.ErrElementNotFound { 332 return nil, err 333 } 334 if labelsElem.Type == bsontype.Array { 335 labelsIt, err := labelsElem.Array().Elements() 336 if err != nil { 337 return nil, err 338 } 339 for _, elem := range labelsIt { 340 labels = append(labels, elem.Value().StringValue()) 341 } 342 } 343 return labels, nil 344} 345 346// Remove command arguments for insert, update, and delete commands from the BSON document so they can be encoded 347// as a Section 1 payload in OP_MSG 348func opmsgRemoveArray(cmd bsonx.Doc) (bsonx.Doc, bsonx.Arr, string) { 349 var array bsonx.Arr 350 var id string 351 352 keys := []string{"documents", "updates", "deletes"} 353 354 for _, key := range keys { 355 val, err := cmd.LookupErr(key) 356 if err != nil { 357 continue 358 } 359 360 array = val.Array() 361 cmd = cmd.Delete(key) 362 id = key 363 break 364 } 365 366 return cmd, array, id 367} 368 369// Add the $db and $readPreference keys to the command 370// If the command has no read preference, pass nil for rpDoc 371func opmsgAddGlobals(cmd bsonx.Doc, dbName string, rpDoc bsonx.Doc) (bson.Raw, error) { 372 cmd = append(cmd, bsonx.Elem{"$db", bsonx.String(dbName)}) 373 if rpDoc != nil { 374 cmd = append(cmd, bsonx.Elem{"$readPreference", bsonx.Document(rpDoc)}) 375 } 376 377 return cmd.MarshalBSON() // bsonx.Doc.MarshalBSON never returns an error. 378} 379 380func opmsgCreateDocSequence(arr bsonx.Arr, identifier string) (wiremessage.SectionDocumentSequence, error) { 381 docSequence := wiremessage.SectionDocumentSequence{ 382 PayloadType: wiremessage.DocumentSequence, 383 Identifier: identifier, 384 Documents: make([]bson.Raw, 0, len(arr)), 385 } 386 387 for _, val := range arr { 388 d, _ := val.Document().MarshalBSON() 389 docSequence.Documents = append(docSequence.Documents, d) 390 } 391 392 docSequence.Size = int32(docSequence.PayloadLen()) 393 return docSequence, nil 394} 395 396func splitBatches(docs []bsonx.Doc, maxCount, targetBatchSize int) ([][]bsonx.Doc, error) { 397 batches := [][]bsonx.Doc{} 398 399 if targetBatchSize > reservedCommandBufferBytes { 400 targetBatchSize -= reservedCommandBufferBytes 401 } 402 403 if maxCount <= 0 { 404 maxCount = 1 405 } 406 407 startAt := 0 408splitInserts: 409 for { 410 size := 0 411 batch := []bsonx.Doc{} 412 assembleBatch: 413 for idx := startAt; idx < len(docs); idx++ { 414 raw, _ := docs[idx].MarshalBSON() 415 416 if len(raw) > targetBatchSize { 417 return nil, ErrDocumentTooLarge 418 } 419 if size+len(raw) > targetBatchSize { 420 break assembleBatch 421 } 422 423 size += len(raw) 424 batch = append(batch, docs[idx]) 425 startAt++ 426 if len(batch) == maxCount { 427 break assembleBatch 428 } 429 } 430 batches = append(batches, batch) 431 if startAt == len(docs) { 432 break splitInserts 433 } 434 } 435 436 return batches, nil 437} 438 439func encodeBatch( 440 docs []bsonx.Doc, 441 opts []bsonx.Elem, 442 cmdKind WriteCommandKind, 443 collName string, 444) (bsonx.Doc, error) { 445 var cmdName string 446 var docString string 447 448 switch cmdKind { 449 case InsertCommand: 450 cmdName = "insert" 451 docString = "documents" 452 case UpdateCommand: 453 cmdName = "update" 454 docString = "updates" 455 case DeleteCommand: 456 cmdName = "delete" 457 docString = "deletes" 458 } 459 460 cmd := bsonx.Doc{{cmdName, bsonx.String(collName)}} 461 462 vals := make(bsonx.Arr, 0, len(docs)) 463 for _, doc := range docs { 464 vals = append(vals, bsonx.Document(doc)) 465 } 466 cmd = append(cmd, bsonx.Elem{docString, bsonx.Array(vals)}) 467 cmd = append(cmd, opts...) 468 469 return cmd, nil 470} 471 472// converts batches of Write Commands to wire messages 473func batchesToWireMessage(batches []*WriteBatch, desc description.SelectedServer) ([]wiremessage.WireMessage, error) { 474 wms := make([]wiremessage.WireMessage, len(batches)) 475 for _, cmd := range batches { 476 wm, err := cmd.Encode(desc) 477 if err != nil { 478 return nil, err 479 } 480 481 wms = append(wms, wm) 482 } 483 484 return wms, nil 485} 486 487// Roundtrips the write batches, returning the result structs (as interface), 488// the write batches that weren't round tripped and any errors 489func roundTripBatches( 490 ctx context.Context, 491 desc description.SelectedServer, 492 rw wiremessage.ReadWriter, 493 batches []*WriteBatch, 494 continueOnError bool, 495 sess *session.Client, 496 cmdKind WriteCommandKind, 497) (interface{}, []*WriteBatch, error) { 498 var res interface{} 499 var upsertIndex int64 // the operation index for the upserted IDs map 500 501 // hold onto txnNumber, reset it when loop exits to ensure reuse of same 502 // transaction number if retry is needed 503 var txnNumber int64 504 if sess != nil && sess.RetryWrite { 505 txnNumber = sess.TxnNumber 506 } 507 for j, cmd := range batches { 508 rdr, err := cmd.RoundTrip(ctx, desc, rw) 509 if err != nil { 510 if sess != nil && sess.RetryWrite { 511 sess.TxnNumber = txnNumber + int64(j) 512 } 513 return res, batches, err 514 } 515 516 // TODO can probably DRY up this code 517 switch cmdKind { 518 case InsertCommand: 519 if res == nil { 520 res = result.Insert{} 521 } 522 523 conv, _ := res.(result.Insert) 524 insertCmd := &Insert{} 525 r, err := insertCmd.decode(desc, rdr).Result() 526 if err != nil { 527 return res, batches, err 528 } 529 530 conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...) 531 532 if r.WriteConcernError != nil { 533 conv.WriteConcernError = r.WriteConcernError 534 if sess != nil && sess.RetryWrite { 535 sess.TxnNumber = txnNumber 536 return conv, batches, nil // report writeconcernerror for retry 537 } 538 } 539 540 conv.N += r.N 541 542 if !continueOnError && len(conv.WriteErrors) > 0 { 543 return conv, batches, nil 544 } 545 546 res = conv 547 case UpdateCommand: 548 if res == nil { 549 res = result.Update{} 550 } 551 552 conv, _ := res.(result.Update) 553 updateCmd := &Update{} 554 r, err := updateCmd.decode(desc, rdr).Result() 555 if err != nil { 556 return conv, batches, err 557 } 558 559 conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...) 560 561 if r.WriteConcernError != nil { 562 conv.WriteConcernError = r.WriteConcernError 563 if sess != nil && sess.RetryWrite { 564 sess.TxnNumber = txnNumber 565 return conv, batches, nil // report writeconcernerror for retry 566 } 567 } 568 569 conv.MatchedCount += r.MatchedCount 570 conv.ModifiedCount += r.ModifiedCount 571 for _, upsert := range r.Upserted { 572 conv.Upserted = append(conv.Upserted, result.Upsert{ 573 Index: upsert.Index + upsertIndex, 574 ID: upsert.ID, 575 }) 576 } 577 578 if !continueOnError && len(conv.WriteErrors) > 0 { 579 return conv, batches, nil 580 } 581 582 res = conv 583 upsertIndex += int64(cmd.numDocs) 584 case DeleteCommand: 585 if res == nil { 586 res = result.Delete{} 587 } 588 589 conv, _ := res.(result.Delete) 590 deleteCmd := &Delete{} 591 r, err := deleteCmd.decode(desc, rdr).Result() 592 if err != nil { 593 return conv, batches, err 594 } 595 596 conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...) 597 598 if r.WriteConcernError != nil { 599 conv.WriteConcernError = r.WriteConcernError 600 if sess != nil && sess.RetryWrite { 601 sess.TxnNumber = txnNumber 602 return conv, batches, nil // report writeconcernerror for retry 603 } 604 } 605 606 conv.N += r.N 607 608 if !continueOnError && len(conv.WriteErrors) > 0 { 609 return conv, batches, nil 610 } 611 612 res = conv 613 } 614 615 // Increment txnNumber for each batch 616 if sess != nil && sess.RetryWrite { 617 sess.IncrementTxnNumber() 618 batches = batches[1:] // if batch encoded successfully, remove it from the slice 619 } 620 } 621 622 if sess != nil && sess.RetryWrite { 623 // if retryable write succeeded, transaction number will be incremented one extra time, 624 // so we decrement it here 625 sess.TxnNumber-- 626 } 627 628 return res, batches, nil 629} 630 631// get the firstBatch, cursor ID, and namespace from a bson.Raw 632func getCursorValues(result bson.Raw) ([]bson.RawValue, Namespace, int64, error) { 633 cur, err := result.LookupErr("cursor") 634 if err != nil { 635 return nil, Namespace{}, 0, err 636 } 637 if cur.Type != bson.TypeEmbeddedDocument { 638 return nil, Namespace{}, 0, fmt.Errorf("cursor should be an embedded document but it is a BSON %s", cur.Type) 639 } 640 641 elems, err := cur.Document().Elements() 642 if err != nil { 643 return nil, Namespace{}, 0, err 644 } 645 646 var ok bool 647 var arr bson.Raw 648 var namespace Namespace 649 var cursorID int64 650 651 for _, elem := range elems { 652 switch elem.Key() { 653 case "firstBatch": 654 arr, ok = elem.Value().ArrayOK() 655 if !ok { 656 return nil, Namespace{}, 0, fmt.Errorf("firstBatch should be an array but it is a BSON %s", elem.Value().Type) 657 } 658 if err != nil { 659 return nil, Namespace{}, 0, err 660 } 661 case "ns": 662 if elem.Value().Type != bson.TypeString { 663 return nil, Namespace{}, 0, fmt.Errorf("namespace should be a string but it is a BSON %s", elem.Value().Type) 664 } 665 namespace = ParseNamespace(elem.Value().StringValue()) 666 err = namespace.Validate() 667 if err != nil { 668 return nil, Namespace{}, 0, err 669 } 670 case "id": 671 cursorID, ok = elem.Value().Int64OK() 672 if !ok { 673 return nil, Namespace{}, 0, fmt.Errorf("id should be an int64 but it is a BSON %s", elem.Value().Type) 674 } 675 } 676 } 677 678 vals, err := arr.Values() 679 if err != nil { 680 return nil, Namespace{}, 0, err 681 } 682 683 return vals, namespace, cursorID, nil 684} 685 686func getBatchSize(opts []bsonx.Elem) int32 { 687 for _, opt := range opts { 688 if opt.Key == "batchSize" { 689 return opt.Value.Int32() 690 } 691 } 692 693 return 0 694} 695 696// ErrUnacknowledgedWrite is returned from functions that have an unacknowledged 697// write concern. 698var ErrUnacknowledgedWrite = errors.New("unacknowledged write") 699 700// WriteCommandKind is the type of command represented by a Write 701type WriteCommandKind int8 702 703// These constants represent the valid types of write commands. 704const ( 705 InsertCommand WriteCommandKind = iota 706 UpdateCommand 707 DeleteCommand 708) 709