1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package integration
8
9import (
10	"fmt"
11	"io/ioutil"
12	"math"
13	"os"
14	"path"
15	"strings"
16	"testing"
17	"time"
18
19	"go.mongodb.org/mongo-driver/bson"
20	"go.mongodb.org/mongo-driver/internal/testutil/assert"
21	"go.mongodb.org/mongo-driver/mongo"
22	"go.mongodb.org/mongo-driver/mongo/options"
23	"go.mongodb.org/mongo-driver/mongo/readconcern"
24	"go.mongodb.org/mongo-driver/mongo/readpref"
25	"go.mongodb.org/mongo-driver/mongo/writeconcern"
26)
27
28const (
29	awsAccessKeyID     = "AWS_ACCESS_KEY_ID"
30	awsSecretAccessKey = "AWS_SECRET_ACCESS_KEY"
31)
32
33// Helper functions to do read JSON spec test files and convert JSON objects into the appropriate driver types.
34// Functions in this file should take testing.TB rather than testing.T/mtest.T for generality because they
35// do not do any database communication.
36
37// generate a slice of all JSON file names in a directory
38func jsonFilesInDir(t testing.TB, dir string) []string {
39	t.Helper()
40
41	files := make([]string, 0)
42
43	entries, err := ioutil.ReadDir(dir)
44	assert.Nil(t, err, "unable to read json file: %v", err)
45
46	for _, entry := range entries {
47		if entry.IsDir() || path.Ext(entry.Name()) != ".json" {
48			continue
49		}
50
51		files = append(files, entry.Name())
52	}
53
54	return files
55}
56
57// create client options from a map
58func createClientOptions(t testing.TB, opts bson.Raw) *options.ClientOptions {
59	t.Helper()
60
61	clientOpts := options.Client()
62	elems, _ := opts.Elements()
63	for _, elem := range elems {
64		name := elem.Key()
65		opt := elem.Value()
66
67		switch name {
68		case "retryWrites":
69			clientOpts.SetRetryWrites(opt.Boolean())
70		case "w":
71			switch opt.Type {
72			case bson.TypeInt32:
73				w := int(opt.Int32())
74				clientOpts.SetWriteConcern(writeconcern.New(writeconcern.W(w)))
75			case bson.TypeDouble:
76				w := int(opt.Double())
77				clientOpts.SetWriteConcern(writeconcern.New(writeconcern.W(w)))
78			case bson.TypeString:
79				clientOpts.SetWriteConcern(writeconcern.New(writeconcern.WMajority()))
80			default:
81				t.Fatalf("unrecognized type for w client option: %v", opt.Type)
82			}
83		case "readConcernLevel":
84			clientOpts.SetReadConcern(readconcern.New(readconcern.Level(opt.StringValue())))
85		case "readPreference":
86			clientOpts.SetReadPreference(readPrefFromString(opt.StringValue()))
87		case "heartbeatFrequencyMS":
88			hf := convertValueToMilliseconds(t, opt)
89			clientOpts.SetHeartbeatInterval(hf)
90		case "retryReads":
91			clientOpts.SetRetryReads(opt.Boolean())
92		case "autoEncryptOpts":
93			clientOpts.SetAutoEncryptionOptions(createAutoEncryptionOptions(t, opt.Document()))
94		case "appname":
95			clientOpts.SetAppName(opt.StringValue())
96		case "connectTimeoutMS":
97			ct := convertValueToMilliseconds(t, opt)
98			clientOpts.SetConnectTimeout(ct)
99		case "serverSelectionTimeoutMS":
100			sst := convertValueToMilliseconds(t, opt)
101			clientOpts.SetServerSelectionTimeout(sst)
102		default:
103			t.Fatalf("unrecognized client option: %v", name)
104		}
105	}
106
107	return clientOpts
108}
109
110func createAutoEncryptionOptions(t testing.TB, opts bson.Raw) *options.AutoEncryptionOptions {
111	t.Helper()
112
113	aeo := options.AutoEncryption()
114	var kvnsFound bool
115	elems, _ := opts.Elements()
116
117	for _, elem := range elems {
118		name := elem.Key()
119		opt := elem.Value()
120
121		switch name {
122		case "kmsProviders":
123			aeo.SetKmsProviders(createKmsProvidersMap(t, opt.Document()))
124		case "schemaMap":
125			var schemaMap map[string]interface{}
126			err := bson.Unmarshal(opt.Document(), &schemaMap)
127			if err != nil {
128				t.Fatalf("error creating schema map: %v", err)
129			}
130
131			aeo.SetSchemaMap(schemaMap)
132		case "keyVaultNamespace":
133			kvnsFound = true
134			aeo.SetKeyVaultNamespace(opt.StringValue())
135		case "bypassAutoEncryption":
136			aeo.SetBypassAutoEncryption(opt.Boolean())
137		default:
138			t.Fatalf("unrecognized auto encryption option: %v", name)
139		}
140	}
141	if !kvnsFound {
142		aeo.SetKeyVaultNamespace("keyvault.datakeys")
143	}
144
145	return aeo
146}
147
148func createKmsProvidersMap(t testing.TB, opts bson.Raw) map[string]map[string]interface{} {
149	t.Helper()
150
151	// aws: value is always empty object. create new map value from access key ID and secret access key
152	// local: value is {"key": primitive.Binary}. transform to {"key": []byte}
153
154	kmsMap := make(map[string]map[string]interface{})
155	elems, _ := opts.Elements()
156
157	for _, elem := range elems {
158		provider := elem.Key()
159		providerOpt := elem.Value()
160
161		switch provider {
162		case "aws":
163			keyID := os.Getenv(awsAccessKeyID)
164			if keyID == "" {
165				t.Fatalf("%s env var not set", awsAccessKeyID)
166			}
167			secretAccessKey := os.Getenv(awsSecretAccessKey)
168			if secretAccessKey == "" {
169				t.Fatalf("%s env var not set", awsSecretAccessKey)
170			}
171
172			awsMap := map[string]interface{}{
173				"accessKeyId":     keyID,
174				"secretAccessKey": secretAccessKey,
175			}
176			kmsMap["aws"] = awsMap
177		case "local":
178			_, key := providerOpt.Document().Lookup("key").Binary()
179			localMap := map[string]interface{}{
180				"key": key,
181			}
182			kmsMap["local"] = localMap
183		default:
184			t.Fatalf("unrecognized KMS provider: %v", provider)
185		}
186	}
187
188	return kmsMap
189}
190
191// create session options from a map
192func createSessionOptions(t testing.TB, opts bson.Raw) *options.SessionOptions {
193	t.Helper()
194
195	sessOpts := options.Session()
196	elems, _ := opts.Elements()
197	for _, elem := range elems {
198		name := elem.Key()
199		opt := elem.Value()
200
201		switch name {
202		case "causalConsistency":
203			sessOpts = sessOpts.SetCausalConsistency(opt.Boolean())
204		case "defaultTransactionOptions":
205			txnOpts := createTransactionOptions(t, opt.Document())
206			if txnOpts.ReadConcern != nil {
207				sessOpts.SetDefaultReadConcern(txnOpts.ReadConcern)
208			}
209			if txnOpts.ReadPreference != nil {
210				sessOpts.SetDefaultReadPreference(txnOpts.ReadPreference)
211			}
212			if txnOpts.WriteConcern != nil {
213				sessOpts.SetDefaultWriteConcern(txnOpts.WriteConcern)
214			}
215			if txnOpts.MaxCommitTime != nil {
216				sessOpts.SetDefaultMaxCommitTime(txnOpts.MaxCommitTime)
217			}
218		default:
219			t.Fatalf("unrecognized session option: %v", name)
220		}
221	}
222
223	return sessOpts
224}
225
226// create database options from a BSON document.
227func createDatabaseOptions(t testing.TB, opts bson.Raw) *options.DatabaseOptions {
228	t.Helper()
229
230	do := options.Database()
231	elems, _ := opts.Elements()
232	for _, elem := range elems {
233		name := elem.Key()
234		opt := elem.Value()
235
236		switch name {
237		case "readConcern":
238			do.SetReadConcern(createReadConcern(opt))
239		case "writeConcern":
240			do.SetWriteConcern(createWriteConcern(t, opt))
241		default:
242			t.Fatalf("unrecognized database option: %v", name)
243		}
244	}
245
246	return do
247}
248
249// create collection options from a map
250func createCollectionOptions(t testing.TB, opts bson.Raw) *options.CollectionOptions {
251	t.Helper()
252
253	co := options.Collection()
254	elems, _ := opts.Elements()
255	for _, elem := range elems {
256		name := elem.Key()
257		opt := elem.Value()
258
259		switch name {
260		case "readConcern":
261			co.SetReadConcern(createReadConcern(opt))
262		case "writeConcern":
263			co.SetWriteConcern(createWriteConcern(t, opt))
264		case "readPreference":
265			co.SetReadPreference(createReadPref(opt))
266		default:
267			t.Fatalf("unrecognized collection option: %v", name)
268		}
269	}
270
271	return co
272}
273
274// create transaction options from a map
275func createTransactionOptions(t testing.TB, opts bson.Raw) *options.TransactionOptions {
276	t.Helper()
277
278	txnOpts := options.Transaction()
279	elems, _ := opts.Elements()
280	for _, elem := range elems {
281		name := elem.Key()
282		opt := elem.Value()
283
284		switch name {
285		case "writeConcern":
286			txnOpts.SetWriteConcern(createWriteConcern(t, opt))
287		case "readPreference":
288			txnOpts.SetReadPreference(createReadPref(opt))
289		case "readConcern":
290			txnOpts.SetReadConcern(createReadConcern(opt))
291		case "maxCommitTimeMS":
292			t := time.Duration(opt.Int32()) * time.Millisecond
293			txnOpts.SetMaxCommitTime(&t)
294		default:
295			t.Fatalf("unrecognized transaction option: %v", opt)
296		}
297	}
298	return txnOpts
299}
300
301// create a read concern from a map
302func createReadConcern(opt bson.RawValue) *readconcern.ReadConcern {
303	return readconcern.New(readconcern.Level(opt.Document().Lookup("level").StringValue()))
304}
305
306// create a read concern from a map
307func createWriteConcern(t testing.TB, opt bson.RawValue) *writeconcern.WriteConcern {
308	wcDoc, ok := opt.DocumentOK()
309	if !ok {
310		return nil
311	}
312
313	var opts []writeconcern.Option
314	elems, _ := wcDoc.Elements()
315	for _, elem := range elems {
316		key := elem.Key()
317		val := elem.Value()
318
319		switch key {
320		case "wtimeout":
321			wtimeout := convertValueToMilliseconds(t, val)
322			opts = append(opts, writeconcern.WTimeout(wtimeout))
323		case "j":
324			opts = append(opts, writeconcern.J(val.Boolean()))
325		case "w":
326			switch val.Type {
327			case bson.TypeString:
328				if val.StringValue() != "majority" {
329					break
330				}
331				opts = append(opts, writeconcern.WMajority())
332			case bson.TypeInt32:
333				w := int(val.Int32())
334				opts = append(opts, writeconcern.W(w))
335			default:
336				t.Fatalf("unrecognized type for w: %v", val.Type)
337			}
338		default:
339			t.Fatalf("unrecognized write concern option: %v", key)
340		}
341	}
342	return writeconcern.New(opts...)
343}
344
345// create a read preference from a string.
346// returns readpref.Primary() if the string doesn't match any known read preference modes.
347func readPrefFromString(s string) *readpref.ReadPref {
348	switch strings.ToLower(s) {
349	case "primary":
350		return readpref.Primary()
351	case "primarypreferred":
352		return readpref.PrimaryPreferred()
353	case "secondary":
354		return readpref.Secondary()
355	case "secondarypreferred":
356		return readpref.SecondaryPreferred()
357	case "nearest":
358		return readpref.Nearest()
359	}
360	return readpref.Primary()
361}
362
363// create a read preference from a map.
364func createReadPref(opt bson.RawValue) *readpref.ReadPref {
365	mode := opt.Document().Lookup("mode").StringValue()
366	return readPrefFromString(mode)
367}
368
369// transform a slice of BSON documents to a slice of interface{}.
370func rawSliceToInterfaceSlice(docs []bson.Raw) []interface{} {
371	out := make([]interface{}, len(docs))
372
373	for i, doc := range docs {
374		out[i] = doc
375	}
376
377	return out
378}
379
380// transform a BSON raw array to a slice of interface{}.
381func rawArrayToInterfaceSlice(docs bson.Raw) []interface{} {
382	vals, _ := docs.Values()
383
384	out := make([]interface{}, len(vals))
385	for i, val := range vals {
386		out[i] = val.Document()
387	}
388
389	return out
390}
391
392// retrieve the error associated with a result.
393func errorFromResult(t testing.TB, result interface{}) *operationError {
394	t.Helper()
395
396	// embedded doc will be unmarshalled as Raw
397	raw, ok := result.(bson.Raw)
398	if !ok {
399		return nil
400	}
401
402	var expected operationError
403	err := bson.Unmarshal(raw, &expected)
404	if err != nil {
405		return nil
406	}
407	if expected.ErrorCodeName == nil && expected.ErrorContains == nil && len(expected.ErrorLabelsOmit) == 0 &&
408		len(expected.ErrorLabelsContain) == 0 {
409		return nil
410	}
411
412	return &expected
413}
414
415// errorDetails is a helper type that holds information that can be returned by driver functions in different error
416// types.
417type errorDetails struct {
418	name   string
419	labels []string
420}
421
422// extractErrorDetails creates an errorDetails instance based on the provided error. It returns the details and an "ok"
423// value which is true if the provided error is of a known type that can be processed.
424func extractErrorDetails(err error) (errorDetails, bool) {
425	var details errorDetails
426
427	switch converted := err.(type) {
428	case mongo.CommandError:
429		details.name = converted.Name
430		details.labels = converted.Labels
431	case mongo.WriteException:
432		if converted.WriteConcernError != nil {
433			details.name = converted.WriteConcernError.Name
434		}
435		details.labels = converted.Labels
436	case mongo.BulkWriteException:
437		if converted.WriteConcernError != nil {
438			details.name = converted.WriteConcernError.Name
439		}
440		details.labels = converted.Labels
441	default:
442		return errorDetails{}, false
443	}
444
445	return details, true
446}
447
448// verify that an error returned by an operation matches the expected error.
449func verifyError(expected *operationError, actual error) error {
450	// The spec test format doesn't treat ErrNoDocuments or ErrUnacknowledgedWrite as errors, so set actual to nil
451	// to indicate that no error occurred.
452	if actual == mongo.ErrNoDocuments || actual == mongo.ErrUnacknowledgedWrite {
453		actual = nil
454	}
455
456	if expected == nil && actual != nil {
457		return fmt.Errorf("did not expect error but got %v", actual)
458	}
459	if expected != nil && actual == nil {
460		return fmt.Errorf("expected error but got nil")
461	}
462	if expected == nil {
463		return nil
464	}
465
466	// check ErrorContains for all error types
467	if expected.ErrorContains != nil {
468		emsg := strings.ToLower(*expected.ErrorContains)
469		amsg := strings.ToLower(actual.Error())
470		if !strings.Contains(amsg, emsg) {
471			return fmt.Errorf("expected error message %q to contain %q", amsg, emsg)
472		}
473	}
474
475	// Get an errorDetails instance for the error. If this fails but the test has expectations about the error name or
476	// labels, fail because we can't verify them.
477	details, ok := extractErrorDetails(actual)
478	if !ok {
479		if expected.ErrorCodeName != nil || len(expected.ErrorLabelsContain) > 0 || len(expected.ErrorLabelsOmit) > 0 {
480			return fmt.Errorf("failed to extract details from error %v of type %T", actual, actual)
481		}
482		return nil
483	}
484
485	if expected.ErrorCodeName != nil {
486		if *expected.ErrorCodeName != details.name {
487			return fmt.Errorf("expected error name %v, got %v", *expected.ErrorCodeName, details.name)
488		}
489	}
490	for _, label := range expected.ErrorLabelsContain {
491		if !stringSliceContains(details.labels, label) {
492			return fmt.Errorf("expected error %v to contain label %q", actual, label)
493		}
494	}
495	for _, label := range expected.ErrorLabelsOmit {
496		if stringSliceContains(details.labels, label) {
497			return fmt.Errorf("expected error %v to not contain label %q", actual, label)
498		}
499	}
500	return nil
501}
502
503// get the underlying value of i as an int64. returns nil if i is not an int, int32, or int64 type.
504func getIntFromInterface(i interface{}) *int64 {
505	var out int64
506
507	switch v := i.(type) {
508	case int:
509		out = int64(v)
510	case int32:
511		out = int64(v)
512	case int64:
513		out = v
514	case float32:
515		f := float64(v)
516		if math.Floor(f) != f || f > float64(math.MaxInt64) {
517			break
518		}
519
520		out = int64(f)
521	case float64:
522		if math.Floor(v) != v || v > float64(math.MaxInt64) {
523			break
524		}
525
526		out = int64(v)
527	default:
528		return nil
529	}
530
531	return &out
532}
533
534func createCollation(t testing.TB, m bson.Raw) *options.Collation {
535	var collation options.Collation
536	elems, _ := m.Elements()
537
538	for _, elem := range elems {
539		switch elem.Key() {
540		case "locale":
541			collation.Locale = elem.Value().StringValue()
542		case "caseLevel":
543			collation.CaseLevel = elem.Value().Boolean()
544		case "caseFirst":
545			collation.CaseFirst = elem.Value().StringValue()
546		case "strength":
547			collation.Strength = int(elem.Value().Int32())
548		case "numericOrdering":
549			collation.NumericOrdering = elem.Value().Boolean()
550		case "alternate":
551			collation.Alternate = elem.Value().StringValue()
552		case "maxVariable":
553			collation.MaxVariable = elem.Value().StringValue()
554		case "normalization":
555			collation.Normalization = elem.Value().Boolean()
556		case "backwards":
557			collation.Backwards = elem.Value().Boolean()
558		default:
559			t.Fatalf("unrecognized collation option: %v", elem.Key())
560		}
561	}
562	return &collation
563}
564
565func createChangeStreamOptions(t testing.TB, opts bson.Raw) *options.ChangeStreamOptions {
566	t.Helper()
567
568	csOpts := options.ChangeStream()
569	elems, _ := opts.Elements()
570	for _, elem := range elems {
571		key := elem.Key()
572		opt := elem.Value()
573
574		switch key {
575		case "batchSize":
576			csOpts.SetBatchSize(opt.Int32())
577		default:
578			t.Fatalf("unrecognized change stream option: %v", key)
579		}
580	}
581	return csOpts
582}
583
584func convertValueToMilliseconds(t testing.TB, val bson.RawValue) time.Duration {
585	t.Helper()
586
587	int32Val, ok := val.Int32OK()
588	if !ok {
589		t.Fatalf("failed to convert value of type %s to int32", val.Type)
590	}
591	return time.Duration(int32Val) * time.Millisecond
592}
593
594func stringSliceContains(stringSlice []string, target string) bool {
595	for _, str := range stringSlice {
596		if str == target {
597			return true
598		}
599	}
600	return false
601}
602