1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package integration 8 9import ( 10 "fmt" 11 "io/ioutil" 12 "math" 13 "os" 14 "path" 15 "strings" 16 "testing" 17 "time" 18 19 "go.mongodb.org/mongo-driver/bson" 20 "go.mongodb.org/mongo-driver/internal/testutil/assert" 21 "go.mongodb.org/mongo-driver/mongo" 22 "go.mongodb.org/mongo-driver/mongo/options" 23 "go.mongodb.org/mongo-driver/mongo/readconcern" 24 "go.mongodb.org/mongo-driver/mongo/readpref" 25 "go.mongodb.org/mongo-driver/mongo/writeconcern" 26) 27 28const ( 29 awsAccessKeyID = "AWS_ACCESS_KEY_ID" 30 awsSecretAccessKey = "AWS_SECRET_ACCESS_KEY" 31) 32 33// Helper functions to do read JSON spec test files and convert JSON objects into the appropriate driver types. 34// Functions in this file should take testing.TB rather than testing.T/mtest.T for generality because they 35// do not do any database communication. 36 37// generate a slice of all JSON file names in a directory 38func jsonFilesInDir(t testing.TB, dir string) []string { 39 t.Helper() 40 41 files := make([]string, 0) 42 43 entries, err := ioutil.ReadDir(dir) 44 assert.Nil(t, err, "unable to read json file: %v", err) 45 46 for _, entry := range entries { 47 if entry.IsDir() || path.Ext(entry.Name()) != ".json" { 48 continue 49 } 50 51 files = append(files, entry.Name()) 52 } 53 54 return files 55} 56 57// create client options from a map 58func createClientOptions(t testing.TB, opts bson.Raw) *options.ClientOptions { 59 t.Helper() 60 61 clientOpts := options.Client() 62 elems, _ := opts.Elements() 63 for _, elem := range elems { 64 name := elem.Key() 65 opt := elem.Value() 66 67 switch name { 68 case "retryWrites": 69 clientOpts.SetRetryWrites(opt.Boolean()) 70 case "w": 71 switch opt.Type { 72 case bson.TypeInt32: 73 w := int(opt.Int32()) 74 clientOpts.SetWriteConcern(writeconcern.New(writeconcern.W(w))) 75 case bson.TypeDouble: 76 w := int(opt.Double()) 77 clientOpts.SetWriteConcern(writeconcern.New(writeconcern.W(w))) 78 case bson.TypeString: 79 clientOpts.SetWriteConcern(writeconcern.New(writeconcern.WMajority())) 80 default: 81 t.Fatalf("unrecognized type for w client option: %v", opt.Type) 82 } 83 case "readConcernLevel": 84 clientOpts.SetReadConcern(readconcern.New(readconcern.Level(opt.StringValue()))) 85 case "readPreference": 86 clientOpts.SetReadPreference(readPrefFromString(opt.StringValue())) 87 case "heartbeatFrequencyMS": 88 hf := convertValueToMilliseconds(t, opt) 89 clientOpts.SetHeartbeatInterval(hf) 90 case "retryReads": 91 clientOpts.SetRetryReads(opt.Boolean()) 92 case "autoEncryptOpts": 93 clientOpts.SetAutoEncryptionOptions(createAutoEncryptionOptions(t, opt.Document())) 94 case "appname": 95 clientOpts.SetAppName(opt.StringValue()) 96 case "connectTimeoutMS": 97 ct := convertValueToMilliseconds(t, opt) 98 clientOpts.SetConnectTimeout(ct) 99 case "serverSelectionTimeoutMS": 100 sst := convertValueToMilliseconds(t, opt) 101 clientOpts.SetServerSelectionTimeout(sst) 102 default: 103 t.Fatalf("unrecognized client option: %v", name) 104 } 105 } 106 107 return clientOpts 108} 109 110func createAutoEncryptionOptions(t testing.TB, opts bson.Raw) *options.AutoEncryptionOptions { 111 t.Helper() 112 113 aeo := options.AutoEncryption() 114 var kvnsFound bool 115 elems, _ := opts.Elements() 116 117 for _, elem := range elems { 118 name := elem.Key() 119 opt := elem.Value() 120 121 switch name { 122 case "kmsProviders": 123 aeo.SetKmsProviders(createKmsProvidersMap(t, opt.Document())) 124 case "schemaMap": 125 var schemaMap map[string]interface{} 126 err := bson.Unmarshal(opt.Document(), &schemaMap) 127 if err != nil { 128 t.Fatalf("error creating schema map: %v", err) 129 } 130 131 aeo.SetSchemaMap(schemaMap) 132 case "keyVaultNamespace": 133 kvnsFound = true 134 aeo.SetKeyVaultNamespace(opt.StringValue()) 135 case "bypassAutoEncryption": 136 aeo.SetBypassAutoEncryption(opt.Boolean()) 137 default: 138 t.Fatalf("unrecognized auto encryption option: %v", name) 139 } 140 } 141 if !kvnsFound { 142 aeo.SetKeyVaultNamespace("keyvault.datakeys") 143 } 144 145 return aeo 146} 147 148func createKmsProvidersMap(t testing.TB, opts bson.Raw) map[string]map[string]interface{} { 149 t.Helper() 150 151 // aws: value is always empty object. create new map value from access key ID and secret access key 152 // local: value is {"key": primitive.Binary}. transform to {"key": []byte} 153 154 kmsMap := make(map[string]map[string]interface{}) 155 elems, _ := opts.Elements() 156 157 for _, elem := range elems { 158 provider := elem.Key() 159 providerOpt := elem.Value() 160 161 switch provider { 162 case "aws": 163 keyID := os.Getenv(awsAccessKeyID) 164 if keyID == "" { 165 t.Fatalf("%s env var not set", awsAccessKeyID) 166 } 167 secretAccessKey := os.Getenv(awsSecretAccessKey) 168 if secretAccessKey == "" { 169 t.Fatalf("%s env var not set", awsSecretAccessKey) 170 } 171 172 awsMap := map[string]interface{}{ 173 "accessKeyId": keyID, 174 "secretAccessKey": secretAccessKey, 175 } 176 kmsMap["aws"] = awsMap 177 case "local": 178 _, key := providerOpt.Document().Lookup("key").Binary() 179 localMap := map[string]interface{}{ 180 "key": key, 181 } 182 kmsMap["local"] = localMap 183 default: 184 t.Fatalf("unrecognized KMS provider: %v", provider) 185 } 186 } 187 188 return kmsMap 189} 190 191// create session options from a map 192func createSessionOptions(t testing.TB, opts bson.Raw) *options.SessionOptions { 193 t.Helper() 194 195 sessOpts := options.Session() 196 elems, _ := opts.Elements() 197 for _, elem := range elems { 198 name := elem.Key() 199 opt := elem.Value() 200 201 switch name { 202 case "causalConsistency": 203 sessOpts = sessOpts.SetCausalConsistency(opt.Boolean()) 204 case "defaultTransactionOptions": 205 txnOpts := createTransactionOptions(t, opt.Document()) 206 if txnOpts.ReadConcern != nil { 207 sessOpts.SetDefaultReadConcern(txnOpts.ReadConcern) 208 } 209 if txnOpts.ReadPreference != nil { 210 sessOpts.SetDefaultReadPreference(txnOpts.ReadPreference) 211 } 212 if txnOpts.WriteConcern != nil { 213 sessOpts.SetDefaultWriteConcern(txnOpts.WriteConcern) 214 } 215 if txnOpts.MaxCommitTime != nil { 216 sessOpts.SetDefaultMaxCommitTime(txnOpts.MaxCommitTime) 217 } 218 default: 219 t.Fatalf("unrecognized session option: %v", name) 220 } 221 } 222 223 return sessOpts 224} 225 226// create database options from a BSON document. 227func createDatabaseOptions(t testing.TB, opts bson.Raw) *options.DatabaseOptions { 228 t.Helper() 229 230 do := options.Database() 231 elems, _ := opts.Elements() 232 for _, elem := range elems { 233 name := elem.Key() 234 opt := elem.Value() 235 236 switch name { 237 case "readConcern": 238 do.SetReadConcern(createReadConcern(opt)) 239 case "writeConcern": 240 do.SetWriteConcern(createWriteConcern(t, opt)) 241 default: 242 t.Fatalf("unrecognized database option: %v", name) 243 } 244 } 245 246 return do 247} 248 249// create collection options from a map 250func createCollectionOptions(t testing.TB, opts bson.Raw) *options.CollectionOptions { 251 t.Helper() 252 253 co := options.Collection() 254 elems, _ := opts.Elements() 255 for _, elem := range elems { 256 name := elem.Key() 257 opt := elem.Value() 258 259 switch name { 260 case "readConcern": 261 co.SetReadConcern(createReadConcern(opt)) 262 case "writeConcern": 263 co.SetWriteConcern(createWriteConcern(t, opt)) 264 case "readPreference": 265 co.SetReadPreference(createReadPref(opt)) 266 default: 267 t.Fatalf("unrecognized collection option: %v", name) 268 } 269 } 270 271 return co 272} 273 274// create transaction options from a map 275func createTransactionOptions(t testing.TB, opts bson.Raw) *options.TransactionOptions { 276 t.Helper() 277 278 txnOpts := options.Transaction() 279 elems, _ := opts.Elements() 280 for _, elem := range elems { 281 name := elem.Key() 282 opt := elem.Value() 283 284 switch name { 285 case "writeConcern": 286 txnOpts.SetWriteConcern(createWriteConcern(t, opt)) 287 case "readPreference": 288 txnOpts.SetReadPreference(createReadPref(opt)) 289 case "readConcern": 290 txnOpts.SetReadConcern(createReadConcern(opt)) 291 case "maxCommitTimeMS": 292 t := time.Duration(opt.Int32()) * time.Millisecond 293 txnOpts.SetMaxCommitTime(&t) 294 default: 295 t.Fatalf("unrecognized transaction option: %v", opt) 296 } 297 } 298 return txnOpts 299} 300 301// create a read concern from a map 302func createReadConcern(opt bson.RawValue) *readconcern.ReadConcern { 303 return readconcern.New(readconcern.Level(opt.Document().Lookup("level").StringValue())) 304} 305 306// create a read concern from a map 307func createWriteConcern(t testing.TB, opt bson.RawValue) *writeconcern.WriteConcern { 308 wcDoc, ok := opt.DocumentOK() 309 if !ok { 310 return nil 311 } 312 313 var opts []writeconcern.Option 314 elems, _ := wcDoc.Elements() 315 for _, elem := range elems { 316 key := elem.Key() 317 val := elem.Value() 318 319 switch key { 320 case "wtimeout": 321 wtimeout := convertValueToMilliseconds(t, val) 322 opts = append(opts, writeconcern.WTimeout(wtimeout)) 323 case "j": 324 opts = append(opts, writeconcern.J(val.Boolean())) 325 case "w": 326 switch val.Type { 327 case bson.TypeString: 328 if val.StringValue() != "majority" { 329 break 330 } 331 opts = append(opts, writeconcern.WMajority()) 332 case bson.TypeInt32: 333 w := int(val.Int32()) 334 opts = append(opts, writeconcern.W(w)) 335 default: 336 t.Fatalf("unrecognized type for w: %v", val.Type) 337 } 338 default: 339 t.Fatalf("unrecognized write concern option: %v", key) 340 } 341 } 342 return writeconcern.New(opts...) 343} 344 345// create a read preference from a string. 346// returns readpref.Primary() if the string doesn't match any known read preference modes. 347func readPrefFromString(s string) *readpref.ReadPref { 348 switch strings.ToLower(s) { 349 case "primary": 350 return readpref.Primary() 351 case "primarypreferred": 352 return readpref.PrimaryPreferred() 353 case "secondary": 354 return readpref.Secondary() 355 case "secondarypreferred": 356 return readpref.SecondaryPreferred() 357 case "nearest": 358 return readpref.Nearest() 359 } 360 return readpref.Primary() 361} 362 363// create a read preference from a map. 364func createReadPref(opt bson.RawValue) *readpref.ReadPref { 365 mode := opt.Document().Lookup("mode").StringValue() 366 return readPrefFromString(mode) 367} 368 369// transform a slice of BSON documents to a slice of interface{}. 370func rawSliceToInterfaceSlice(docs []bson.Raw) []interface{} { 371 out := make([]interface{}, len(docs)) 372 373 for i, doc := range docs { 374 out[i] = doc 375 } 376 377 return out 378} 379 380// transform a BSON raw array to a slice of interface{}. 381func rawArrayToInterfaceSlice(docs bson.Raw) []interface{} { 382 vals, _ := docs.Values() 383 384 out := make([]interface{}, len(vals)) 385 for i, val := range vals { 386 out[i] = val.Document() 387 } 388 389 return out 390} 391 392// retrieve the error associated with a result. 393func errorFromResult(t testing.TB, result interface{}) *operationError { 394 t.Helper() 395 396 // embedded doc will be unmarshalled as Raw 397 raw, ok := result.(bson.Raw) 398 if !ok { 399 return nil 400 } 401 402 var expected operationError 403 err := bson.Unmarshal(raw, &expected) 404 if err != nil { 405 return nil 406 } 407 if expected.ErrorCodeName == nil && expected.ErrorContains == nil && len(expected.ErrorLabelsOmit) == 0 && 408 len(expected.ErrorLabelsContain) == 0 { 409 return nil 410 } 411 412 return &expected 413} 414 415// errorDetails is a helper type that holds information that can be returned by driver functions in different error 416// types. 417type errorDetails struct { 418 name string 419 labels []string 420} 421 422// extractErrorDetails creates an errorDetails instance based on the provided error. It returns the details and an "ok" 423// value which is true if the provided error is of a known type that can be processed. 424func extractErrorDetails(err error) (errorDetails, bool) { 425 var details errorDetails 426 427 switch converted := err.(type) { 428 case mongo.CommandError: 429 details.name = converted.Name 430 details.labels = converted.Labels 431 case mongo.WriteException: 432 if converted.WriteConcernError != nil { 433 details.name = converted.WriteConcernError.Name 434 } 435 details.labels = converted.Labels 436 case mongo.BulkWriteException: 437 if converted.WriteConcernError != nil { 438 details.name = converted.WriteConcernError.Name 439 } 440 details.labels = converted.Labels 441 default: 442 return errorDetails{}, false 443 } 444 445 return details, true 446} 447 448// verify that an error returned by an operation matches the expected error. 449func verifyError(expected *operationError, actual error) error { 450 // The spec test format doesn't treat ErrNoDocuments or ErrUnacknowledgedWrite as errors, so set actual to nil 451 // to indicate that no error occurred. 452 if actual == mongo.ErrNoDocuments || actual == mongo.ErrUnacknowledgedWrite { 453 actual = nil 454 } 455 456 if expected == nil && actual != nil { 457 return fmt.Errorf("did not expect error but got %v", actual) 458 } 459 if expected != nil && actual == nil { 460 return fmt.Errorf("expected error but got nil") 461 } 462 if expected == nil { 463 return nil 464 } 465 466 // check ErrorContains for all error types 467 if expected.ErrorContains != nil { 468 emsg := strings.ToLower(*expected.ErrorContains) 469 amsg := strings.ToLower(actual.Error()) 470 if !strings.Contains(amsg, emsg) { 471 return fmt.Errorf("expected error message %q to contain %q", amsg, emsg) 472 } 473 } 474 475 // Get an errorDetails instance for the error. If this fails but the test has expectations about the error name or 476 // labels, fail because we can't verify them. 477 details, ok := extractErrorDetails(actual) 478 if !ok { 479 if expected.ErrorCodeName != nil || len(expected.ErrorLabelsContain) > 0 || len(expected.ErrorLabelsOmit) > 0 { 480 return fmt.Errorf("failed to extract details from error %v of type %T", actual, actual) 481 } 482 return nil 483 } 484 485 if expected.ErrorCodeName != nil { 486 if *expected.ErrorCodeName != details.name { 487 return fmt.Errorf("expected error name %v, got %v", *expected.ErrorCodeName, details.name) 488 } 489 } 490 for _, label := range expected.ErrorLabelsContain { 491 if !stringSliceContains(details.labels, label) { 492 return fmt.Errorf("expected error %v to contain label %q", actual, label) 493 } 494 } 495 for _, label := range expected.ErrorLabelsOmit { 496 if stringSliceContains(details.labels, label) { 497 return fmt.Errorf("expected error %v to not contain label %q", actual, label) 498 } 499 } 500 return nil 501} 502 503// get the underlying value of i as an int64. returns nil if i is not an int, int32, or int64 type. 504func getIntFromInterface(i interface{}) *int64 { 505 var out int64 506 507 switch v := i.(type) { 508 case int: 509 out = int64(v) 510 case int32: 511 out = int64(v) 512 case int64: 513 out = v 514 case float32: 515 f := float64(v) 516 if math.Floor(f) != f || f > float64(math.MaxInt64) { 517 break 518 } 519 520 out = int64(f) 521 case float64: 522 if math.Floor(v) != v || v > float64(math.MaxInt64) { 523 break 524 } 525 526 out = int64(v) 527 default: 528 return nil 529 } 530 531 return &out 532} 533 534func createCollation(t testing.TB, m bson.Raw) *options.Collation { 535 var collation options.Collation 536 elems, _ := m.Elements() 537 538 for _, elem := range elems { 539 switch elem.Key() { 540 case "locale": 541 collation.Locale = elem.Value().StringValue() 542 case "caseLevel": 543 collation.CaseLevel = elem.Value().Boolean() 544 case "caseFirst": 545 collation.CaseFirst = elem.Value().StringValue() 546 case "strength": 547 collation.Strength = int(elem.Value().Int32()) 548 case "numericOrdering": 549 collation.NumericOrdering = elem.Value().Boolean() 550 case "alternate": 551 collation.Alternate = elem.Value().StringValue() 552 case "maxVariable": 553 collation.MaxVariable = elem.Value().StringValue() 554 case "normalization": 555 collation.Normalization = elem.Value().Boolean() 556 case "backwards": 557 collation.Backwards = elem.Value().Boolean() 558 default: 559 t.Fatalf("unrecognized collation option: %v", elem.Key()) 560 } 561 } 562 return &collation 563} 564 565func createChangeStreamOptions(t testing.TB, opts bson.Raw) *options.ChangeStreamOptions { 566 t.Helper() 567 568 csOpts := options.ChangeStream() 569 elems, _ := opts.Elements() 570 for _, elem := range elems { 571 key := elem.Key() 572 opt := elem.Value() 573 574 switch key { 575 case "batchSize": 576 csOpts.SetBatchSize(opt.Int32()) 577 default: 578 t.Fatalf("unrecognized change stream option: %v", key) 579 } 580 } 581 return csOpts 582} 583 584func convertValueToMilliseconds(t testing.TB, val bson.RawValue) time.Duration { 585 t.Helper() 586 587 int32Val, ok := val.Int32OK() 588 if !ok { 589 t.Fatalf("failed to convert value of type %s to int32", val.Type) 590 } 591 return time.Duration(int32Val) * time.Millisecond 592} 593 594func stringSliceContains(stringSlice []string, target string) bool { 595 for _, str := range stringSlice { 596 if str == target { 597 return true 598 } 599 } 600 return false 601} 602