1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package integration 8 9import ( 10 "fmt" 11 "io/ioutil" 12 "math" 13 "os" 14 "path" 15 "strings" 16 "testing" 17 "time" 18 19 "go.mongodb.org/mongo-driver/bson" 20 "go.mongodb.org/mongo-driver/internal/testutil/assert" 21 "go.mongodb.org/mongo-driver/mongo" 22 "go.mongodb.org/mongo-driver/mongo/options" 23 "go.mongodb.org/mongo-driver/mongo/readconcern" 24 "go.mongodb.org/mongo-driver/mongo/readpref" 25 "go.mongodb.org/mongo-driver/mongo/writeconcern" 26) 27 28var ( 29 awsAccessKeyID = os.Getenv("AWS_ACCESS_KEY_ID") 30 awsSecretAccessKey = os.Getenv("AWS_SECRET_ACCESS_KEY") 31 awsTempAccessKeyID = os.Getenv("CSFLE_AWS_TEMP_ACCESS_KEY_ID") 32 awsTempSecretAccessKey = os.Getenv("CSFLE_AWS_TEMP_SECRET_ACCESS_KEY") 33 awsTempSessionToken = os.Getenv("CSFLE_AWS_TEMP_SESSION_TOKEN") 34 azureTenantID = os.Getenv("AZURE_TENANT_ID") 35 azureClientID = os.Getenv("AZURE_CLIENT_ID") 36 azureClientSecret = os.Getenv("AZURE_CLIENT_SECRET") 37 gcpEmail = os.Getenv("GCP_EMAIL") 38 gcpPrivateKey = os.Getenv("GCP_PRIVATE_KEY") 39) 40 41// Helper functions to do read JSON spec test files and convert JSON objects into the appropriate driver types. 42// Functions in this file should take testing.TB rather than testing.T/mtest.T for generality because they 43// do not do any database communication. 44 45// generate a slice of all JSON file names in a directory 46func jsonFilesInDir(t testing.TB, dir string) []string { 47 t.Helper() 48 49 files := make([]string, 0) 50 51 entries, err := ioutil.ReadDir(dir) 52 assert.Nil(t, err, "unable to read json file: %v", err) 53 54 for _, entry := range entries { 55 if entry.IsDir() || path.Ext(entry.Name()) != ".json" { 56 continue 57 } 58 59 files = append(files, entry.Name()) 60 } 61 62 return files 63} 64 65// create client options from a map 66func createClientOptions(t testing.TB, opts bson.Raw) *options.ClientOptions { 67 t.Helper() 68 69 clientOpts := options.Client() 70 elems, _ := opts.Elements() 71 for _, elem := range elems { 72 name := elem.Key() 73 opt := elem.Value() 74 75 switch name { 76 case "retryWrites": 77 clientOpts.SetRetryWrites(opt.Boolean()) 78 case "w": 79 switch opt.Type { 80 case bson.TypeInt32: 81 w := int(opt.Int32()) 82 clientOpts.SetWriteConcern(writeconcern.New(writeconcern.W(w))) 83 case bson.TypeDouble: 84 w := int(opt.Double()) 85 clientOpts.SetWriteConcern(writeconcern.New(writeconcern.W(w))) 86 case bson.TypeString: 87 clientOpts.SetWriteConcern(writeconcern.New(writeconcern.WMajority())) 88 default: 89 t.Fatalf("unrecognized type for w client option: %v", opt.Type) 90 } 91 case "readConcernLevel": 92 clientOpts.SetReadConcern(readconcern.New(readconcern.Level(opt.StringValue()))) 93 case "readPreference": 94 clientOpts.SetReadPreference(readPrefFromString(opt.StringValue())) 95 case "heartbeatFrequencyMS": 96 hf := convertValueToMilliseconds(t, opt) 97 clientOpts.SetHeartbeatInterval(hf) 98 case "retryReads": 99 clientOpts.SetRetryReads(opt.Boolean()) 100 case "autoEncryptOpts": 101 clientOpts.SetAutoEncryptionOptions(createAutoEncryptionOptions(t, opt.Document())) 102 case "appname": 103 clientOpts.SetAppName(opt.StringValue()) 104 case "connectTimeoutMS": 105 ct := convertValueToMilliseconds(t, opt) 106 clientOpts.SetConnectTimeout(ct) 107 case "serverSelectionTimeoutMS": 108 sst := convertValueToMilliseconds(t, opt) 109 clientOpts.SetServerSelectionTimeout(sst) 110 default: 111 t.Fatalf("unrecognized client option: %v", name) 112 } 113 } 114 115 return clientOpts 116} 117 118func createAutoEncryptionOptions(t testing.TB, opts bson.Raw) *options.AutoEncryptionOptions { 119 t.Helper() 120 121 aeo := options.AutoEncryption() 122 var kvnsFound bool 123 elems, _ := opts.Elements() 124 125 for _, elem := range elems { 126 name := elem.Key() 127 opt := elem.Value() 128 129 switch name { 130 case "kmsProviders": 131 aeo.SetKmsProviders(createKmsProvidersMap(t, opt.Document())) 132 case "schemaMap": 133 var schemaMap map[string]interface{} 134 err := bson.Unmarshal(opt.Document(), &schemaMap) 135 if err != nil { 136 t.Fatalf("error creating schema map: %v", err) 137 } 138 139 aeo.SetSchemaMap(schemaMap) 140 case "keyVaultNamespace": 141 kvnsFound = true 142 aeo.SetKeyVaultNamespace(opt.StringValue()) 143 case "bypassAutoEncryption": 144 aeo.SetBypassAutoEncryption(opt.Boolean()) 145 default: 146 t.Fatalf("unrecognized auto encryption option: %v", name) 147 } 148 } 149 if !kvnsFound { 150 aeo.SetKeyVaultNamespace("keyvault.datakeys") 151 } 152 153 return aeo 154} 155 156func createKmsProvidersMap(t testing.TB, opts bson.Raw) map[string]map[string]interface{} { 157 t.Helper() 158 159 kmsMap := make(map[string]map[string]interface{}) 160 elems, _ := opts.Elements() 161 162 for _, elem := range elems { 163 provider := elem.Key() 164 providerOpt := elem.Value() 165 166 switch provider { 167 case "aws": 168 awsMap := map[string]interface{}{ 169 "accessKeyId": awsAccessKeyID, 170 "secretAccessKey": awsSecretAccessKey, 171 } 172 kmsMap["aws"] = awsMap 173 case "azure": 174 kmsMap["azure"] = map[string]interface{}{ 175 "tenantId": azureTenantID, 176 "clientId": azureClientID, 177 "clientSecret": azureClientSecret, 178 } 179 case "gcp": 180 kmsMap["gcp"] = map[string]interface{}{ 181 "email": gcpEmail, 182 "privateKey": gcpPrivateKey, 183 } 184 case "local": 185 _, key := providerOpt.Document().Lookup("key").Binary() 186 localMap := map[string]interface{}{ 187 "key": key, 188 } 189 kmsMap["local"] = localMap 190 case "awsTemporary": 191 if awsTempAccessKeyID == "" { 192 t.Fatal("AWS temp access key ID not set") 193 } 194 if awsTempSecretAccessKey == "" { 195 t.Fatal("AWS temp secret access key not set") 196 } 197 if awsTempSessionToken == "" { 198 t.Fatal("AWS temp session token not set") 199 } 200 awsMap := map[string]interface{}{ 201 "accessKeyId": awsTempAccessKeyID, 202 "secretAccessKey": awsTempSecretAccessKey, 203 "sessionToken": awsTempSessionToken, 204 } 205 kmsMap["aws"] = awsMap 206 case "awsTemporaryNoSessionToken": 207 if awsTempAccessKeyID == "" { 208 t.Fatal("AWS temp access key ID not set") 209 } 210 if awsTempSecretAccessKey == "" { 211 t.Fatal("AWS temp secret access key not set") 212 } 213 awsMap := map[string]interface{}{ 214 "accessKeyId": awsTempAccessKeyID, 215 "secretAccessKey": awsTempSecretAccessKey, 216 } 217 kmsMap["aws"] = awsMap 218 default: 219 t.Fatalf("unrecognized KMS provider: %v", provider) 220 } 221 } 222 223 return kmsMap 224} 225 226// create session options from a map 227func createSessionOptions(t testing.TB, opts bson.Raw) *options.SessionOptions { 228 t.Helper() 229 230 sessOpts := options.Session() 231 elems, _ := opts.Elements() 232 for _, elem := range elems { 233 name := elem.Key() 234 opt := elem.Value() 235 236 switch name { 237 case "causalConsistency": 238 sessOpts = sessOpts.SetCausalConsistency(opt.Boolean()) 239 case "defaultTransactionOptions": 240 txnOpts := createTransactionOptions(t, opt.Document()) 241 if txnOpts.ReadConcern != nil { 242 sessOpts.SetDefaultReadConcern(txnOpts.ReadConcern) 243 } 244 if txnOpts.ReadPreference != nil { 245 sessOpts.SetDefaultReadPreference(txnOpts.ReadPreference) 246 } 247 if txnOpts.WriteConcern != nil { 248 sessOpts.SetDefaultWriteConcern(txnOpts.WriteConcern) 249 } 250 if txnOpts.MaxCommitTime != nil { 251 sessOpts.SetDefaultMaxCommitTime(txnOpts.MaxCommitTime) 252 } 253 default: 254 t.Fatalf("unrecognized session option: %v", name) 255 } 256 } 257 258 return sessOpts 259} 260 261// create database options from a BSON document. 262func createDatabaseOptions(t testing.TB, opts bson.Raw) *options.DatabaseOptions { 263 t.Helper() 264 265 do := options.Database() 266 elems, _ := opts.Elements() 267 for _, elem := range elems { 268 name := elem.Key() 269 opt := elem.Value() 270 271 switch name { 272 case "readConcern": 273 do.SetReadConcern(createReadConcern(opt)) 274 case "writeConcern": 275 do.SetWriteConcern(createWriteConcern(t, opt)) 276 default: 277 t.Fatalf("unrecognized database option: %v", name) 278 } 279 } 280 281 return do 282} 283 284// create collection options from a map 285func createCollectionOptions(t testing.TB, opts bson.Raw) *options.CollectionOptions { 286 t.Helper() 287 288 co := options.Collection() 289 elems, _ := opts.Elements() 290 for _, elem := range elems { 291 name := elem.Key() 292 opt := elem.Value() 293 294 switch name { 295 case "readConcern": 296 co.SetReadConcern(createReadConcern(opt)) 297 case "writeConcern": 298 co.SetWriteConcern(createWriteConcern(t, opt)) 299 case "readPreference": 300 co.SetReadPreference(createReadPref(opt)) 301 default: 302 t.Fatalf("unrecognized collection option: %v", name) 303 } 304 } 305 306 return co 307} 308 309// create transaction options from a map 310func createTransactionOptions(t testing.TB, opts bson.Raw) *options.TransactionOptions { 311 t.Helper() 312 313 txnOpts := options.Transaction() 314 elems, _ := opts.Elements() 315 for _, elem := range elems { 316 name := elem.Key() 317 opt := elem.Value() 318 319 switch name { 320 case "writeConcern": 321 txnOpts.SetWriteConcern(createWriteConcern(t, opt)) 322 case "readPreference": 323 txnOpts.SetReadPreference(createReadPref(opt)) 324 case "readConcern": 325 txnOpts.SetReadConcern(createReadConcern(opt)) 326 case "maxCommitTimeMS": 327 t := time.Duration(opt.Int32()) * time.Millisecond 328 txnOpts.SetMaxCommitTime(&t) 329 default: 330 t.Fatalf("unrecognized transaction option: %v", opt) 331 } 332 } 333 return txnOpts 334} 335 336// create a read concern from a map 337func createReadConcern(opt bson.RawValue) *readconcern.ReadConcern { 338 return readconcern.New(readconcern.Level(opt.Document().Lookup("level").StringValue())) 339} 340 341// create a read concern from a map 342func createWriteConcern(t testing.TB, opt bson.RawValue) *writeconcern.WriteConcern { 343 wcDoc, ok := opt.DocumentOK() 344 if !ok { 345 return nil 346 } 347 348 var opts []writeconcern.Option 349 elems, _ := wcDoc.Elements() 350 for _, elem := range elems { 351 key := elem.Key() 352 val := elem.Value() 353 354 switch key { 355 case "wtimeout": 356 wtimeout := convertValueToMilliseconds(t, val) 357 opts = append(opts, writeconcern.WTimeout(wtimeout)) 358 case "j": 359 opts = append(opts, writeconcern.J(val.Boolean())) 360 case "w": 361 switch val.Type { 362 case bson.TypeString: 363 if val.StringValue() != "majority" { 364 break 365 } 366 opts = append(opts, writeconcern.WMajority()) 367 case bson.TypeInt32: 368 w := int(val.Int32()) 369 opts = append(opts, writeconcern.W(w)) 370 default: 371 t.Fatalf("unrecognized type for w: %v", val.Type) 372 } 373 default: 374 t.Fatalf("unrecognized write concern option: %v", key) 375 } 376 } 377 return writeconcern.New(opts...) 378} 379 380// create a read preference from a string. 381// returns readpref.Primary() if the string doesn't match any known read preference modes. 382func readPrefFromString(s string) *readpref.ReadPref { 383 switch strings.ToLower(s) { 384 case "primary": 385 return readpref.Primary() 386 case "primarypreferred": 387 return readpref.PrimaryPreferred() 388 case "secondary": 389 return readpref.Secondary() 390 case "secondarypreferred": 391 return readpref.SecondaryPreferred() 392 case "nearest": 393 return readpref.Nearest() 394 } 395 return readpref.Primary() 396} 397 398// create a read preference from a map. 399func createReadPref(opt bson.RawValue) *readpref.ReadPref { 400 mode := opt.Document().Lookup("mode").StringValue() 401 return readPrefFromString(mode) 402} 403 404// transform a slice of BSON documents to a slice of interface{}. 405func rawSliceToInterfaceSlice(docs []bson.Raw) []interface{} { 406 out := make([]interface{}, len(docs)) 407 408 for i, doc := range docs { 409 out[i] = doc 410 } 411 412 return out 413} 414 415// transform a BSON raw array to a slice of interface{}. 416func rawArrayToInterfaceSlice(docs bson.Raw) []interface{} { 417 vals, _ := docs.Values() 418 419 out := make([]interface{}, len(vals)) 420 for i, val := range vals { 421 out[i] = val.Document() 422 } 423 424 return out 425} 426 427// retrieve the error associated with a result. 428func errorFromResult(t testing.TB, result interface{}) *operationError { 429 t.Helper() 430 431 // embedded doc will be unmarshalled as Raw 432 raw, ok := result.(bson.Raw) 433 if !ok { 434 return nil 435 } 436 437 var expected operationError 438 err := bson.Unmarshal(raw, &expected) 439 if err != nil { 440 return nil 441 } 442 if expected.ErrorCodeName == nil && expected.ErrorContains == nil && len(expected.ErrorLabelsOmit) == 0 && 443 len(expected.ErrorLabelsContain) == 0 { 444 return nil 445 } 446 447 return &expected 448} 449 450// errorDetails is a helper type that holds information that can be returned by driver functions in different error 451// types. 452type errorDetails struct { 453 name string 454 labels []string 455} 456 457// extractErrorDetails creates an errorDetails instance based on the provided error. It returns the details and an "ok" 458// value which is true if the provided error is of a known type that can be processed. 459func extractErrorDetails(err error) (errorDetails, bool) { 460 var details errorDetails 461 462 switch converted := err.(type) { 463 case mongo.CommandError: 464 details.name = converted.Name 465 details.labels = converted.Labels 466 case mongo.WriteException: 467 if converted.WriteConcernError != nil { 468 details.name = converted.WriteConcernError.Name 469 } 470 details.labels = converted.Labels 471 case mongo.BulkWriteException: 472 if converted.WriteConcernError != nil { 473 details.name = converted.WriteConcernError.Name 474 } 475 details.labels = converted.Labels 476 default: 477 return errorDetails{}, false 478 } 479 480 return details, true 481} 482 483// verify that an error returned by an operation matches the expected error. 484func verifyError(expected *operationError, actual error) error { 485 // The spec test format doesn't treat ErrNoDocuments or ErrUnacknowledgedWrite as errors, so set actual to nil 486 // to indicate that no error occurred. 487 if actual == mongo.ErrNoDocuments || actual == mongo.ErrUnacknowledgedWrite { 488 actual = nil 489 } 490 491 if expected == nil && actual != nil { 492 return fmt.Errorf("did not expect error but got %v", actual) 493 } 494 if expected != nil && actual == nil { 495 return fmt.Errorf("expected error but got nil") 496 } 497 if expected == nil { 498 return nil 499 } 500 501 // check ErrorContains for all error types 502 if expected.ErrorContains != nil { 503 emsg := strings.ToLower(*expected.ErrorContains) 504 amsg := strings.ToLower(actual.Error()) 505 if !strings.Contains(amsg, emsg) { 506 return fmt.Errorf("expected error message %q to contain %q", amsg, emsg) 507 } 508 } 509 510 // Get an errorDetails instance for the error. If this fails but the test has expectations about the error name or 511 // labels, fail because we can't verify them. 512 details, ok := extractErrorDetails(actual) 513 if !ok { 514 if expected.ErrorCodeName != nil || len(expected.ErrorLabelsContain) > 0 || len(expected.ErrorLabelsOmit) > 0 { 515 return fmt.Errorf("failed to extract details from error %v of type %T", actual, actual) 516 } 517 return nil 518 } 519 520 if expected.ErrorCodeName != nil { 521 if *expected.ErrorCodeName != details.name { 522 return fmt.Errorf("expected error name %v, got %v", *expected.ErrorCodeName, details.name) 523 } 524 } 525 for _, label := range expected.ErrorLabelsContain { 526 if !stringSliceContains(details.labels, label) { 527 return fmt.Errorf("expected error %v to contain label %q", actual, label) 528 } 529 } 530 for _, label := range expected.ErrorLabelsOmit { 531 if stringSliceContains(details.labels, label) { 532 return fmt.Errorf("expected error %v to not contain label %q", actual, label) 533 } 534 } 535 return nil 536} 537 538// get the underlying value of i as an int64. returns nil if i is not an int, int32, or int64 type. 539func getIntFromInterface(i interface{}) *int64 { 540 var out int64 541 542 switch v := i.(type) { 543 case int: 544 out = int64(v) 545 case int32: 546 out = int64(v) 547 case int64: 548 out = v 549 case float32: 550 f := float64(v) 551 if math.Floor(f) != f || f > float64(math.MaxInt64) { 552 break 553 } 554 555 out = int64(f) 556 case float64: 557 if math.Floor(v) != v || v > float64(math.MaxInt64) { 558 break 559 } 560 561 out = int64(v) 562 default: 563 return nil 564 } 565 566 return &out 567} 568 569func createCollation(t testing.TB, m bson.Raw) *options.Collation { 570 var collation options.Collation 571 elems, _ := m.Elements() 572 573 for _, elem := range elems { 574 switch elem.Key() { 575 case "locale": 576 collation.Locale = elem.Value().StringValue() 577 case "caseLevel": 578 collation.CaseLevel = elem.Value().Boolean() 579 case "caseFirst": 580 collation.CaseFirst = elem.Value().StringValue() 581 case "strength": 582 collation.Strength = int(elem.Value().Int32()) 583 case "numericOrdering": 584 collation.NumericOrdering = elem.Value().Boolean() 585 case "alternate": 586 collation.Alternate = elem.Value().StringValue() 587 case "maxVariable": 588 collation.MaxVariable = elem.Value().StringValue() 589 case "normalization": 590 collation.Normalization = elem.Value().Boolean() 591 case "backwards": 592 collation.Backwards = elem.Value().Boolean() 593 default: 594 t.Fatalf("unrecognized collation option: %v", elem.Key()) 595 } 596 } 597 return &collation 598} 599 600func createChangeStreamOptions(t testing.TB, opts bson.Raw) *options.ChangeStreamOptions { 601 t.Helper() 602 603 csOpts := options.ChangeStream() 604 elems, _ := opts.Elements() 605 for _, elem := range elems { 606 key := elem.Key() 607 opt := elem.Value() 608 609 switch key { 610 case "batchSize": 611 csOpts.SetBatchSize(opt.Int32()) 612 default: 613 t.Fatalf("unrecognized change stream option: %v", key) 614 } 615 } 616 return csOpts 617} 618 619func convertValueToMilliseconds(t testing.TB, val bson.RawValue) time.Duration { 620 t.Helper() 621 622 int32Val, ok := val.Int32OK() 623 if !ok { 624 t.Fatalf("failed to convert value of type %s to int32", val.Type) 625 } 626 return time.Duration(int32Val) * time.Millisecond 627} 628 629func stringSliceContains(stringSlice []string, target string) bool { 630 for _, str := range stringSlice { 631 if str == target { 632 return true 633 } 634 } 635 return false 636} 637