1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package mongo 8 9import ( 10 "bytes" 11 "context" 12 "encoding/json" 13 "math" 14 "strings" 15 "testing" 16 17 "github.com/stretchr/testify/require" 18 "go.mongodb.org/mongo-driver/bson" 19 "go.mongodb.org/mongo-driver/internal/testutil/helpers" 20 "go.mongodb.org/mongo-driver/mongo/options" 21 "go.mongodb.org/mongo-driver/mongo/readconcern" 22 "go.mongodb.org/mongo-driver/mongo/writeconcern" 23 "go.mongodb.org/mongo-driver/x/bsonx" 24) 25 26// Various helper functions for crud related operations 27 28// Mutates the client to add options 29func addClientOptions(c *Client, opts map[string]interface{}) { 30 for name, opt := range opts { 31 switch name { 32 case "retryWrites": 33 c.retryWrites = opt.(bool) 34 case "w": 35 switch opt.(type) { 36 case float64: 37 c.writeConcern = writeconcern.New(writeconcern.W(int(opt.(float64)))) 38 case string: 39 c.writeConcern = writeconcern.New(writeconcern.WMajority()) 40 } 41 case "readConcernLevel": 42 c.readConcern = readconcern.New(readconcern.Level(opt.(string))) 43 case "readPreference": 44 c.readPreference = readPrefFromString(opt.(string)) 45 } 46 } 47} 48 49// Mutates the collection to add options 50func addCollectionOptions(c *Collection, opts map[string]interface{}) { 51 for name, opt := range opts { 52 switch name { 53 case "readConcern": 54 c.readConcern = getReadConcern(opt) 55 case "writeConcern": 56 c.writeConcern = getWriteConcern(opt) 57 case "readPreference": 58 c.readPreference = readPrefFromString(opt.(map[string]interface{})["mode"].(string)) 59 } 60 } 61} 62 63func executeCount(sess *sessionImpl, coll *Collection, args map[string]interface{}) (int64, error) { 64 var filter map[string]interface{} 65 opts := options.Count() 66 for name, opt := range args { 67 switch name { 68 case "filter": 69 filter = opt.(map[string]interface{}) 70 case "skip": 71 opts = opts.SetSkip(int64(opt.(float64))) 72 case "limit": 73 opts = opts.SetLimit(int64(opt.(float64))) 74 case "collation": 75 opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{}))) 76 } 77 } 78 79 if sess != nil { 80 // EXAMPLE: 81 sessCtx := sessionContext{ 82 Context: context.WithValue(ctx, sessionKey{}, sess), 83 Session: sess, 84 } 85 return coll.CountDocuments(sessCtx, filter, opts) 86 } 87 return coll.CountDocuments(ctx, filter, opts) 88} 89 90func executeCountDocuments(sess *sessionImpl, coll *Collection, args map[string]interface{}) (int64, error) { 91 var filter map[string]interface{} 92 opts := options.Count() 93 for name, opt := range args { 94 switch name { 95 case "filter": 96 filter = opt.(map[string]interface{}) 97 case "skip": 98 opts = opts.SetSkip(int64(opt.(float64))) 99 case "limit": 100 opts = opts.SetLimit(int64(opt.(float64))) 101 case "collation": 102 opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{}))) 103 } 104 } 105 106 if sess != nil { 107 // EXAMPLE: 108 sessCtx := sessionContext{ 109 Context: context.WithValue(ctx, sessionKey{}, sess), 110 Session: sess, 111 } 112 return coll.CountDocuments(sessCtx, filter, opts) 113 } 114 return coll.CountDocuments(ctx, filter, opts) 115} 116 117func executeDistinct(sess *sessionImpl, coll *Collection, args map[string]interface{}) ([]interface{}, error) { 118 var fieldName string 119 var filter map[string]interface{} 120 opts := options.Distinct() 121 for name, opt := range args { 122 switch name { 123 case "filter": 124 filter = opt.(map[string]interface{}) 125 case "fieldName": 126 fieldName = opt.(string) 127 case "collation": 128 opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{}))) 129 } 130 } 131 132 if sess != nil { 133 sessCtx := sessionContext{ 134 Context: context.WithValue(ctx, sessionKey{}, sess), 135 Session: sess, 136 } 137 return coll.Distinct(sessCtx, fieldName, filter, opts) 138 } 139 return coll.Distinct(ctx, fieldName, filter, opts) 140} 141 142func executeInsertOne(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*InsertOneResult, error) { 143 document := args["document"].(map[string]interface{}) 144 145 // For some reason, the insertion document is unmarshaled with a float rather than integer, 146 // but the documents that are used to initially populate the collection are unmarshaled 147 // correctly with integers. To ensure that the tests can correctly compare them, we iterate 148 // through the insertion document and change any valid integers stored as floats to actual 149 // integers. 150 replaceFloatsWithInts(document) 151 152 if sess != nil { 153 sessCtx := sessionContext{ 154 Context: context.WithValue(ctx, sessionKey{}, sess), 155 Session: sess, 156 } 157 return coll.InsertOne(sessCtx, document) 158 } 159 return coll.InsertOne(context.Background(), document) 160} 161 162func executeInsertMany(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*InsertManyResult, error) { 163 documents := args["documents"].([]interface{}) 164 165 // For some reason, the insertion documents are unmarshaled with a float rather than 166 // integer, but the documents that are used to initially populate the collection are 167 // unmarshaled correctly with integers. To ensure that the tests can correctly compare 168 // them, we iterate through the insertion documents and change any valid integers stored 169 // as floats to actual integers. 170 for i, doc := range documents { 171 docM := doc.(map[string]interface{}) 172 replaceFloatsWithInts(docM) 173 174 documents[i] = docM 175 } 176 177 if sess != nil { 178 sessCtx := sessionContext{ 179 Context: context.WithValue(ctx, sessionKey{}, sess), 180 Session: sess, 181 } 182 return coll.InsertMany(sessCtx, documents) 183 } 184 return coll.InsertMany(context.Background(), documents) 185} 186 187func executeFind(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*Cursor, error) { 188 opts := options.Find() 189 var filter map[string]interface{} 190 for name, opt := range args { 191 switch name { 192 case "filter": 193 filter = opt.(map[string]interface{}) 194 case "sort": 195 opts = opts.SetSort(opt.(map[string]interface{})) 196 case "skip": 197 opts = opts.SetSkip(int64(opt.(float64))) 198 case "limit": 199 opts = opts.SetLimit(int64(opt.(float64))) 200 case "batchSize": 201 opts = opts.SetBatchSize(int32(opt.(float64))) 202 case "collation": 203 opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{}))) 204 } 205 } 206 207 if sess != nil { 208 sessCtx := sessionContext{ 209 Context: context.WithValue(ctx, sessionKey{}, sess), 210 Session: sess, 211 } 212 return coll.Find(sessCtx, filter, opts) 213 } 214 return coll.Find(ctx, filter, opts) 215} 216 217func executeFindOneAndDelete(sess *sessionImpl, coll *Collection, args map[string]interface{}) *SingleResult { 218 opts := options.FindOneAndDelete() 219 var filter map[string]interface{} 220 for name, opt := range args { 221 switch name { 222 case "filter": 223 filter = opt.(map[string]interface{}) 224 case "sort": 225 opts = opts.SetSort(opt.(map[string]interface{})) 226 case "projection": 227 opts = opts.SetProjection(opt.(map[string]interface{})) 228 case "collation": 229 opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{}))) 230 } 231 } 232 233 if sess != nil { 234 sessCtx := sessionContext{ 235 Context: context.WithValue(ctx, sessionKey{}, sess), 236 Session: sess, 237 } 238 return coll.FindOneAndDelete(sessCtx, filter, opts) 239 } 240 return coll.FindOneAndDelete(ctx, filter, opts) 241} 242 243func executeFindOneAndUpdate(sess *sessionImpl, coll *Collection, args map[string]interface{}) *SingleResult { 244 opts := options.FindOneAndUpdate() 245 var filter map[string]interface{} 246 var update map[string]interface{} 247 for name, opt := range args { 248 switch name { 249 case "filter": 250 filter = opt.(map[string]interface{}) 251 case "update": 252 update = opt.(map[string]interface{}) 253 case "arrayFilters": 254 opts = opts.SetArrayFilters(options.ArrayFilters{ 255 Filters: opt.([]interface{}), 256 }) 257 case "sort": 258 opts = opts.SetSort(opt.(map[string]interface{})) 259 case "projection": 260 opts = opts.SetProjection(opt.(map[string]interface{})) 261 case "upsert": 262 opts = opts.SetUpsert(opt.(bool)) 263 case "returnDocument": 264 switch opt.(string) { 265 case "After": 266 opts = opts.SetReturnDocument(options.After) 267 case "Before": 268 opts = opts.SetReturnDocument(options.Before) 269 } 270 case "collation": 271 opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{}))) 272 } 273 } 274 275 // For some reason, the filter and update documents are unmarshaled with floats 276 // rather than integers, but the documents that are used to initially populate the 277 // collection are unmarshaled correctly with integers. To ensure that the tests can 278 // correctly compare them, we iterate through the filter and replacement documents and 279 // change any valid integers stored as floats to actual integers. 280 replaceFloatsWithInts(filter) 281 replaceFloatsWithInts(update) 282 283 if sess != nil { 284 sessCtx := sessionContext{ 285 Context: context.WithValue(ctx, sessionKey{}, sess), 286 Session: sess, 287 } 288 return coll.FindOneAndUpdate(sessCtx, filter, update, opts) 289 } 290 return coll.FindOneAndUpdate(ctx, filter, update, opts) 291} 292 293func executeFindOneAndReplace(sess *sessionImpl, coll *Collection, args map[string]interface{}) *SingleResult { 294 opts := options.FindOneAndReplace() 295 var filter map[string]interface{} 296 var replacement map[string]interface{} 297 for name, opt := range args { 298 switch name { 299 case "filter": 300 filter = opt.(map[string]interface{}) 301 case "replacement": 302 replacement = opt.(map[string]interface{}) 303 case "sort": 304 opts = opts.SetSort(opt.(map[string]interface{})) 305 case "projection": 306 opts = opts.SetProjection(opt.(map[string]interface{})) 307 case "upsert": 308 opts = opts.SetUpsert(opt.(bool)) 309 case "returnDocument": 310 switch opt.(string) { 311 case "After": 312 opts = opts.SetReturnDocument(options.After) 313 case "Before": 314 opts = opts.SetReturnDocument(options.Before) 315 } 316 case "collation": 317 opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{}))) 318 } 319 } 320 321 // For some reason, the filter and replacement documents are unmarshaled with floats 322 // rather than integers, but the documents that are used to initially populate the 323 // collection are unmarshaled correctly with integers. To ensure that the tests can 324 // correctly compare them, we iterate through the filter and replacement documents and 325 // change any valid integers stored as floats to actual integers. 326 replaceFloatsWithInts(filter) 327 replaceFloatsWithInts(replacement) 328 329 if sess != nil { 330 sessCtx := sessionContext{ 331 Context: context.WithValue(ctx, sessionKey{}, sess), 332 Session: sess, 333 } 334 return coll.FindOneAndReplace(sessCtx, filter, replacement, opts) 335 } 336 return coll.FindOneAndReplace(ctx, filter, replacement, opts) 337} 338 339func executeDeleteOne(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*DeleteResult, error) { 340 opts := options.Delete() 341 var filter map[string]interface{} 342 for name, opt := range args { 343 switch name { 344 case "filter": 345 filter = opt.(map[string]interface{}) 346 case "collation": 347 opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{}))) 348 } 349 } 350 351 // For some reason, the filter document is unmarshaled with floats 352 // rather than integers, but the documents that are used to initially populate the 353 // collection are unmarshaled correctly with integers. To ensure that the tests can 354 // correctly compare them, we iterate through the filter and replacement documents and 355 // change any valid integers stored as floats to actual integers. 356 replaceFloatsWithInts(filter) 357 358 if sess != nil { 359 sessCtx := sessionContext{ 360 Context: context.WithValue(ctx, sessionKey{}, sess), 361 Session: sess, 362 } 363 return coll.DeleteOne(sessCtx, filter, opts) 364 } 365 return coll.DeleteOne(ctx, filter, opts) 366} 367 368func executeDeleteMany(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*DeleteResult, error) { 369 opts := options.Delete() 370 var filter map[string]interface{} 371 for name, opt := range args { 372 switch name { 373 case "filter": 374 filter = opt.(map[string]interface{}) 375 case "collation": 376 opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{}))) 377 } 378 } 379 380 // For some reason, the filter document is unmarshaled with floats 381 // rather than integers, but the documents that are used to initially populate the 382 // collection are unmarshaled correctly with integers. To ensure that the tests can 383 // correctly compare them, we iterate through the filter and replacement documents and 384 // change any valid integers stored as floats to actual integers. 385 replaceFloatsWithInts(filter) 386 387 if sess != nil { 388 sessCtx := sessionContext{ 389 Context: context.WithValue(ctx, sessionKey{}, sess), 390 Session: sess, 391 } 392 return coll.DeleteMany(sessCtx, filter, opts) 393 } 394 return coll.DeleteMany(ctx, filter, opts) 395} 396 397func executeReplaceOne(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*UpdateResult, error) { 398 opts := options.Replace() 399 var filter map[string]interface{} 400 var replacement map[string]interface{} 401 for name, opt := range args { 402 switch name { 403 case "filter": 404 filter = opt.(map[string]interface{}) 405 case "replacement": 406 replacement = opt.(map[string]interface{}) 407 case "upsert": 408 opts = opts.SetUpsert(opt.(bool)) 409 case "collation": 410 opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{}))) 411 } 412 } 413 414 // For some reason, the filter and replacement documents are unmarshaled with floats 415 // rather than integers, but the documents that are used to initially populate the 416 // collection are unmarshaled correctly with integers. To ensure that the tests can 417 // correctly compare them, we iterate through the filter and replacement documents and 418 // change any valid integers stored as floats to actual integers. 419 replaceFloatsWithInts(filter) 420 replaceFloatsWithInts(replacement) 421 422 // TODO temporarily default upsert to false explicitly to make test pass 423 // because we do not send upsert=false by default 424 //opts = opts.SetUpsert(false) 425 if opts.Upsert == nil { 426 opts = opts.SetUpsert(false) 427 } 428 if sess != nil { 429 sessCtx := sessionContext{ 430 Context: context.WithValue(ctx, sessionKey{}, sess), 431 Session: sess, 432 } 433 return coll.ReplaceOne(sessCtx, filter, replacement, opts) 434 } 435 return coll.ReplaceOne(ctx, filter, replacement, opts) 436} 437 438func executeUpdateOne(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*UpdateResult, error) { 439 opts := options.Update() 440 var filter map[string]interface{} 441 var update map[string]interface{} 442 for name, opt := range args { 443 switch name { 444 case "filter": 445 filter = opt.(map[string]interface{}) 446 case "update": 447 update = opt.(map[string]interface{}) 448 case "arrayFilters": 449 opts = opts.SetArrayFilters(options.ArrayFilters{Filters: opt.([]interface{})}) 450 case "upsert": 451 opts = opts.SetUpsert(opt.(bool)) 452 case "collation": 453 opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{}))) 454 } 455 } 456 457 // For some reason, the filter and update documents are unmarshaled with floats 458 // rather than integers, but the documents that are used to initially populate the 459 // collection are unmarshaled correctly with integers. To ensure that the tests can 460 // correctly compare them, we iterate through the filter and replacement documents and 461 // change any valid integers stored as floats to actual integers. 462 replaceFloatsWithInts(filter) 463 replaceFloatsWithInts(update) 464 465 if opts.Upsert == nil { 466 opts = opts.SetUpsert(false) 467 } 468 if sess != nil { 469 sessCtx := sessionContext{ 470 Context: context.WithValue(ctx, sessionKey{}, sess), 471 Session: sess, 472 } 473 return coll.UpdateOne(sessCtx, filter, update, opts) 474 } 475 return coll.UpdateOne(ctx, filter, update, opts) 476} 477 478func executeUpdateMany(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*UpdateResult, error) { 479 opts := options.Update() 480 var filter map[string]interface{} 481 var update map[string]interface{} 482 for name, opt := range args { 483 switch name { 484 case "filter": 485 filter = opt.(map[string]interface{}) 486 case "update": 487 update = opt.(map[string]interface{}) 488 case "arrayFilters": 489 opts = opts.SetArrayFilters(options.ArrayFilters{Filters: opt.([]interface{})}) 490 case "upsert": 491 opts = opts.SetUpsert(opt.(bool)) 492 case "collation": 493 opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{}))) 494 } 495 } 496 497 // For some reason, the filter and update documents are unmarshaled with floats 498 // rather than integers, but the documents that are used to initially populate the 499 // collection are unmarshaled correctly with integers. To ensure that the tests can 500 // correctly compare them, we iterate through the filter and replacement documents and 501 // change any valid integers stored as floats to actual integers. 502 replaceFloatsWithInts(filter) 503 replaceFloatsWithInts(update) 504 505 if opts.Upsert == nil { 506 opts = opts.SetUpsert(false) 507 } 508 if sess != nil { 509 sessCtx := sessionContext{ 510 Context: context.WithValue(ctx, sessionKey{}, sess), 511 Session: sess, 512 } 513 return coll.UpdateMany(sessCtx, filter, update, opts) 514 } 515 return coll.UpdateMany(ctx, filter, update, opts) 516} 517 518func executeAggregate(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*Cursor, error) { 519 var pipeline []interface{} 520 opts := options.Aggregate() 521 for name, opt := range args { 522 switch name { 523 case "pipeline": 524 pipeline = opt.([]interface{}) 525 case "batchSize": 526 opts = opts.SetBatchSize(int32(opt.(float64))) 527 case "collation": 528 opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{}))) 529 } 530 } 531 532 if sess != nil { 533 sessCtx := sessionContext{ 534 Context: context.WithValue(ctx, sessionKey{}, sess), 535 Session: sess, 536 } 537 return coll.Aggregate(sessCtx, pipeline, opts) 538 } 539 return coll.Aggregate(ctx, pipeline, opts) 540} 541 542func executeRunCommand(sess Session, db *Database, argmap map[string]interface{}, args json.RawMessage) *SingleResult { 543 var cmd bsonx.Doc 544 opts := options.RunCmd() 545 for name, opt := range argmap { 546 switch name { 547 case "command": 548 argBytes, err := args.MarshalJSON() 549 if err != nil { 550 return &SingleResult{err: err} 551 } 552 553 var argCmdStruct struct { 554 Cmd json.RawMessage `json:"command"` 555 } 556 err = json.NewDecoder(bytes.NewBuffer(argBytes)).Decode(&argCmdStruct) 557 if err != nil { 558 return &SingleResult{err: err} 559 } 560 561 err = bson.UnmarshalExtJSON(argCmdStruct.Cmd, true, &cmd) 562 if err != nil { 563 return &SingleResult{err: err} 564 } 565 case "readPreference": 566 opts = opts.SetReadPreference(getReadPref(opt)) 567 } 568 } 569 570 if sess != nil { 571 sessCtx := sessionContext{ 572 Context: context.WithValue(ctx, sessionKey{}, sess), 573 Session: sess, 574 } 575 return db.RunCommand(sessCtx, cmd, opts) 576 } 577 return db.RunCommand(ctx, cmd, opts) 578} 579 580func verifyBulkWriteResult(t *testing.T, res *BulkWriteResult, result json.RawMessage) { 581 expectedBytes, err := result.MarshalJSON() 582 require.NoError(t, err) 583 584 var expected BulkWriteResult 585 err = json.NewDecoder(bytes.NewBuffer(expectedBytes)).Decode(&expected) 586 require.NoError(t, err) 587 588 require.Equal(t, expected.DeletedCount, res.DeletedCount) 589 require.Equal(t, expected.InsertedCount, res.InsertedCount) 590 require.Equal(t, expected.MatchedCount, res.MatchedCount) 591 require.Equal(t, expected.ModifiedCount, res.ModifiedCount) 592 require.Equal(t, expected.UpsertedCount, res.UpsertedCount) 593 594 // replace floats with ints 595 for opID, upsertID := range expected.UpsertedIDs { 596 if floatID, ok := upsertID.(float64); ok { 597 expected.UpsertedIDs[opID] = int32(floatID) 598 } 599 } 600 601 for operationID, upsertID := range expected.UpsertedIDs { 602 require.Equal(t, upsertID, res.UpsertedIDs[operationID]) 603 } 604} 605 606func verifyInsertOneResult(t *testing.T, res *InsertOneResult, result json.RawMessage) { 607 expectedBytes, err := result.MarshalJSON() 608 require.NoError(t, err) 609 610 var expected InsertOneResult 611 err = json.NewDecoder(bytes.NewBuffer(expectedBytes)).Decode(&expected) 612 require.NoError(t, err) 613 614 expectedID := expected.InsertedID 615 if f, ok := expectedID.(float64); ok && f == math.Floor(f) { 616 expectedID = int32(f) 617 } 618 619 if expectedID != nil { 620 require.NotNil(t, res) 621 require.Equal(t, expectedID, res.InsertedID) 622 } 623} 624 625func verifyInsertManyResult(t *testing.T, res *InsertManyResult, result json.RawMessage) { 626 expectedBytes, err := result.MarshalJSON() 627 require.NoError(t, err) 628 629 var expected struct{ InsertedIds map[string]interface{} } 630 err = json.NewDecoder(bytes.NewBuffer(expectedBytes)).Decode(&expected) 631 require.NoError(t, err) 632 633 if expected.InsertedIds != nil { 634 replaceFloatsWithInts(expected.InsertedIds) 635 636 for _, val := range expected.InsertedIds { 637 require.Contains(t, res.InsertedIDs, val) 638 } 639 } 640} 641 642func verifyCursorResult2(t *testing.T, cur *Cursor, result json.RawMessage) { 643 for _, expected := range docSliceFromRaw(t, result) { 644 require.NotNil(t, cur) 645 require.True(t, cur.Next(context.Background())) 646 647 var actual bsonx.Doc 648 require.NoError(t, cur.Decode(&actual)) 649 650 compareDocs(t, expected, actual) 651 } 652 653 require.False(t, cur.Next(ctx)) 654 require.NoError(t, cur.Err()) 655} 656 657func verifyCursorResult(t *testing.T, cur *Cursor, result json.RawMessage) { 658 for _, expected := range docSliceFromRaw(t, result) { 659 require.NotNil(t, cur) 660 require.True(t, cur.Next(context.Background())) 661 662 var actual bsonx.Doc 663 require.NoError(t, cur.Decode(&actual)) 664 665 compareDocs(t, expected, actual) 666 } 667 668 require.False(t, cur.Next(ctx)) 669 require.NoError(t, cur.Err()) 670} 671 672func verifySingleResult(t *testing.T, res *SingleResult, result json.RawMessage) { 673 jsonBytes, err := result.MarshalJSON() 674 require.NoError(t, err) 675 676 var actual bsonx.Doc 677 err = res.Decode(&actual) 678 if err == ErrNoDocuments { 679 var expected map[string]interface{} 680 err := json.NewDecoder(bytes.NewBuffer(jsonBytes)).Decode(&expected) 681 require.NoError(t, err) 682 683 require.Nil(t, expected) 684 return 685 } 686 687 require.NoError(t, err) 688 689 doc := bsonx.Doc{} 690 err = bson.UnmarshalExtJSON(jsonBytes, true, &doc) 691 require.NoError(t, err) 692 693 require.True(t, doc.Equal(actual)) 694} 695 696func verifyDistinctResult(t *testing.T, res []interface{}, result json.RawMessage) { 697 resultBytes, err := result.MarshalJSON() 698 require.NoError(t, err) 699 700 var expected []interface{} 701 require.NoError(t, json.NewDecoder(bytes.NewBuffer(resultBytes)).Decode(&expected)) 702 703 require.Equal(t, len(expected), len(res)) 704 705 for i := range expected { 706 expectedElem := expected[i] 707 actualElem := res[i] 708 709 iExpected := testhelpers.GetIntFromInterface(expectedElem) 710 iActual := testhelpers.GetIntFromInterface(actualElem) 711 712 require.Equal(t, iExpected == nil, iActual == nil) 713 if iExpected != nil { 714 require.Equal(t, *iExpected, *iActual) 715 continue 716 } 717 718 require.Equal(t, expected[i], res[i]) 719 } 720} 721 722func verifyDeleteResult(t *testing.T, res *DeleteResult, result json.RawMessage) { 723 expectedBytes, err := result.MarshalJSON() 724 require.NoError(t, err) 725 726 var expected DeleteResult 727 err = json.NewDecoder(bytes.NewBuffer(expectedBytes)).Decode(&expected) 728 require.NoError(t, err) 729 730 require.Equal(t, expected.DeletedCount, res.DeletedCount) 731} 732 733func verifyUpdateResult(t *testing.T, res *UpdateResult, result json.RawMessage) { 734 expectedBytes, err := result.MarshalJSON() 735 require.NoError(t, err) 736 737 var expected struct { 738 MatchedCount int64 739 ModifiedCount int64 740 UpsertedCount int64 741 } 742 err = json.NewDecoder(bytes.NewBuffer(expectedBytes)).Decode(&expected) 743 744 require.Equal(t, expected.MatchedCount, res.MatchedCount) 745 require.Equal(t, expected.ModifiedCount, res.ModifiedCount) 746 747 actualUpsertedCount := int64(0) 748 if res.UpsertedID != nil { 749 actualUpsertedCount = 1 750 } 751 752 require.Equal(t, expected.UpsertedCount, actualUpsertedCount) 753} 754 755func verifyRunCommandResult(t *testing.T, res bson.Raw, result json.RawMessage) { 756 if len(result) == 0 { 757 return 758 } 759 jsonBytes, err := result.MarshalJSON() 760 require.NoError(t, err) 761 762 expected := bsonx.Doc{} 763 err = bson.UnmarshalExtJSON(jsonBytes, true, &expected) 764 require.NoError(t, err) 765 766 require.NotNil(t, res) 767 actual, err := bsonx.ReadDoc(res) 768 require.NoError(t, err) 769 770 // All runcommand results in tests are for key "n" only 771 compareElements(t, expected.LookupElement("n"), actual.LookupElement("n")) 772} 773 774func verifyCollectionContents(t *testing.T, coll *Collection, result json.RawMessage) { 775 cursor, err := coll.Find(context.Background(), bsonx.Doc{}) 776 require.NoError(t, err) 777 778 verifyCursorResult(t, cursor, result) 779} 780 781func sanitizeCollectionName(kind string, name string) string { 782 // Collections can't have "$" in their names, so we substitute it with "%". 783 name = strings.Replace(name, "$", "%", -1) 784 785 // Namespaces can only have 120 bytes max. 786 if len(kind+"."+name) >= 119 { 787 name = name[:119-len(kind+".")] 788 } 789 790 return name 791} 792 793func compareElements(t *testing.T, expected bsonx.Elem, actual bsonx.Elem) { 794 if expected.Value.IsNumber() { 795 if expectedNum, ok := expected.Value.Int64OK(); ok { 796 switch actual.Value.Type() { 797 case bson.TypeInt32: 798 require.Equal(t, expectedNum, int64(actual.Value.Int32()), "For key %v", expected.Key) 799 case bson.TypeInt64: 800 require.Equal(t, expectedNum, actual.Value.Int64(), "For key %v\n", expected.Key) 801 case bson.TypeDouble: 802 require.Equal(t, expectedNum, int64(actual.Value.Double()), "For key %v\n", expected.Key) 803 } 804 } else { 805 expectedNum := expected.Value.Int32() 806 switch actual.Value.Type() { 807 case bson.TypeInt32: 808 require.Equal(t, expectedNum, actual.Value.Int32(), "For key %v", expected.Key) 809 case bson.TypeInt64: 810 require.Equal(t, expectedNum, int32(actual.Value.Int64()), "For key %v\n", expected.Key) 811 case bson.TypeDouble: 812 require.Equal(t, expectedNum, int32(actual.Value.Double()), "For key %v\n", expected.Key) 813 } 814 } 815 } else if conv, ok := expected.Value.DocumentOK(); ok { 816 actualConv, actualOk := actual.Value.DocumentOK() 817 require.True(t, actualOk) 818 compareDocs(t, conv, actualConv) 819 } else if conv, ok := expected.Value.ArrayOK(); ok { 820 actualConv, actualOk := actual.Value.ArrayOK() 821 require.True(t, actualOk) 822 compareArrays(t, conv, actualConv) 823 } else { 824 require.True(t, actual.Equal(expected), "For key %s, expected %v\nactual: %v", expected.Key, expected, actual) 825 } 826} 827 828func compareArrays(t *testing.T, expected bsonx.Arr, actual bsonx.Arr) { 829 if len(expected) != len(actual) { 830 t.Errorf("array length mismatch. expected %d got %d", len(expected), len(actual)) 831 t.FailNow() 832 } 833 834 for idx := range expected { 835 expectedDoc := expected[idx].Document() 836 actualDoc := actual[idx].Document() 837 compareDocs(t, expectedDoc, actualDoc) 838 } 839} 840 841func collationFromMap(m map[string]interface{}) *options.Collation { 842 var collation options.Collation 843 844 if locale, found := m["locale"]; found { 845 collation.Locale = locale.(string) 846 } 847 848 if caseLevel, found := m["caseLevel"]; found { 849 collation.CaseLevel = caseLevel.(bool) 850 } 851 852 if caseFirst, found := m["caseFirst"]; found { 853 collation.CaseFirst = caseFirst.(string) 854 } 855 856 if strength, found := m["strength"]; found { 857 collation.Strength = int(strength.(float64)) 858 } 859 860 if numericOrdering, found := m["numericOrdering"]; found { 861 collation.NumericOrdering = numericOrdering.(bool) 862 } 863 864 if alternate, found := m["alternate"]; found { 865 collation.Alternate = alternate.(string) 866 } 867 868 if maxVariable, found := m["maxVariable"]; found { 869 collation.MaxVariable = maxVariable.(string) 870 } 871 872 if normalization, found := m["normalization"]; found { 873 collation.Normalization = normalization.(bool) 874 } 875 876 if backwards, found := m["backwards"]; found { 877 collation.Backwards = backwards.(bool) 878 } 879 880 return &collation 881} 882 883func docSliceFromRaw(t *testing.T, raw json.RawMessage) []bsonx.Doc { 884 jsonBytes, err := raw.MarshalJSON() 885 require.NoError(t, err) 886 887 array := bsonx.Arr{} 888 err = bson.UnmarshalExtJSON(jsonBytes, true, &array) 889 require.NoError(t, err) 890 891 docs := make([]bsonx.Doc, 0) 892 893 for _, val := range array { 894 docs = append(docs, val.Document()) 895 } 896 897 return docs 898} 899 900func docSliceToInterfaceSlice(docs []bsonx.Doc) []interface{} { 901 out := make([]interface{}, 0, len(docs)) 902 903 for _, doc := range docs { 904 out = append(out, doc) 905 } 906 907 return out 908} 909 910func replaceFloatsWithInts(m map[string]interface{}) { 911 for key, val := range m { 912 if f, ok := val.(float64); ok && f == math.Floor(f) { 913 m[key] = int32(f) 914 continue 915 } 916 917 if innerM, ok := val.(map[string]interface{}); ok { 918 replaceFloatsWithInts(innerM) 919 m[key] = innerM 920 } 921 } 922} 923