1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package integration
8
9import (
10	"fmt"
11	"io/ioutil"
12	"math"
13	"os"
14	"path"
15	"strings"
16	"testing"
17	"time"
18
19	"go.mongodb.org/mongo-driver/bson"
20	"go.mongodb.org/mongo-driver/internal/testutil/assert"
21	"go.mongodb.org/mongo-driver/mongo"
22	"go.mongodb.org/mongo-driver/mongo/options"
23	"go.mongodb.org/mongo-driver/mongo/readconcern"
24	"go.mongodb.org/mongo-driver/mongo/readpref"
25	"go.mongodb.org/mongo-driver/mongo/writeconcern"
26)
27
28var (
29	awsAccessKeyID         = os.Getenv("AWS_ACCESS_KEY_ID")
30	awsSecretAccessKey     = os.Getenv("AWS_SECRET_ACCESS_KEY")
31	awsTempAccessKeyID     = os.Getenv("CSFLE_AWS_TEMP_ACCESS_KEY_ID")
32	awsTempSecretAccessKey = os.Getenv("CSFLE_AWS_TEMP_SECRET_ACCESS_KEY")
33	awsTempSessionToken    = os.Getenv("CSFLE_AWS_TEMP_SESSION_TOKEN")
34	azureTenantID          = os.Getenv("AZURE_TENANT_ID")
35	azureClientID          = os.Getenv("AZURE_CLIENT_ID")
36	azureClientSecret      = os.Getenv("AZURE_CLIENT_SECRET")
37	gcpEmail               = os.Getenv("GCP_EMAIL")
38	gcpPrivateKey          = os.Getenv("GCP_PRIVATE_KEY")
39)
40
41// Helper functions to do read JSON spec test files and convert JSON objects into the appropriate driver types.
42// Functions in this file should take testing.TB rather than testing.T/mtest.T for generality because they
43// do not do any database communication.
44
45// generate a slice of all JSON file names in a directory
46func jsonFilesInDir(t testing.TB, dir string) []string {
47	t.Helper()
48
49	files := make([]string, 0)
50
51	entries, err := ioutil.ReadDir(dir)
52	assert.Nil(t, err, "unable to read json file: %v", err)
53
54	for _, entry := range entries {
55		if entry.IsDir() || path.Ext(entry.Name()) != ".json" {
56			continue
57		}
58
59		files = append(files, entry.Name())
60	}
61
62	return files
63}
64
65// create client options from a map
66func createClientOptions(t testing.TB, opts bson.Raw) *options.ClientOptions {
67	t.Helper()
68
69	clientOpts := options.Client()
70	elems, _ := opts.Elements()
71	for _, elem := range elems {
72		name := elem.Key()
73		opt := elem.Value()
74
75		switch name {
76		case "retryWrites":
77			clientOpts.SetRetryWrites(opt.Boolean())
78		case "w":
79			switch opt.Type {
80			case bson.TypeInt32:
81				w := int(opt.Int32())
82				clientOpts.SetWriteConcern(writeconcern.New(writeconcern.W(w)))
83			case bson.TypeDouble:
84				w := int(opt.Double())
85				clientOpts.SetWriteConcern(writeconcern.New(writeconcern.W(w)))
86			case bson.TypeString:
87				clientOpts.SetWriteConcern(writeconcern.New(writeconcern.WMajority()))
88			default:
89				t.Fatalf("unrecognized type for w client option: %v", opt.Type)
90			}
91		case "readConcernLevel":
92			clientOpts.SetReadConcern(readconcern.New(readconcern.Level(opt.StringValue())))
93		case "readPreference":
94			clientOpts.SetReadPreference(readPrefFromString(opt.StringValue()))
95		case "heartbeatFrequencyMS":
96			hf := convertValueToMilliseconds(t, opt)
97			clientOpts.SetHeartbeatInterval(hf)
98		case "retryReads":
99			clientOpts.SetRetryReads(opt.Boolean())
100		case "autoEncryptOpts":
101			clientOpts.SetAutoEncryptionOptions(createAutoEncryptionOptions(t, opt.Document()))
102		case "appname":
103			clientOpts.SetAppName(opt.StringValue())
104		case "connectTimeoutMS":
105			ct := convertValueToMilliseconds(t, opt)
106			clientOpts.SetConnectTimeout(ct)
107		case "serverSelectionTimeoutMS":
108			sst := convertValueToMilliseconds(t, opt)
109			clientOpts.SetServerSelectionTimeout(sst)
110		default:
111			t.Fatalf("unrecognized client option: %v", name)
112		}
113	}
114
115	return clientOpts
116}
117
118func createAutoEncryptionOptions(t testing.TB, opts bson.Raw) *options.AutoEncryptionOptions {
119	t.Helper()
120
121	aeo := options.AutoEncryption()
122	var kvnsFound bool
123	elems, _ := opts.Elements()
124
125	for _, elem := range elems {
126		name := elem.Key()
127		opt := elem.Value()
128
129		switch name {
130		case "kmsProviders":
131			aeo.SetKmsProviders(createKmsProvidersMap(t, opt.Document()))
132		case "schemaMap":
133			var schemaMap map[string]interface{}
134			err := bson.Unmarshal(opt.Document(), &schemaMap)
135			if err != nil {
136				t.Fatalf("error creating schema map: %v", err)
137			}
138
139			aeo.SetSchemaMap(schemaMap)
140		case "keyVaultNamespace":
141			kvnsFound = true
142			aeo.SetKeyVaultNamespace(opt.StringValue())
143		case "bypassAutoEncryption":
144			aeo.SetBypassAutoEncryption(opt.Boolean())
145		default:
146			t.Fatalf("unrecognized auto encryption option: %v", name)
147		}
148	}
149	if !kvnsFound {
150		aeo.SetKeyVaultNamespace("keyvault.datakeys")
151	}
152
153	return aeo
154}
155
156func createKmsProvidersMap(t testing.TB, opts bson.Raw) map[string]map[string]interface{} {
157	t.Helper()
158
159	kmsMap := make(map[string]map[string]interface{})
160	elems, _ := opts.Elements()
161
162	for _, elem := range elems {
163		provider := elem.Key()
164		providerOpt := elem.Value()
165
166		switch provider {
167		case "aws":
168			awsMap := map[string]interface{}{
169				"accessKeyId":     awsAccessKeyID,
170				"secretAccessKey": awsSecretAccessKey,
171			}
172			kmsMap["aws"] = awsMap
173		case "azure":
174			kmsMap["azure"] = map[string]interface{}{
175				"tenantId":     azureTenantID,
176				"clientId":     azureClientID,
177				"clientSecret": azureClientSecret,
178			}
179		case "gcp":
180			kmsMap["gcp"] = map[string]interface{}{
181				"email":      gcpEmail,
182				"privateKey": gcpPrivateKey,
183			}
184		case "local":
185			_, key := providerOpt.Document().Lookup("key").Binary()
186			localMap := map[string]interface{}{
187				"key": key,
188			}
189			kmsMap["local"] = localMap
190		case "awsTemporary":
191			if awsTempAccessKeyID == "" {
192				t.Fatal("AWS temp access key ID not set")
193			}
194			if awsTempSecretAccessKey == "" {
195				t.Fatal("AWS temp secret access key not set")
196			}
197			if awsTempSessionToken == "" {
198				t.Fatal("AWS temp session token not set")
199			}
200			awsMap := map[string]interface{}{
201				"accessKeyId":     awsTempAccessKeyID,
202				"secretAccessKey": awsTempSecretAccessKey,
203				"sessionToken":    awsTempSessionToken,
204			}
205			kmsMap["aws"] = awsMap
206		case "awsTemporaryNoSessionToken":
207			if awsTempAccessKeyID == "" {
208				t.Fatal("AWS temp access key ID not set")
209			}
210			if awsTempSecretAccessKey == "" {
211				t.Fatal("AWS temp secret access key not set")
212			}
213			awsMap := map[string]interface{}{
214				"accessKeyId":     awsTempAccessKeyID,
215				"secretAccessKey": awsTempSecretAccessKey,
216			}
217			kmsMap["aws"] = awsMap
218		default:
219			t.Fatalf("unrecognized KMS provider: %v", provider)
220		}
221	}
222
223	return kmsMap
224}
225
226// create session options from a map
227func createSessionOptions(t testing.TB, opts bson.Raw) *options.SessionOptions {
228	t.Helper()
229
230	sessOpts := options.Session()
231	elems, _ := opts.Elements()
232	for _, elem := range elems {
233		name := elem.Key()
234		opt := elem.Value()
235
236		switch name {
237		case "causalConsistency":
238			sessOpts = sessOpts.SetCausalConsistency(opt.Boolean())
239		case "defaultTransactionOptions":
240			txnOpts := createTransactionOptions(t, opt.Document())
241			if txnOpts.ReadConcern != nil {
242				sessOpts.SetDefaultReadConcern(txnOpts.ReadConcern)
243			}
244			if txnOpts.ReadPreference != nil {
245				sessOpts.SetDefaultReadPreference(txnOpts.ReadPreference)
246			}
247			if txnOpts.WriteConcern != nil {
248				sessOpts.SetDefaultWriteConcern(txnOpts.WriteConcern)
249			}
250			if txnOpts.MaxCommitTime != nil {
251				sessOpts.SetDefaultMaxCommitTime(txnOpts.MaxCommitTime)
252			}
253		default:
254			t.Fatalf("unrecognized session option: %v", name)
255		}
256	}
257
258	return sessOpts
259}
260
261// create database options from a BSON document.
262func createDatabaseOptions(t testing.TB, opts bson.Raw) *options.DatabaseOptions {
263	t.Helper()
264
265	do := options.Database()
266	elems, _ := opts.Elements()
267	for _, elem := range elems {
268		name := elem.Key()
269		opt := elem.Value()
270
271		switch name {
272		case "readConcern":
273			do.SetReadConcern(createReadConcern(opt))
274		case "writeConcern":
275			do.SetWriteConcern(createWriteConcern(t, opt))
276		default:
277			t.Fatalf("unrecognized database option: %v", name)
278		}
279	}
280
281	return do
282}
283
284// create collection options from a map
285func createCollectionOptions(t testing.TB, opts bson.Raw) *options.CollectionOptions {
286	t.Helper()
287
288	co := options.Collection()
289	elems, _ := opts.Elements()
290	for _, elem := range elems {
291		name := elem.Key()
292		opt := elem.Value()
293
294		switch name {
295		case "readConcern":
296			co.SetReadConcern(createReadConcern(opt))
297		case "writeConcern":
298			co.SetWriteConcern(createWriteConcern(t, opt))
299		case "readPreference":
300			co.SetReadPreference(createReadPref(opt))
301		default:
302			t.Fatalf("unrecognized collection option: %v", name)
303		}
304	}
305
306	return co
307}
308
309// create transaction options from a map
310func createTransactionOptions(t testing.TB, opts bson.Raw) *options.TransactionOptions {
311	t.Helper()
312
313	txnOpts := options.Transaction()
314	elems, _ := opts.Elements()
315	for _, elem := range elems {
316		name := elem.Key()
317		opt := elem.Value()
318
319		switch name {
320		case "writeConcern":
321			txnOpts.SetWriteConcern(createWriteConcern(t, opt))
322		case "readPreference":
323			txnOpts.SetReadPreference(createReadPref(opt))
324		case "readConcern":
325			txnOpts.SetReadConcern(createReadConcern(opt))
326		case "maxCommitTimeMS":
327			t := time.Duration(opt.Int32()) * time.Millisecond
328			txnOpts.SetMaxCommitTime(&t)
329		default:
330			t.Fatalf("unrecognized transaction option: %v", opt)
331		}
332	}
333	return txnOpts
334}
335
336// create a read concern from a map
337func createReadConcern(opt bson.RawValue) *readconcern.ReadConcern {
338	return readconcern.New(readconcern.Level(opt.Document().Lookup("level").StringValue()))
339}
340
341// create a read concern from a map
342func createWriteConcern(t testing.TB, opt bson.RawValue) *writeconcern.WriteConcern {
343	wcDoc, ok := opt.DocumentOK()
344	if !ok {
345		return nil
346	}
347
348	var opts []writeconcern.Option
349	elems, _ := wcDoc.Elements()
350	for _, elem := range elems {
351		key := elem.Key()
352		val := elem.Value()
353
354		switch key {
355		case "wtimeout":
356			wtimeout := convertValueToMilliseconds(t, val)
357			opts = append(opts, writeconcern.WTimeout(wtimeout))
358		case "j":
359			opts = append(opts, writeconcern.J(val.Boolean()))
360		case "w":
361			switch val.Type {
362			case bson.TypeString:
363				if val.StringValue() != "majority" {
364					break
365				}
366				opts = append(opts, writeconcern.WMajority())
367			case bson.TypeInt32:
368				w := int(val.Int32())
369				opts = append(opts, writeconcern.W(w))
370			default:
371				t.Fatalf("unrecognized type for w: %v", val.Type)
372			}
373		default:
374			t.Fatalf("unrecognized write concern option: %v", key)
375		}
376	}
377	return writeconcern.New(opts...)
378}
379
380// create a read preference from a string.
381// returns readpref.Primary() if the string doesn't match any known read preference modes.
382func readPrefFromString(s string) *readpref.ReadPref {
383	switch strings.ToLower(s) {
384	case "primary":
385		return readpref.Primary()
386	case "primarypreferred":
387		return readpref.PrimaryPreferred()
388	case "secondary":
389		return readpref.Secondary()
390	case "secondarypreferred":
391		return readpref.SecondaryPreferred()
392	case "nearest":
393		return readpref.Nearest()
394	}
395	return readpref.Primary()
396}
397
398// create a read preference from a map.
399func createReadPref(opt bson.RawValue) *readpref.ReadPref {
400	mode := opt.Document().Lookup("mode").StringValue()
401	return readPrefFromString(mode)
402}
403
404// transform a slice of BSON documents to a slice of interface{}.
405func rawSliceToInterfaceSlice(docs []bson.Raw) []interface{} {
406	out := make([]interface{}, len(docs))
407
408	for i, doc := range docs {
409		out[i] = doc
410	}
411
412	return out
413}
414
415// transform a BSON raw array to a slice of interface{}.
416func rawArrayToInterfaceSlice(docs bson.Raw) []interface{} {
417	vals, _ := docs.Values()
418
419	out := make([]interface{}, len(vals))
420	for i, val := range vals {
421		out[i] = val.Document()
422	}
423
424	return out
425}
426
427// retrieve the error associated with a result.
428func errorFromResult(t testing.TB, result interface{}) *operationError {
429	t.Helper()
430
431	// embedded doc will be unmarshalled as Raw
432	raw, ok := result.(bson.Raw)
433	if !ok {
434		return nil
435	}
436
437	var expected operationError
438	err := bson.Unmarshal(raw, &expected)
439	if err != nil {
440		return nil
441	}
442	if expected.ErrorCodeName == nil && expected.ErrorContains == nil && len(expected.ErrorLabelsOmit) == 0 &&
443		len(expected.ErrorLabelsContain) == 0 {
444		return nil
445	}
446
447	return &expected
448}
449
450// errorDetails is a helper type that holds information that can be returned by driver functions in different error
451// types.
452type errorDetails struct {
453	name   string
454	labels []string
455}
456
457// extractErrorDetails creates an errorDetails instance based on the provided error. It returns the details and an "ok"
458// value which is true if the provided error is of a known type that can be processed.
459func extractErrorDetails(err error) (errorDetails, bool) {
460	var details errorDetails
461
462	switch converted := err.(type) {
463	case mongo.CommandError:
464		details.name = converted.Name
465		details.labels = converted.Labels
466	case mongo.WriteException:
467		if converted.WriteConcernError != nil {
468			details.name = converted.WriteConcernError.Name
469		}
470		details.labels = converted.Labels
471	case mongo.BulkWriteException:
472		if converted.WriteConcernError != nil {
473			details.name = converted.WriteConcernError.Name
474		}
475		details.labels = converted.Labels
476	default:
477		return errorDetails{}, false
478	}
479
480	return details, true
481}
482
483// verify that an error returned by an operation matches the expected error.
484func verifyError(expected *operationError, actual error) error {
485	// The spec test format doesn't treat ErrNoDocuments or ErrUnacknowledgedWrite as errors, so set actual to nil
486	// to indicate that no error occurred.
487	if actual == mongo.ErrNoDocuments || actual == mongo.ErrUnacknowledgedWrite {
488		actual = nil
489	}
490
491	if expected == nil && actual != nil {
492		return fmt.Errorf("did not expect error but got %v", actual)
493	}
494	if expected != nil && actual == nil {
495		return fmt.Errorf("expected error but got nil")
496	}
497	if expected == nil {
498		return nil
499	}
500
501	// check ErrorContains for all error types
502	if expected.ErrorContains != nil {
503		emsg := strings.ToLower(*expected.ErrorContains)
504		amsg := strings.ToLower(actual.Error())
505		if !strings.Contains(amsg, emsg) {
506			return fmt.Errorf("expected error message %q to contain %q", amsg, emsg)
507		}
508	}
509
510	// Get an errorDetails instance for the error. If this fails but the test has expectations about the error name or
511	// labels, fail because we can't verify them.
512	details, ok := extractErrorDetails(actual)
513	if !ok {
514		if expected.ErrorCodeName != nil || len(expected.ErrorLabelsContain) > 0 || len(expected.ErrorLabelsOmit) > 0 {
515			return fmt.Errorf("failed to extract details from error %v of type %T", actual, actual)
516		}
517		return nil
518	}
519
520	if expected.ErrorCodeName != nil {
521		if *expected.ErrorCodeName != details.name {
522			return fmt.Errorf("expected error name %v, got %v", *expected.ErrorCodeName, details.name)
523		}
524	}
525	for _, label := range expected.ErrorLabelsContain {
526		if !stringSliceContains(details.labels, label) {
527			return fmt.Errorf("expected error %v to contain label %q", actual, label)
528		}
529	}
530	for _, label := range expected.ErrorLabelsOmit {
531		if stringSliceContains(details.labels, label) {
532			return fmt.Errorf("expected error %v to not contain label %q", actual, label)
533		}
534	}
535	return nil
536}
537
538// get the underlying value of i as an int64. returns nil if i is not an int, int32, or int64 type.
539func getIntFromInterface(i interface{}) *int64 {
540	var out int64
541
542	switch v := i.(type) {
543	case int:
544		out = int64(v)
545	case int32:
546		out = int64(v)
547	case int64:
548		out = v
549	case float32:
550		f := float64(v)
551		if math.Floor(f) != f || f > float64(math.MaxInt64) {
552			break
553		}
554
555		out = int64(f)
556	case float64:
557		if math.Floor(v) != v || v > float64(math.MaxInt64) {
558			break
559		}
560
561		out = int64(v)
562	default:
563		return nil
564	}
565
566	return &out
567}
568
569func createCollation(t testing.TB, m bson.Raw) *options.Collation {
570	var collation options.Collation
571	elems, _ := m.Elements()
572
573	for _, elem := range elems {
574		switch elem.Key() {
575		case "locale":
576			collation.Locale = elem.Value().StringValue()
577		case "caseLevel":
578			collation.CaseLevel = elem.Value().Boolean()
579		case "caseFirst":
580			collation.CaseFirst = elem.Value().StringValue()
581		case "strength":
582			collation.Strength = int(elem.Value().Int32())
583		case "numericOrdering":
584			collation.NumericOrdering = elem.Value().Boolean()
585		case "alternate":
586			collation.Alternate = elem.Value().StringValue()
587		case "maxVariable":
588			collation.MaxVariable = elem.Value().StringValue()
589		case "normalization":
590			collation.Normalization = elem.Value().Boolean()
591		case "backwards":
592			collation.Backwards = elem.Value().Boolean()
593		default:
594			t.Fatalf("unrecognized collation option: %v", elem.Key())
595		}
596	}
597	return &collation
598}
599
600func createChangeStreamOptions(t testing.TB, opts bson.Raw) *options.ChangeStreamOptions {
601	t.Helper()
602
603	csOpts := options.ChangeStream()
604	elems, _ := opts.Elements()
605	for _, elem := range elems {
606		key := elem.Key()
607		opt := elem.Value()
608
609		switch key {
610		case "batchSize":
611			csOpts.SetBatchSize(opt.Int32())
612		default:
613			t.Fatalf("unrecognized change stream option: %v", key)
614		}
615	}
616	return csOpts
617}
618
619func convertValueToMilliseconds(t testing.TB, val bson.RawValue) time.Duration {
620	t.Helper()
621
622	int32Val, ok := val.Int32OK()
623	if !ok {
624		t.Fatalf("failed to convert value of type %s to int32", val.Type)
625	}
626	return time.Duration(int32Val) * time.Millisecond
627}
628
629func stringSliceContains(stringSlice []string, target string) bool {
630	for _, str := range stringSlice {
631		if str == target {
632			return true
633		}
634	}
635	return false
636}
637