1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7// +build cse
8
9package integration
10
11import (
12	"context"
13	"encoding/base64"
14	"fmt"
15	"io/ioutil"
16	"os"
17	"path/filepath"
18	"runtime"
19	"strings"
20	"testing"
21	"time"
22
23	"go.mongodb.org/mongo-driver/bson"
24	"go.mongodb.org/mongo-driver/bson/primitive"
25	"go.mongodb.org/mongo-driver/event"
26	"go.mongodb.org/mongo-driver/internal/testutil/assert"
27	"go.mongodb.org/mongo-driver/mongo"
28	"go.mongodb.org/mongo-driver/mongo/integration/mtest"
29	"go.mongodb.org/mongo-driver/mongo/options"
30	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
31)
32
33var (
34	localMasterKey = []byte("2x44+xduTaBBkY16Er5DuADaghvS4vwdkg8tpPp3tz6gV01A1CwbD9itQ2HFDgPWOp8eMaC1Oi766JzXZBdBdbdMurdonJ1d")
35)
36
37const (
38	clientEncryptionProseDir      = "../../data/client-side-encryption-prose"
39	deterministicAlgorithm        = "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic"
40	randomAlgorithm               = "AEAD_AES_256_CBC_HMAC_SHA_512-Random"
41	kvNamespace                   = "keyvault.datakeys" // default namespace for the key vault collection
42	keySubtype               byte = 4                   // expected subtype for data keys
43	encryptedValueSubtype    byte = 6                   // expected subtypes for encrypted values
44	cryptMaxBatchSizeBytes        = 2097152             // max bytes in write batch when auto encryption is enabled
45	maxBsonObjSize                = 16777216            // max bytes in BSON object
46)
47
48func TestClientSideEncryptionProse(t *testing.T) {
49	verifyClientSideEncryptionVarsSet(t)
50	mt := mtest.New(t, mtest.NewOptions().MinServerVersion("4.2").Enterprise(true).CreateClient(false))
51	defer mt.Close()
52
53	defaultKvClientOptions := options.Client().ApplyURI(mtest.ClusterURI())
54	fullKmsProvidersMap := map[string]map[string]interface{}{
55		"aws": {
56			"accessKeyId":     awsAccessKeyID,
57			"secretAccessKey": awsSecretAccessKey,
58		},
59		"azure": {
60			"tenantId":     azureTenantID,
61			"clientId":     azureClientID,
62			"clientSecret": azureClientSecret,
63		},
64		"gcp": {
65			"email":      gcpEmail,
66			"privateKey": gcpPrivateKey,
67		},
68		"local": {"key": localMasterKey},
69	}
70
71	mt.RunOpts("data key and double encryption", noClientOpts, func(mt *mtest.T) {
72		// set up options structs
73		schema := bson.D{
74			{"bsonType", "object"},
75			{"properties", bson.D{
76				{"encrypted_placeholder", bson.D{
77					{"encrypt", bson.D{
78						{"keyId", "/placeholder"},
79						{"bsonType", "string"},
80						{"algorithm", "AEAD_AES_256_CBC_HMAC_SHA_512-Random"},
81					}},
82				}},
83			}},
84		}
85		schemaMap := map[string]interface{}{"db.coll": schema}
86		aeo := options.AutoEncryption().
87			SetKmsProviders(fullKmsProvidersMap).
88			SetKeyVaultNamespace(kvNamespace).
89			SetSchemaMap(schemaMap)
90		ceo := options.ClientEncryption().
91			SetKmsProviders(fullKmsProvidersMap).
92			SetKeyVaultNamespace(kvNamespace)
93
94		awsMasterKey := bson.D{
95			{"region", "us-east-1"},
96			{"key", "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"},
97		}
98		azureMasterKey := bson.D{
99			{"keyVaultEndpoint", "key-vault-csfle.vault.azure.net"},
100			{"keyName", "key-name-csfle"},
101		}
102		gcpMasterKey := bson.D{
103			{"projectId", "devprod-drivers"},
104			{"location", "global"},
105			{"keyRing", "key-ring-csfle"},
106			{"keyName", "key-name-csfle"},
107		}
108		testCases := []struct {
109			provider  string
110			masterKey interface{}
111		}{
112			{"local", nil},
113			{"aws", awsMasterKey},
114			{"azure", azureMasterKey},
115			{"gcp", gcpMasterKey},
116		}
117		for _, tc := range testCases {
118			mt.Run(tc.provider, func(mt *mtest.T) {
119				var startedEvents []*event.CommandStartedEvent
120				monitor := &event.CommandMonitor{
121					Started: func(_ context.Context, evt *event.CommandStartedEvent) {
122						startedEvents = append(startedEvents, evt)
123					},
124				}
125				kvClientOpts := options.Client().ApplyURI(mtest.ClusterURI()).SetMonitor(monitor)
126				cpt := setup(mt, aeo, kvClientOpts, ceo)
127				defer cpt.teardown(mt)
128
129				// create data key
130				keyAltName := fmt.Sprintf("%s_altname", tc.provider)
131				dataKeyOpts := options.DataKey().SetKeyAltNames([]string{keyAltName})
132				if tc.masterKey != nil {
133					dataKeyOpts.SetMasterKey(tc.masterKey)
134				}
135				dataKeyID, err := cpt.clientEnc.CreateDataKey(mtest.Background, tc.provider, dataKeyOpts)
136				assert.Nil(mt, err, "CreateDataKey error: %v", err)
137				assert.Equal(mt, keySubtype, dataKeyID.Subtype,
138					"expected data key subtype %v, got %v", keySubtype, dataKeyID.Subtype)
139
140				// assert that the key exists in the key vault
141				cursor, err := cpt.keyVaultColl.Find(mtest.Background, bson.D{{"_id", dataKeyID}})
142				assert.Nil(mt, err, "key vault Find error: %v", err)
143				assert.True(mt, cursor.Next(mtest.Background), "no keys found in key vault")
144				provider := cursor.Current.Lookup("masterKey", "provider").StringValue()
145				assert.Equal(mt, tc.provider, provider, "expected provider %v, got %v", tc.provider, provider)
146				assert.False(mt, cursor.Next(mtest.Background), "unexpected document in key vault: %v", cursor.Current)
147
148				// verify that the key was inserted using write concern majority
149				assert.Equal(mt, 1, len(startedEvents), "expected 1 CommandStartedEvent, got %v", len(startedEvents))
150				evt := startedEvents[0]
151				assert.Equal(mt, "insert", evt.CommandName, "expected command 'insert', got '%v'", evt.CommandName)
152				writeConcernVal, err := evt.Command.LookupErr("writeConcern")
153				assert.Nil(mt, err, "expected writeConcern in command %s", evt.Command)
154				wString := writeConcernVal.Document().Lookup("w").StringValue()
155				assert.Equal(mt, "majority", wString, "expected write concern 'majority', got %v", wString)
156
157				// encrypt a value with the new key by ID
158				valueToEncrypt := fmt.Sprintf("hello %s", tc.provider)
159				rawVal := bson.RawValue{Type: bson.TypeString, Value: bsoncore.AppendString(nil, valueToEncrypt)}
160				encrypted, err := cpt.clientEnc.Encrypt(mtest.Background, rawVal,
161					options.Encrypt().SetAlgorithm(deterministicAlgorithm).SetKeyID(dataKeyID))
162				assert.Nil(mt, err, "Encrypt error while encrypting value by ID: %v", err)
163				assert.Equal(mt, encryptedValueSubtype, encrypted.Subtype,
164					"expected encrypted value subtype %v, got %v", encryptedValueSubtype, encrypted.Subtype)
165
166				// insert an encrypted value. the value shouldn't be encrypted again because it's not in the schema.
167				_, err = cpt.cseColl.InsertOne(mtest.Background, bson.D{{"_id", tc.provider}, {"value", encrypted}})
168				assert.Nil(mt, err, "InsertOne error: %v", err)
169
170				// find the inserted document. the value should be decrypted automatically
171				resBytes, err := cpt.cseColl.FindOne(mtest.Background, bson.D{{"_id", tc.provider}}).DecodeBytes()
172				assert.Nil(mt, err, "Find error: %v", err)
173				foundVal := resBytes.Lookup("value").StringValue()
174				assert.Equal(mt, valueToEncrypt, foundVal, "expected value %v, got %v", valueToEncrypt, foundVal)
175
176				// encrypt a value with an alternate name for the new key
177				altEncrypted, err := cpt.clientEnc.Encrypt(mtest.Background, rawVal,
178					options.Encrypt().SetAlgorithm(deterministicAlgorithm).SetKeyAltName(keyAltName))
179				assert.Nil(mt, err, "Encrypt error while encrypting value by alt key name: %v", err)
180				assert.Equal(mt, encryptedValueSubtype, altEncrypted.Subtype,
181					"expected encrypted value subtype %v, got %v", encryptedValueSubtype, altEncrypted.Subtype)
182				assert.Equal(mt, encrypted.Data, altEncrypted.Data,
183					"expected data %v, got %v", encrypted.Data, altEncrypted.Data)
184
185				// insert an encrypted value for an auto-encrypted field
186				_, err = cpt.cseColl.InsertOne(mtest.Background, bson.D{{"encrypted_placeholder", encrypted}})
187				assert.NotNil(mt, err, "expected InsertOne error, got nil")
188			})
189		}
190	})
191	mt.RunOpts("external key vault", noClientOpts, func(mt *mtest.T) {
192		testCases := []struct {
193			name          string
194			externalVault bool
195		}{
196			{"with external vault", true},
197			{"without external vault", false},
198		}
199
200		for _, tc := range testCases {
201			mt.Run(tc.name, func(mt *mtest.T) {
202				// setup options structs
203				kmsProviders := map[string]map[string]interface{}{
204					"local": {
205						"key": localMasterKey,
206					},
207				}
208				schemaMap := map[string]interface{}{"db.coll": readJSONFile(mt, "external-schema.json")}
209				aeo := options.AutoEncryption().SetKmsProviders(kmsProviders).SetKeyVaultNamespace(kvNamespace).SetSchemaMap(schemaMap)
210				ceo := options.ClientEncryption().SetKmsProviders(kmsProviders).SetKeyVaultNamespace(kvNamespace)
211				kvClientOpts := defaultKvClientOptions
212
213				if tc.externalVault {
214					externalKvOpts := options.Client().ApplyURI(mtest.ClusterURI()).SetAuth(options.Credential{
215						Username: "fake-user",
216						Password: "fake-password",
217					})
218					aeo.SetKeyVaultClientOptions(externalKvOpts)
219					kvClientOpts = externalKvOpts
220				}
221				cpt := setup(mt, aeo, kvClientOpts, ceo)
222				defer cpt.teardown(mt)
223
224				// manually insert data key
225				key := readJSONFile(mt, "external-key.json")
226				_, err := cpt.keyVaultColl.InsertOne(mtest.Background, key)
227				assert.Nil(mt, err, "InsertOne error for data key: %v", err)
228				subtype, data := key.Lookup("_id").Binary()
229				dataKeyID := primitive.Binary{Subtype: subtype, Data: data}
230
231				doc := bson.D{{"encrypted", "test"}}
232				_, insertErr := cpt.cseClient.Database("db").Collection("coll").InsertOne(mtest.Background, doc)
233				rawVal := bson.RawValue{Type: bson.TypeString, Value: bsoncore.AppendString(nil, "test")}
234				_, encErr := cpt.clientEnc.Encrypt(mtest.Background, rawVal,
235					options.Encrypt().SetKeyID(dataKeyID).SetAlgorithm(deterministicAlgorithm))
236
237				if tc.externalVault {
238					assert.NotNil(mt, insertErr, "expected InsertOne auth error, got nil")
239					assert.NotNil(mt, encErr, "expected Encrypt auth error, got nil")
240					assert.True(mt, strings.Contains(insertErr.Error(), "auth error"),
241						"expected InsertOne auth error, got %v", insertErr)
242					assert.True(mt, strings.Contains(encErr.Error(), "auth error"),
243						"expected Encrypt auth error, got %v", insertErr)
244					return
245				}
246				assert.Nil(mt, insertErr, "InsertOne error: %v", insertErr)
247				assert.Nil(mt, encErr, "Encrypt error: %v", err)
248			})
249		}
250	})
251	mt.Run("bson size limits", func(mt *mtest.T) {
252		kmsProviders := map[string]map[string]interface{}{
253			"local": {
254				"key": localMasterKey,
255			},
256		}
257		aeo := options.AutoEncryption().SetKmsProviders(kmsProviders).SetKeyVaultNamespace(kvNamespace)
258		cpt := setup(mt, aeo, nil, nil)
259		defer cpt.teardown(mt)
260
261		// create coll with JSON schema
262		err := mt.Client.Database("db").RunCommand(mtest.Background, bson.D{
263			{"create", "coll"},
264			{"validator", bson.D{
265				{"$jsonSchema", readJSONFile(mt, "limits-schema.json")},
266			}},
267		}).Err()
268		assert.Nil(mt, err, "create error with validator: %v", err)
269
270		// insert key
271		key := readJSONFile(mt, "limits-key.json")
272		_, err = cpt.keyVaultColl.InsertOne(mtest.Background, key)
273		assert.Nil(mt, err, "InsertOne error for key: %v", err)
274
275		var builder2mb, builder16mb strings.Builder
276		for i := 0; i < cryptMaxBatchSizeBytes; i++ {
277			builder2mb.WriteByte('a')
278		}
279		for i := 0; i < maxBsonObjSize; i++ {
280			builder16mb.WriteByte('a')
281		}
282		complete2mbStr := builder2mb.String()
283		complete16mbStr := builder16mb.String()
284
285		// insert a document over 2MiB
286		doc := bson.D{{"over_2mib_under_16mib", complete2mbStr}}
287		_, err = cpt.cseColl.InsertOne(mtest.Background, doc)
288		assert.Nil(mt, err, "InsertOne error for 2MiB document: %v", err)
289
290		str := complete2mbStr[:cryptMaxBatchSizeBytes-2000] // remove last 2000 bytes
291		limitsDoc := readJSONFile(mt, "limits-doc.json")
292
293		// insert a doc smaller than 2MiB that is bigger than 2MiB after encryption
294		var extendedLimitsDoc []byte
295		extendedLimitsDoc = append(extendedLimitsDoc, limitsDoc...)
296		extendedLimitsDoc = extendedLimitsDoc[:len(extendedLimitsDoc)-1] // remove last byte to add new fields
297		extendedLimitsDoc = bsoncore.AppendStringElement(extendedLimitsDoc, "_id", "encryption_exceeds_2mib")
298		extendedLimitsDoc = bsoncore.AppendStringElement(extendedLimitsDoc, "unencrypted", str)
299		extendedLimitsDoc, _ = bsoncore.AppendDocumentEnd(extendedLimitsDoc, 0)
300		_, err = cpt.cseColl.InsertOne(mtest.Background, extendedLimitsDoc)
301		assert.Nil(mt, err, "error inserting extended limits document: %v", err)
302
303		// bulk insert two 2MiB documents, each over 2 MiB
304		// each document should be split into its own batch because the documents are bigger than 2MiB but smaller
305		// than 16MiB
306		cpt.cseStarted = cpt.cseStarted[:0]
307		firstDoc := bson.D{{"_id", "over_2mib_1"}, {"unencrypted", complete2mbStr}}
308		secondDoc := bson.D{{"_id", "over_2mib_2"}, {"unencrypted", complete2mbStr}}
309		_, err = cpt.cseColl.InsertMany(mtest.Background, []interface{}{firstDoc, secondDoc})
310		assert.Nil(mt, err, "InsertMany error for small documents: %v", err)
311		assert.Equal(mt, 2, len(cpt.cseStarted), "expected 2 insert events, got %d", len(cpt.cseStarted))
312
313		// bulk insert two documents
314		str = complete2mbStr[:cryptMaxBatchSizeBytes-20000]
315		firstBulkDoc := make([]byte, len(limitsDoc))
316		copy(firstBulkDoc, limitsDoc)
317		firstBulkDoc = firstBulkDoc[:len(firstBulkDoc)-1] // remove last byte to append new fields
318		firstBulkDoc = bsoncore.AppendStringElement(firstBulkDoc, "_id", "encryption_exceeds_2mib_1")
319		firstBulkDoc = bsoncore.AppendStringElement(firstBulkDoc, "unencrypted", string(str))
320		firstBulkDoc, _ = bsoncore.AppendDocumentEnd(firstBulkDoc, 0)
321
322		secondBulkDoc := make([]byte, len(limitsDoc))
323		copy(secondBulkDoc, limitsDoc)
324		secondBulkDoc = secondBulkDoc[:len(secondBulkDoc)-1] // remove last byte to append new fields
325		secondBulkDoc = bsoncore.AppendStringElement(secondBulkDoc, "_id", "encryption_exceeds_2mib_2")
326		secondBulkDoc = bsoncore.AppendStringElement(secondBulkDoc, "unencrypted", string(str))
327		secondBulkDoc, _ = bsoncore.AppendDocumentEnd(secondBulkDoc, 0)
328
329		cpt.cseStarted = cpt.cseStarted[:0]
330		_, err = cpt.cseColl.InsertMany(mtest.Background, []interface{}{firstBulkDoc, secondBulkDoc})
331		assert.Nil(mt, err, "InsertMany error for large documents: %v", err)
332		assert.Equal(mt, 2, len(cpt.cseStarted), "expected 2 insert events, got %d", len(cpt.cseStarted))
333
334		// insert a document slightly smaller than 16MiB and expect the operation to succeed
335		doc = bson.D{{"_id", "under_16mib"}, {"unencrypted", complete16mbStr[:maxBsonObjSize-2000]}}
336		_, err = cpt.cseColl.InsertOne(mtest.Background, doc)
337		assert.Nil(mt, err, "InsertOne error: %v", err)
338
339		// insert a document over 16MiB and expect the operation to fail
340		var over16mb []byte
341		over16mb = append(over16mb, limitsDoc...)
342		over16mb = over16mb[:len(over16mb)-1] // remove last byte
343		over16mb = bsoncore.AppendStringElement(over16mb, "_id", "encryption_exceeds_16mib")
344		over16mb = bsoncore.AppendStringElement(over16mb, "unencrypted", complete16mbStr[:maxBsonObjSize-2000])
345		over16mb, _ = bsoncore.AppendDocumentEnd(over16mb, 0)
346		_, err = cpt.cseColl.InsertOne(mtest.Background, over16mb)
347		assert.NotNil(mt, err, "expected InsertOne error for document over 16MiB, got nil")
348	})
349	mt.Run("views are prohibited", func(mt *mtest.T) {
350		kmsProviders := map[string]map[string]interface{}{
351			"local": {
352				"key": localMasterKey,
353			},
354		}
355		aeo := options.AutoEncryption().SetKmsProviders(kmsProviders).SetKeyVaultNamespace(kvNamespace)
356		cpt := setup(mt, aeo, nil, nil)
357		defer cpt.teardown(mt)
358
359		// create view on db.coll
360		mt.CreateCollection(mtest.Collection{
361			Name:       "view",
362			DB:         cpt.cseColl.Database().Name(),
363			CreateOpts: bson.D{{"viewOn", "coll"}},
364		}, true)
365
366		view := cpt.cseColl.Database().Collection("view")
367		_, err := view.InsertOne(mtest.Background, bson.D{{"_id", "insert_on_view"}})
368		assert.NotNil(mt, err, "expected InsertOne error on view, got nil")
369		errStr := strings.ToLower(err.Error())
370		viewErrSubstr := "cannot auto encrypt a view"
371		assert.True(mt, strings.Contains(errStr, viewErrSubstr),
372			"expected error '%v' to contain substring '%v'", errStr, viewErrSubstr)
373	})
374	mt.RunOpts("corpus", noClientOpts, func(mt *mtest.T) {
375		corpusSchema := readJSONFile(mt, "corpus-schema.json")
376		localSchemaMap := map[string]interface{}{
377			"db.coll": corpusSchema,
378		}
379		getBaseAutoEncryptionOpts := func() *options.AutoEncryptionOptions {
380			return options.AutoEncryption().
381				SetKmsProviders(fullKmsProvidersMap).
382				SetKeyVaultNamespace(kvNamespace)
383		}
384
385		testCases := []struct {
386			name   string
387			aeo    *options.AutoEncryptionOptions
388			schema bson.Raw // the schema to create the collection. if nil, the collection won't be explicitly created
389		}{
390			{"remote schema", getBaseAutoEncryptionOpts(), corpusSchema},
391			{"local schema", getBaseAutoEncryptionOpts().SetSchemaMap(localSchemaMap), nil},
392		}
393
394		for _, tc := range testCases {
395			mt.Run(tc.name, func(mt *mtest.T) {
396				ceo := options.ClientEncryption().
397					SetKmsProviders(fullKmsProvidersMap).
398					SetKeyVaultNamespace(kvNamespace)
399				cpt := setup(mt, tc.aeo, defaultKvClientOptions, ceo)
400				defer cpt.teardown(mt)
401
402				// create collection with JSON schema
403				if tc.schema != nil {
404					db := cpt.coll.Database()
405					err := db.RunCommand(mtest.Background, bson.D{
406						{"create", "coll"},
407						{"validator", bson.D{
408							{"$jsonSchema", readJSONFile(mt, "corpus-schema.json")},
409						}},
410					}).Err()
411					assert.Nil(mt, err, "create error with validator: %v", err)
412				}
413
414				// Manually insert keys for each KMS provider into the key vault.
415				_, err := cpt.keyVaultColl.InsertMany(mtest.Background, []interface{}{
416					readJSONFile(mt, "corpus-key-local.json"),
417					readJSONFile(mt, "corpus-key-aws.json"),
418					readJSONFile(mt, "corpus-key-azure.json"),
419					readJSONFile(mt, "corpus-key-gcp.json"),
420				})
421				assert.Nil(mt, err, "InsertMany error for key vault: %v", err)
422
423				// read original corpus and recursively copy over each value to new corpus, encrypting certain values
424				// when needed
425				corpus := readJSONFile(mt, "corpus.json")
426				cidx, copied := bsoncore.AppendDocumentStart(nil)
427				elems, _ := corpus.Elements()
428
429				// Keys for top-level non-document elements that should be copied directly.
430				copiedKeys := map[string]struct{}{
431					"_id":           {},
432					"altname_aws":   {},
433					"altname_local": {},
434					"altname_azure": {},
435					"altname_gcp":   {},
436				}
437
438				for _, elem := range elems {
439					key := elem.Key()
440					val := elem.Value()
441
442					if _, ok := copiedKeys[key]; ok {
443						copied = bsoncore.AppendStringElement(copied, key, val.StringValue())
444						continue
445					}
446
447					doc := val.Document()
448					switch method := doc.Lookup("method").StringValue(); method {
449					case "auto":
450						// Copy the value directly because it will be auto-encrypted later.
451						copied = bsoncore.AppendDocumentElement(copied, key, doc)
452						continue
453					case "explicit":
454						// Handled below.
455					default:
456						mt.Fatalf("unrecognized 'method' value %q", method)
457					}
458
459					// explicitly encrypt value
460					algorithm := deterministicAlgorithm
461					if doc.Lookup("algo").StringValue() == "rand" {
462						algorithm = randomAlgorithm
463					}
464					eo := options.Encrypt().SetAlgorithm(algorithm)
465
466					identifier := doc.Lookup("identifier").StringValue()
467					kms := doc.Lookup("kms").StringValue()
468					switch identifier {
469					case "id":
470						var keyID string
471						switch kms {
472						case "local":
473							keyID = "LOCALAAAAAAAAAAAAAAAAA=="
474						case "aws":
475							keyID = "AWSAAAAAAAAAAAAAAAAAAA=="
476						case "azure":
477							keyID = "AZUREAAAAAAAAAAAAAAAAA=="
478						case "gcp":
479							keyID = "GCPAAAAAAAAAAAAAAAAAAA=="
480						default:
481							mt.Fatalf("unrecognized KMS provider %q", kms)
482						}
483
484						keyIDBytes, err := base64.StdEncoding.DecodeString(keyID)
485						assert.Nil(mt, err, "base64 DecodeString error: %v", err)
486						eo.SetKeyID(primitive.Binary{Subtype: 4, Data: keyIDBytes})
487					case "altname":
488						eo.SetKeyAltName(kms) // alt name for a key is the same as the KMS name
489					default:
490						mt.Fatalf("unrecognized identifier: %v", identifier)
491					}
492
493					// iterate over all elements in the document. copy elements directly, except for ones that need to
494					// be encrypted, which should be copied after encryption.
495					var nestedIdx int32
496					nestedIdx, copied = bsoncore.AppendDocumentElementStart(copied, key)
497					docElems, _ := doc.Elements()
498					for _, de := range docElems {
499						deKey := de.Key()
500						deVal := de.Value()
501
502						// element to encrypt has key "value"
503						if deKey != "value" {
504							copied = bsoncore.AppendValueElement(copied, deKey, rawValueToCoreValue(deVal))
505							continue
506						}
507
508						encrypted, err := cpt.clientEnc.Encrypt(mtest.Background, deVal, eo)
509						if !doc.Lookup("allowed").Boolean() {
510							// if allowed is false, encryption should error. in this case, the unencrypted value should be
511							// copied over
512							assert.NotNil(mt, err, "expected error encrypting value for key %v, got nil", key)
513							copied = bsoncore.AppendValueElement(copied, deKey, rawValueToCoreValue(deVal))
514							continue
515						}
516
517						// copy encrypted value
518						assert.Nil(mt, err, "Encrypt error for key %v: %v", key, err)
519						copied = bsoncore.AppendBinaryElement(copied, deKey, encrypted.Subtype, encrypted.Data)
520					}
521					copied, _ = bsoncore.AppendDocumentEnd(copied, nestedIdx)
522				}
523				copied, _ = bsoncore.AppendDocumentEnd(copied, cidx)
524
525				// insert document with encrypted values
526				_, err = cpt.cseColl.InsertOne(mtest.Background, copied)
527				assert.Nil(mt, err, "InsertOne error for corpus document: %v", err)
528
529				// find document using client with encryption and assert it matches original
530				decryptedDoc, err := cpt.cseColl.FindOne(mtest.Background, bson.D{}).DecodeBytes()
531				assert.Nil(mt, err, "Find error with encrypted client: %v", err)
532				assert.Equal(mt, corpus, decryptedDoc, "expected document %v, got %v", corpus, decryptedDoc)
533
534				// find document using a client without encryption enabled and assert fields remain encrypted
535				corpusEncrypted := readJSONFile(mt, "corpus-encrypted.json")
536				foundDoc, err := cpt.coll.FindOne(mtest.Background, bson.D{}).DecodeBytes()
537				assert.Nil(mt, err, "Find error with unencrypted client: %v", err)
538
539				encryptedElems, _ := corpusEncrypted.Elements()
540				for _, encryptedElem := range encryptedElems {
541					// skip non-document fields
542					encryptedDoc, ok := encryptedElem.Value().DocumentOK()
543					if !ok {
544						continue
545					}
546
547					allowed := encryptedDoc.Lookup("allowed").Boolean()
548					expectedKey := encryptedElem.Key()
549					expectedVal := encryptedDoc.Lookup("value")
550					foundVal := foundDoc.Lookup(expectedKey).Document().Lookup("value")
551
552					// for deterministic encryption, the value should be exactly equal
553					// for random encryption, the value should not be equal if allowed is true
554					algo := encryptedDoc.Lookup("algo").StringValue()
555					switch algo {
556					case "det":
557						assert.True(mt, expectedVal.Equal(foundVal),
558							"expected value %v for key %v, got %v", expectedVal, expectedKey, foundVal)
559					case "rand":
560						if allowed {
561							assert.False(mt, expectedVal.Equal(foundVal),
562								"expected values for key %v to be different but were %v", expectedKey, expectedVal)
563						}
564					}
565
566					// if allowed is true, decrypt both values with clientEnc and validate equality
567					if allowed {
568						sub, data := expectedVal.Binary()
569						expectedDecrypted, err := cpt.clientEnc.Decrypt(mtest.Background, primitive.Binary{Subtype: sub, Data: data})
570						assert.Nil(mt, err, "Decrypt error: %v", err)
571						sub, data = foundVal.Binary()
572						actualDecrypted, err := cpt.clientEnc.Decrypt(mtest.Background, primitive.Binary{Subtype: sub, Data: data})
573						assert.Nil(mt, err, "Decrypt error: %v", err)
574
575						assert.True(mt, expectedDecrypted.Equal(actualDecrypted),
576							"expected decrypted value %v for key %v, got %v", expectedDecrypted, expectedKey, actualDecrypted)
577						continue
578					}
579
580					// if allowed is false, validate found value equals the original value in corpus
581					corpusVal := corpus.Lookup(expectedKey).Document().Lookup("value")
582					assert.True(mt, corpusVal.Equal(foundVal),
583						"expected value %v for key %v, got %v", corpusVal, expectedKey, foundVal)
584				}
585			})
586		}
587	})
588	mt.Run("custom endpoint", func(mt *mtest.T) {
589		validKmsProviders := map[string]map[string]interface{}{
590			"aws": {
591				"accessKeyId":     awsAccessKeyID,
592				"secretAccessKey": awsSecretAccessKey,
593			},
594			"azure": {
595				"tenantId":                 azureTenantID,
596				"clientId":                 azureClientID,
597				"clientSecret":             azureClientSecret,
598				"identityPlatformEndpoint": "login.microsoftonline.com:443",
599			},
600			"gcp": {
601				"email":      gcpEmail,
602				"privateKey": gcpPrivateKey,
603				"endpoint":   "oauth2.googleapis.com:443",
604			},
605		}
606		validClientEncryptionOptions := options.ClientEncryption().
607			SetKmsProviders(validKmsProviders).
608			SetKeyVaultNamespace(kvNamespace)
609
610		invalidKmsProviders := map[string]map[string]interface{}{
611			"azure": {
612				"tenantId":                 azureTenantID,
613				"clientId":                 azureClientID,
614				"clientSecret":             azureClientSecret,
615				"identityPlatformEndpoint": "example.com:443",
616			},
617			"gcp": {
618				"email":      gcpEmail,
619				"privateKey": gcpPrivateKey,
620				"endpoint":   "example.com:443",
621			},
622		}
623		invalidClientEncryptionOptions := options.ClientEncryption().
624			SetKmsProviders(invalidKmsProviders).
625			SetKeyVaultNamespace(kvNamespace)
626
627		awsSuccessWithoutEndpoint := map[string]interface{}{
628			"region": "us-east-1",
629			"key":    "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0",
630		}
631		awsSuccessWithEndpoint := map[string]interface{}{
632			"region":   "us-east-1",
633			"key":      "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0",
634			"endpoint": "kms.us-east-1.amazonaws.com",
635		}
636		awsSuccessWithHTTPSEndpoint := map[string]interface{}{
637			"region":   "us-east-1",
638			"key":      "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0",
639			"endpoint": "kms.us-east-1.amazonaws.com:443",
640		}
641		awsFailureConnectionError := map[string]interface{}{
642			"region":   "us-east-1",
643			"key":      "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0",
644			"endpoint": "kms.us-east-1.amazonaws.com:12345",
645		}
646		awsFailureInvalidEndpoint := map[string]interface{}{
647			"region":   "us-east-1",
648			"key":      "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0",
649			"endpoint": "kms.us-east-2.amazonaws.com",
650		}
651		awsFailureParseError := map[string]interface{}{
652			"region":   "us-east-1",
653			"key":      "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0",
654			"endpoint": "example.com",
655		}
656		azure := map[string]interface{}{
657			"keyVaultEndpoint": "key-vault-csfle.vault.azure.net",
658			"keyName":          "key-name-csfle",
659		}
660		gcpSuccess := map[string]interface{}{
661			"projectId": "devprod-drivers",
662			"location":  "global",
663			"keyRing":   "key-ring-csfle",
664			"keyName":   "key-name-csfle",
665			"endpoint":  "cloudkms.googleapis.com:443",
666		}
667		gcpFailure := map[string]interface{}{
668			"projectId": "devprod-drivers",
669			"location":  "global",
670			"keyRing":   "key-ring-csfle",
671			"keyName":   "key-name-csfle",
672			"endpoint":  "example.com:443",
673		}
674
675		testCases := []struct {
676			name                        string
677			provider                    string
678			masterKey                   interface{}
679			errorSubstring              string
680			testInvalidClientEncryption bool
681		}{
682			{"aws success without endpoint", "aws", awsSuccessWithoutEndpoint, "", false},
683			{"aws success with endpoint", "aws", awsSuccessWithEndpoint, "", false},
684			{"aws success with https endpoint", "aws", awsSuccessWithHTTPSEndpoint, "", false},
685			{"aws failure with connection error", "aws", awsFailureConnectionError, "connection refused", false},
686			{"aws failure with wrong endpoint", "aws", awsFailureInvalidEndpoint, "us-east-1", false},
687			{"aws failure with parse error", "aws", awsFailureParseError, "parse error", false},
688			{"azure success", "azure", azure, "", true},
689			{"gcp success", "gcp", gcpSuccess, "", true},
690			{"gcp failure", "gcp", gcpFailure, "Invalid KMS response", false},
691		}
692		for _, tc := range testCases {
693			mt.Run(tc.name, func(mt *mtest.T) {
694				cpt := setup(mt, nil, defaultKvClientOptions, validClientEncryptionOptions)
695				defer cpt.teardown(mt)
696
697				dkOpts := options.DataKey().SetMasterKey(tc.masterKey)
698				createdKey, err := cpt.clientEnc.CreateDataKey(mtest.Background, tc.provider, dkOpts)
699				if tc.errorSubstring != "" {
700					assert.NotNil(mt, err, "expected error, got nil")
701					errSubstr := tc.errorSubstring
702					if runtime.GOOS == "windows" && errSubstr == "connection refused" {
703						// tls.Dial returns an error that does not contain the substring "connection refused"
704						// on Windows machines
705						errSubstr = "No connection could be made because the target machine actively refused it"
706					}
707					assert.True(mt, strings.Contains(err.Error(), errSubstr),
708						"expected error '%s' to contain '%s'", err.Error(), errSubstr)
709					return
710				}
711				assert.Nil(mt, err, "CreateDataKey error: %v", err)
712
713				encOpts := options.Encrypt().SetKeyID(createdKey).SetAlgorithm(deterministicAlgorithm)
714				testVal := bson.RawValue{
715					Type:  bson.TypeString,
716					Value: bsoncore.AppendString(nil, "test"),
717				}
718				encrypted, err := cpt.clientEnc.Encrypt(mtest.Background, testVal, encOpts)
719				assert.Nil(mt, err, "Encrypt error: %v", err)
720				decrypted, err := cpt.clientEnc.Decrypt(mtest.Background, encrypted)
721				assert.Nil(mt, err, "Decrypt error: %v", err)
722				assert.Equal(mt, testVal, decrypted, "expected value %s, got %s", testVal, decrypted)
723
724				if !tc.testInvalidClientEncryption {
725					return
726				}
727
728				invalidClientEncryption, err := mongo.NewClientEncryption(cpt.kvClient, invalidClientEncryptionOptions)
729				assert.Nil(mt, err, "error creating invalidClientEncryption object: %v", err)
730				defer invalidClientEncryption.Close(mtest.Background)
731
732				invalidKeyOpts := options.DataKey().SetMasterKey(tc.masterKey)
733				_, err = invalidClientEncryption.CreateDataKey(mtest.Background, tc.provider, invalidKeyOpts)
734				assert.NotNil(mt, err, "expected CreateDataKey error, got nil")
735				assert.True(mt, strings.Contains(err.Error(), "parse error"),
736					"expected error %v to contain substring 'parse error'", err)
737			})
738		}
739	})
740	mt.RunOpts("bypass mongocryptd spawning", noClientOpts, func(mt *mtest.T) {
741		kmsProviders := map[string]map[string]interface{}{
742			"local": {
743				"key": localMasterKey,
744			},
745		}
746		schemaMap := map[string]interface{}{
747			"db.coll": readJSONFile(mt, "external-schema.json"),
748		}
749
750		// All mongocryptd options use port 27021 instead of the default 27020 to avoid interference with mongocryptd
751		// instances spawned by previous tests.
752		mongocryptdBypassSpawnTrue := map[string]interface{}{
753			"mongocryptdBypassSpawn": true,
754			"mongocryptdURI":         "mongodb://localhost:27021/db?serverSelectionTimeoutMS=1000",
755			"mongocryptdSpawnArgs":   []string{"--pidfilepath=bypass-spawning-mongocryptd.pid", "--port=27021"},
756		}
757		mongocryptdBypassSpawnFalse := map[string]interface{}{
758			"mongocryptdBypassSpawn": false,
759			"mongocryptdSpawnArgs":   []string{"--pidfilepath=bypass-spawning-mongocryptd.pid", "--port=27021"},
760		}
761		mongocryptdBypassSpawnNotSet := map[string]interface{}{
762			"mongocryptdSpawnArgs": []string{"--pidfilepath=bypass-spawning-mongocryptd.pid", "--port=27021"},
763		}
764
765		testCases := []struct {
766			name                    string
767			mongocryptdOpts         map[string]interface{}
768			setBypassAutoEncryption bool
769			bypassAutoEncryption    bool
770		}{
771			{"mongocryptdBypassSpawn only", mongocryptdBypassSpawnTrue, false, false},
772			{"bypassAutoEncryption only", mongocryptdBypassSpawnNotSet, true, true},
773			{"mongocryptdBypassSpawn false, bypassAutoEncryption true", mongocryptdBypassSpawnFalse, true, true},
774			{"mongocryptdBypassSpawn true, bypassAutoEncryption false", mongocryptdBypassSpawnTrue, true, false},
775		}
776		for _, tc := range testCases {
777			mt.Run(tc.name, func(mt *mtest.T) {
778				aeo := options.AutoEncryption().
779					SetKmsProviders(kmsProviders).
780					SetKeyVaultNamespace(kvNamespace).
781					SetSchemaMap(schemaMap).
782					SetExtraOptions(tc.mongocryptdOpts)
783				if tc.setBypassAutoEncryption {
784					aeo.SetBypassAutoEncryption(tc.bypassAutoEncryption)
785				}
786				cpt := setup(mt, aeo, nil, nil)
787				defer cpt.teardown(mt)
788
789				_, err := cpt.cseColl.InsertOne(mtest.Background, bson.D{{"unencrypted", "test"}})
790
791				// Check for mongocryptd server selection error if auto encryption was not bypassed.
792				if !(tc.setBypassAutoEncryption && tc.bypassAutoEncryption) {
793					assert.NotNil(mt, err, "expected InsertOne error, got nil")
794					mcryptErr, ok := err.(mongo.MongocryptdError)
795					assert.True(mt, ok, "expected error type %T, got %v of type %T", mongo.MongocryptdError{}, err, err)
796					assert.True(mt, strings.Contains(mcryptErr.Error(), "server selection error"),
797						"expected mongocryptd server selection error, got %v", err)
798					return
799				}
800
801				// If auto encryption is bypassed, the command should succeed. Create a new client to connect to
802				// mongocryptd and verify it is not running.
803				assert.Nil(mt, err, "InsertOne error: %v", err)
804
805				mcryptOpts := options.Client().ApplyURI("mongodb://localhost:27021").
806					SetServerSelectionTimeout(1 * time.Second)
807				mcryptClient, err := mongo.Connect(mtest.Background, mcryptOpts)
808				assert.Nil(mt, err, "mongocryptd Connect error: %v", err)
809
810				err = mcryptClient.Database("admin").RunCommand(mtest.Background, bson.D{{"ismaster", 1}}).Err()
811				assert.NotNil(mt, err, "expected mongocryptd ismaster error, got nil")
812				assert.True(mt, strings.Contains(err.Error(), "server selection error"),
813					"expected mongocryptd server selection error, got %v", err)
814			})
815		}
816	})
817	changeStreamOpts := mtest.NewOptions().
818		CreateClient(false).
819		Topologies(mtest.ReplicaSet)
820	mt.RunOpts("change streams", changeStreamOpts, func(mt *mtest.T) {
821		// Change streams can't easily fit into the spec test format because of their tailable nature, so there are two
822		// prose tests for them instead:
823		//
824		// 1. Auto-encryption errors for Watch operations. Collection-level change streams error because the
825		// $changeStream aggregation stage is not valid for encryption. Client and database-level streams error because
826		// only collection-level operations are valid for encryption.
827		//
828		// 2. Events are automatically decrypted: If the Watch() is done with BypassAutoEncryption=true, the Watch
829		// should succeed and subsequent getMore calls should decrypt documents when necessary.
830
831		var testConfig struct {
832			JSONSchema        bson.Raw   `bson:"json_schema"`
833			KeyVaultData      []bson.Raw `bson:"key_vault_data"`
834			EncryptedDocument bson.Raw   `bson:"encrypted_document"`
835			DecryptedDocument bson.Raw   `bson:"decrytped_document"`
836		}
837		decodeJSONFile(mt, "change-streams-test.json", &testConfig)
838
839		schemaMap := map[string]interface{}{
840			"db.coll": testConfig.JSONSchema,
841		}
842		kmsProviders := map[string]map[string]interface{}{
843			"aws": {
844				"accessKeyId":     awsAccessKeyID,
845				"secretAccessKey": awsSecretAccessKey,
846			},
847		}
848
849		testCases := []struct {
850			name       string
851			streamType mongo.StreamType
852		}{
853			{"client", mongo.ClientStream},
854			{"database", mongo.DatabaseStream},
855			{"collection", mongo.CollectionStream},
856		}
857		mt.RunOpts("auto encryption errors", noClientOpts, func(mt *mtest.T) {
858			for _, tc := range testCases {
859				mt.Run(tc.name, func(mt *mtest.T) {
860					autoEncryptionOpts := options.AutoEncryption().
861						SetKmsProviders(kmsProviders).
862						SetKeyVaultNamespace(kvNamespace).
863						SetSchemaMap(schemaMap)
864					cpt := setup(mt, autoEncryptionOpts, nil, nil)
865					defer cpt.teardown(mt)
866
867					_, err := getWatcher(mt, tc.streamType, cpt).Watch(mtest.Background, mongo.Pipeline{})
868					assert.NotNil(mt, err, "expected Watch error: %v", err)
869				})
870			}
871		})
872		mt.RunOpts("events are automatically decrypted", noClientOpts, func(mt *mtest.T) {
873			for _, tc := range testCases {
874				mt.Run(tc.name, func(mt *mtest.T) {
875					autoEncryptionOpts := options.AutoEncryption().
876						SetKmsProviders(kmsProviders).
877						SetKeyVaultNamespace(kvNamespace).
878						SetSchemaMap(schemaMap).
879						SetBypassAutoEncryption(true)
880					cpt := setup(mt, autoEncryptionOpts, nil, nil)
881					defer cpt.teardown(mt)
882
883					// Insert key vault data so the key can be accessed when starting the change stream.
884					insertDocuments(mt, cpt.keyVaultColl, testConfig.KeyVaultData)
885
886					stream, err := getWatcher(mt, tc.streamType, cpt).Watch(mtest.Background, mongo.Pipeline{})
887					assert.Nil(mt, err, "Watch error: %v", err)
888					defer stream.Close(mtest.Background)
889
890					// Insert already encrypted data and verify that it is automatically decrypted by Next().
891					insertDocuments(mt, cpt.coll, []bson.Raw{testConfig.EncryptedDocument})
892					assert.True(mt, stream.Next(mtest.Background), "expected Next to return true, got false")
893					gotDocument := stream.Current.Lookup("fullDocument").Document()
894					err = compareDocs(mt, testConfig.DecryptedDocument, gotDocument)
895					assert.Nil(mt, err, "compareDocs error: %v", err)
896				})
897			}
898		})
899	})
900
901	mt.RunOpts("deadlock tests", noClientOpts, func(mt *mtest.T) {
902		testcases := []struct {
903			description                            string
904			maxPoolSize                            uint64
905			bypassAutoEncryption                   bool
906			keyVaultClientSet                      bool
907			clientEncryptedTopologyOpeningExpected int
908			clientEncryptedCommandStartedExpected  []startedEvent
909			clientKeyVaultCommandStartedExpected   []startedEvent
910		}{
911			// In the following comments, "special auto encryption options" refers to the "bypassAutoEncryption" and
912			// "keyVaultClient" options
913			{
914				// If the client has a limited maxPoolSize, and no special auto-encryption options are set, the
915				// driver should create an internal Client for metadata/keyVault operations.
916				"deadlock case 1", 1, false, false, 2,
917				[]startedEvent{{"listCollections", "db"}, {"find", "keyvault"}, {"insert", "db"}, {"find", "db"}},
918				nil,
919			},
920			{
921				// If the client has a limited maxPoolSize, and a keyVaultClient is set, the driver should create
922				// an internal Client for metadata operations.
923				"deadlock case 2", 1, false, true, 2,
924				[]startedEvent{{"listCollections", "db"}, {"insert", "db"}, {"find", "db"}},
925				[]startedEvent{{"find", "keyvault"}},
926			},
927			{
928				// If the client has a limited maxPoolSize, and a bypassAutomaticEncryption=true, the driver should
929				// create an internal Client for keyVault operations.
930				"deadlock case 3", 1, true, false, 2,
931				[]startedEvent{{"find", "db"}, {"find", "keyvault"}},
932				nil,
933			},
934			{
935				// If the client has a limited maxPoolSize, bypassAutomaticEncryption=true, and a keyVaultClient is set,
936				// the driver should not create an internal Client.
937				"deadlock case 4", 1, true, true, 1,
938				[]startedEvent{{"find", "db"}},
939				[]startedEvent{{"find", "keyvault"}},
940			},
941			{
942				// If the client has an unlimited maxPoolSize, and no special auto-encryption options are set,  the
943				// driver should reuse the client for metadata/keyVault operations
944				"deadlock case 5", 0, false, false, 1,
945				[]startedEvent{{"listCollections", "db"}, {"listCollections", "keyvault"}, {"find", "keyvault"}, {"insert", "db"}, {"find", "db"}},
946				nil,
947			},
948			{
949				// If the client has an unlimited maxPoolSize, and a keyVaultClient is set, the driver should reuse the
950				// client for metadata operations.
951				"deadlock case 6", 0, false, true, 1,
952				[]startedEvent{{"listCollections", "db"}, {"insert", "db"}, {"find", "db"}},
953				[]startedEvent{{"find", "keyvault"}},
954			},
955			{
956				// If the client has an unlimited maxPoolSize, and bypassAutomaticEncryption=true, the driver should
957				// reuse the client for keyVault operations
958				"deadlock case 7", 0, true, false, 1,
959				[]startedEvent{{"find", "db"}, {"find", "keyvault"}},
960				nil,
961			},
962			{
963				// If the client has an unlimited maxPoolSize, bypassAutomaticEncryption=true, and a keyVaultClient is
964				// set, the driver should not create an internal Client.
965				"deadlock case 8", 0, true, true, 1,
966				[]startedEvent{{"find", "db"}},
967				[]startedEvent{{"find", "keyvault"}},
968			},
969		}
970
971		for _, tc := range testcases {
972			mt.Run(tc.description, func(mt *mtest.T) {
973				var clientEncryptedEvents []startedEvent
974				var clientEncryptedTopologyOpening int
975
976				d := newDeadlockTest(mt)
977				defer d.disconnect(mt)
978
979				kmsProviders := map[string]map[string]interface{}{
980					"local": {"key": localMasterKey},
981				}
982				aeOpts := options.AutoEncryption()
983				aeOpts.SetKeyVaultNamespace("keyvault.datakeys").
984					SetKmsProviders(kmsProviders).
985					SetBypassAutoEncryption(tc.bypassAutoEncryption)
986				if tc.keyVaultClientSet {
987					aeOpts.SetKeyVaultClientOptions(d.clientKeyVaultOpts)
988				}
989
990				ceOpts := options.Client().ApplyURI(mtest.ClusterURI()).
991					SetMonitor(&event.CommandMonitor{
992						Started: func(ctx context.Context, event *event.CommandStartedEvent) {
993							clientEncryptedEvents = append(clientEncryptedEvents, startedEvent{event.CommandName, event.DatabaseName})
994						},
995					}).
996					SetServerMonitor(&event.ServerMonitor{
997						TopologyOpening: func(event *event.TopologyOpeningEvent) {
998							clientEncryptedTopologyOpening++
999						},
1000					}).
1001					SetMaxPoolSize(tc.maxPoolSize).
1002					SetAutoEncryptionOptions(aeOpts)
1003
1004				clientEncrypted, err := mongo.Connect(mtest.Background, ceOpts)
1005				defer clientEncrypted.Disconnect(mtest.Background)
1006				assert.Nil(mt, err, "Connect error: %v", err)
1007
1008				coll := clientEncrypted.Database("db").Collection("coll")
1009				if !tc.bypassAutoEncryption {
1010					_, err = coll.InsertOne(mtest.Background, bson.M{"_id": 0, "encrypted": "string0"})
1011				} else {
1012					unencryptedColl := d.clientTest.Database("db").Collection("coll")
1013					_, err = unencryptedColl.InsertOne(mtest.Background, bson.M{"_id": 0, "encrypted": d.ciphertext})
1014				}
1015				assert.Nil(mt, err, "InsertOne error: %v", err)
1016
1017				raw, err := coll.FindOne(mtest.Background, bson.M{"_id": 0}).DecodeBytes()
1018				assert.Nil(mt, err, "FindOne error: %v", err)
1019
1020				expected := bsoncore.NewDocumentBuilder().
1021					AppendInt32("_id", 0).
1022					AppendString("encrypted", "string0").
1023					Build()
1024				assert.Equal(mt, bson.Raw(expected), raw, "returned value unequal, expected: %v, got: %v", expected, raw)
1025
1026				assert.Equal(mt, clientEncryptedEvents, tc.clientEncryptedCommandStartedExpected, "mismatched events for clientEncrypted. Expected %v, got %v", clientEncryptedEvents, tc.clientEncryptedCommandStartedExpected)
1027				assert.Equal(mt, d.clientKeyVaultEvents, tc.clientKeyVaultCommandStartedExpected, "mismatched events for clientKeyVault. Expected %v, got %v", d.clientKeyVaultEvents, tc.clientKeyVaultCommandStartedExpected)
1028				assert.Equal(mt, clientEncryptedTopologyOpening, tc.clientEncryptedTopologyOpeningExpected, "wrong number of TopologyOpening events. Expected %v, got %v", tc.clientEncryptedTopologyOpeningExpected, clientEncryptedTopologyOpening)
1029			})
1030		}
1031	})
1032
1033	// These tests only run when a KMS mock server is running on localhost:8000.
1034	mt.RunOpts("kms tls tests", noClientOpts, func(mt *mtest.T) {
1035		kmsTlsTestcase := os.Getenv("KMS_TLS_TESTCASE")
1036		if kmsTlsTestcase == "" {
1037			mt.Skipf("Skipping test as KMS_TLS_TESTCASE is not set")
1038		}
1039
1040		testcases := []struct {
1041			name       string
1042			envValue   string
1043			errMessage string
1044		}{
1045			{
1046				"invalid certificate",
1047				"INVALID_CERT",
1048				"expired",
1049			},
1050			{
1051				"invalid hostname",
1052				"INVALID_HOSTNAME",
1053				"SANs",
1054			},
1055		}
1056
1057		for _, tc := range testcases {
1058			mt.Run(tc.name, func(mt *mtest.T) {
1059				// Only run test if correct KMS mock server is running.
1060				if kmsTlsTestcase != tc.envValue {
1061					mt.Skipf("Skipping test as KMS_TLS_TESTCASE is set to %q, expected %v", kmsTlsTestcase, tc.envValue)
1062				}
1063
1064				ceo := options.ClientEncryption().
1065					SetKmsProviders(fullKmsProvidersMap).
1066					SetKeyVaultNamespace(kvNamespace)
1067				cpt := setup(mt, nil, nil, ceo)
1068				defer cpt.teardown(mt)
1069
1070				_, err := cpt.clientEnc.CreateDataKey(context.Background(), "aws", options.DataKey().SetMasterKey(
1071					bson.D{
1072						{"region", "us-east-1"},
1073						{"key", "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"},
1074						{"endpoint", "mongodb://127.0.0.1:8000"},
1075					},
1076				))
1077				assert.NotNil(mt, err, "expected CreateDataKey error, got nil")
1078				assert.True(mt, strings.Contains(err.Error(), tc.errMessage),
1079					"expected CreateDataKey error to contain %v, got %v", tc.errMessage, err.Error())
1080			})
1081		}
1082	})
1083}
1084
1085func getWatcher(mt *mtest.T, streamType mongo.StreamType, cpt *cseProseTest) watcher {
1086	mt.Helper()
1087
1088	switch streamType {
1089	case mongo.ClientStream:
1090		return cpt.cseClient
1091	case mongo.DatabaseStream:
1092		return cpt.cseColl.Database()
1093	case mongo.CollectionStream:
1094		return cpt.cseColl
1095	default:
1096		mt.Fatalf("unknown stream type %v", streamType)
1097	}
1098	return nil
1099}
1100
1101type cseProseTest struct {
1102	coll         *mongo.Collection // collection db.coll
1103	kvClient     *mongo.Client
1104	keyVaultColl *mongo.Collection
1105	cseClient    *mongo.Client     // encrypted client
1106	cseColl      *mongo.Collection // db.coll with encrypted client
1107	clientEnc    *mongo.ClientEncryption
1108	cseStarted   []*event.CommandStartedEvent
1109}
1110
1111func setup(mt *mtest.T, aeo *options.AutoEncryptionOptions, kvClientOpts *options.ClientOptions,
1112	ceo *options.ClientEncryptionOptions) *cseProseTest {
1113
1114	mt.Helper()
1115	var cpt cseProseTest
1116	var err error
1117
1118	cpt.coll = mt.CreateCollection(mtest.Collection{
1119		Name: "coll",
1120		DB:   "db",
1121		Opts: options.Collection().SetWriteConcern(mtest.MajorityWc),
1122	}, false)
1123	cpt.keyVaultColl = mt.CreateCollection(mtest.Collection{
1124		Name: "datakeys",
1125		DB:   "keyvault",
1126		Opts: options.Collection().SetWriteConcern(mtest.MajorityWc),
1127	}, false)
1128
1129	if aeo != nil {
1130		cseMonitor := &event.CommandMonitor{
1131			Started: func(_ context.Context, evt *event.CommandStartedEvent) {
1132				cpt.cseStarted = append(cpt.cseStarted, evt)
1133			},
1134		}
1135		opts := options.Client().ApplyURI(mtest.ClusterURI()).SetWriteConcern(mtest.MajorityWc).
1136			SetReadPreference(mtest.PrimaryRp).SetAutoEncryptionOptions(aeo).SetMonitor(cseMonitor)
1137		cpt.cseClient, err = mongo.Connect(mtest.Background, opts)
1138		assert.Nil(mt, err, "Connect error for encrypted client: %v", err)
1139		cpt.cseColl = cpt.cseClient.Database("db").Collection("coll")
1140	}
1141	if ceo != nil {
1142		cpt.kvClient, err = mongo.Connect(mtest.Background, kvClientOpts)
1143		assert.Nil(mt, err, "Connect error for ClientEncryption key vault client: %v", err)
1144		cpt.clientEnc, err = mongo.NewClientEncryption(cpt.kvClient, ceo)
1145		assert.Nil(mt, err, "NewClientEncryption error: %v", err)
1146	}
1147	return &cpt
1148}
1149
1150func (cpt *cseProseTest) teardown(mt *mtest.T) {
1151	mt.Helper()
1152
1153	if cpt.cseClient != nil {
1154		_ = cpt.cseClient.Disconnect(mtest.Background)
1155	}
1156	if cpt.clientEnc != nil {
1157		_ = cpt.clientEnc.Close(mtest.Background)
1158	}
1159}
1160
1161func readJSONFile(mt *mtest.T, file string) bson.Raw {
1162	mt.Helper()
1163
1164	content, err := ioutil.ReadFile(filepath.Join(clientEncryptionProseDir, file))
1165	assert.Nil(mt, err, "ReadFile error for %v: %v", file, err)
1166
1167	var doc bson.Raw
1168	err = bson.UnmarshalExtJSON(content, true, &doc)
1169	assert.Nil(mt, err, "UnmarshalExtJSON error for file %v: %v", file, err)
1170	return doc
1171}
1172
1173func decodeJSONFile(mt *mtest.T, file string, val interface{}) bson.Raw {
1174	mt.Helper()
1175
1176	content, err := ioutil.ReadFile(filepath.Join(clientEncryptionProseDir, file))
1177	assert.Nil(mt, err, "ReadFile error for %v: %v", file, err)
1178
1179	var doc bson.Raw
1180	err = bson.UnmarshalExtJSON(content, true, val)
1181	assert.Nil(mt, err, "UnmarshalExtJSON error for file %v: %v", file, err)
1182	return doc
1183}
1184
1185func rawValueToCoreValue(rv bson.RawValue) bsoncore.Value {
1186	return bsoncore.Value{Type: rv.Type, Data: rv.Value}
1187}
1188
1189type deadlockTest struct {
1190	clientTest           *mongo.Client
1191	clientKeyVaultOpts   *options.ClientOptions
1192	clientKeyVaultEvents []startedEvent
1193	clientEncryption     *mongo.ClientEncryption
1194	ciphertext           primitive.Binary
1195}
1196
1197type startedEvent struct {
1198	Command  string
1199	Database string
1200}
1201
1202func newDeadlockTest(mt *mtest.T) *deadlockTest {
1203	mt.Helper()
1204
1205	var d deadlockTest
1206	var err error
1207
1208	clientTestOpts := options.Client().ApplyURI(mtest.ClusterURI()).SetWriteConcern(mtest.MajorityWc)
1209	if d.clientTest, err = mongo.Connect(mtest.Background, clientTestOpts); err != nil {
1210		mt.Fatalf("Connect error: %v", err)
1211	}
1212
1213	clientKeyVaultMonitor := &event.CommandMonitor{
1214		Started: func(ctx context.Context, event *event.CommandStartedEvent) {
1215			startedEvent := startedEvent{event.CommandName, event.DatabaseName}
1216			d.clientKeyVaultEvents = append(d.clientKeyVaultEvents, startedEvent)
1217		},
1218	}
1219
1220	d.clientKeyVaultOpts = options.Client().ApplyURI(mtest.ClusterURI()).
1221		SetMaxPoolSize(1).SetMonitor(clientKeyVaultMonitor)
1222
1223	keyvaultColl := d.clientTest.Database("keyvault").Collection("datakeys")
1224	dataColl := d.clientTest.Database("db").Collection("coll")
1225	err = dataColl.Drop(mtest.Background)
1226	assert.Nil(mt, err, "Drop error for collection db.coll: %v", err)
1227
1228	err = keyvaultColl.Drop(mtest.Background)
1229	assert.Nil(mt, err, "Drop error for key vault collection: %v", err)
1230
1231	keyDoc := readJSONFile(mt, "external-key.json")
1232	_, err = keyvaultColl.InsertOne(mtest.Background, keyDoc)
1233	assert.Nil(mt, err, "InsertOne error into key vault collection: %v", err)
1234
1235	schemaDoc := readJSONFile(mt, "external-schema.json")
1236	createOpts := options.CreateCollection().SetValidator(bson.M{"$jsonSchema": schemaDoc})
1237	err = d.clientTest.Database("db").CreateCollection(mtest.Background, "coll", createOpts)
1238	assert.Nil(mt, err, "CreateCollection error: %v", err)
1239
1240	kmsProviders := map[string]map[string]interface{}{
1241		"local": {"key": localMasterKey},
1242	}
1243	ceOpts := options.ClientEncryption().SetKmsProviders(kmsProviders).SetKeyVaultNamespace("keyvault.datakeys")
1244	d.clientEncryption, err = mongo.NewClientEncryption(d.clientTest, ceOpts)
1245	assert.Nil(mt, err, "NewClientEncryption error: %v", err)
1246
1247	t, value, err := bson.MarshalValue("string0")
1248	assert.Nil(mt, err, "MarshalValue error: %v", err)
1249	in := bson.RawValue{Type: t, Value: value}
1250	eopts := options.Encrypt().SetAlgorithm("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic").SetKeyAltName("local")
1251	d.ciphertext, err = d.clientEncryption.Encrypt(mtest.Background, in, eopts)
1252	assert.Nil(mt, err, "Encrypt error: %v", err)
1253
1254	return &d
1255}
1256
1257func (d *deadlockTest) disconnect(mt *mtest.T) {
1258	mt.Helper()
1259	err := d.clientEncryption.Close(mtest.Background)
1260	assert.Nil(mt, err, "clientEncryption Close error: %v", err)
1261	d.clientTest.Disconnect(mtest.Background)
1262	assert.Nil(mt, err, "clientTest Disconnect error: %v", err)
1263}
1264