1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package mongo
8
9import (
10	"bytes"
11	"context"
12	"encoding/json"
13	"math"
14	"strings"
15	"testing"
16
17	"github.com/stretchr/testify/require"
18	"go.mongodb.org/mongo-driver/bson"
19	"go.mongodb.org/mongo-driver/internal/testutil/helpers"
20	"go.mongodb.org/mongo-driver/mongo/options"
21	"go.mongodb.org/mongo-driver/mongo/readconcern"
22	"go.mongodb.org/mongo-driver/mongo/writeconcern"
23	"go.mongodb.org/mongo-driver/x/bsonx"
24)
25
26// Various helper functions for crud related operations
27
28// Mutates the client to add options
29func addClientOptions(c *Client, opts map[string]interface{}) {
30	for name, opt := range opts {
31		switch name {
32		case "retryWrites":
33			c.retryWrites = opt.(bool)
34		case "w":
35			switch opt.(type) {
36			case float64:
37				c.writeConcern = writeconcern.New(writeconcern.W(int(opt.(float64))))
38			case string:
39				c.writeConcern = writeconcern.New(writeconcern.WMajority())
40			}
41		case "readConcernLevel":
42			c.readConcern = readconcern.New(readconcern.Level(opt.(string)))
43		case "readPreference":
44			c.readPreference = readPrefFromString(opt.(string))
45		}
46	}
47}
48
49// Mutates the collection to add options
50func addCollectionOptions(c *Collection, opts map[string]interface{}) {
51	for name, opt := range opts {
52		switch name {
53		case "readConcern":
54			c.readConcern = getReadConcern(opt)
55		case "writeConcern":
56			c.writeConcern = getWriteConcern(opt)
57		case "readPreference":
58			c.readPreference = readPrefFromString(opt.(map[string]interface{})["mode"].(string))
59		}
60	}
61}
62
63func executeCount(sess *sessionImpl, coll *Collection, args map[string]interface{}) (int64, error) {
64	var filter map[string]interface{}
65	opts := options.Count()
66	for name, opt := range args {
67		switch name {
68		case "filter":
69			filter = opt.(map[string]interface{})
70		case "skip":
71			opts = opts.SetSkip(int64(opt.(float64)))
72		case "limit":
73			opts = opts.SetLimit(int64(opt.(float64)))
74		case "collation":
75			opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{})))
76		}
77	}
78
79	if sess != nil {
80		// EXAMPLE:
81		sessCtx := sessionContext{
82			Context: context.WithValue(ctx, sessionKey{}, sess),
83			Session: sess,
84		}
85		return coll.CountDocuments(sessCtx, filter, opts)
86	}
87	return coll.CountDocuments(ctx, filter, opts)
88}
89
90func executeCountDocuments(sess *sessionImpl, coll *Collection, args map[string]interface{}) (int64, error) {
91	var filter map[string]interface{}
92	opts := options.Count()
93	for name, opt := range args {
94		switch name {
95		case "filter":
96			filter = opt.(map[string]interface{})
97		case "skip":
98			opts = opts.SetSkip(int64(opt.(float64)))
99		case "limit":
100			opts = opts.SetLimit(int64(opt.(float64)))
101		case "collation":
102			opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{})))
103		}
104	}
105
106	if sess != nil {
107		// EXAMPLE:
108		sessCtx := sessionContext{
109			Context: context.WithValue(ctx, sessionKey{}, sess),
110			Session: sess,
111		}
112		return coll.CountDocuments(sessCtx, filter, opts)
113	}
114	return coll.CountDocuments(ctx, filter, opts)
115}
116
117func executeDistinct(sess *sessionImpl, coll *Collection, args map[string]interface{}) ([]interface{}, error) {
118	var fieldName string
119	var filter map[string]interface{}
120	opts := options.Distinct()
121	for name, opt := range args {
122		switch name {
123		case "filter":
124			filter = opt.(map[string]interface{})
125		case "fieldName":
126			fieldName = opt.(string)
127		case "collation":
128			opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{})))
129		}
130	}
131
132	if sess != nil {
133		sessCtx := sessionContext{
134			Context: context.WithValue(ctx, sessionKey{}, sess),
135			Session: sess,
136		}
137		return coll.Distinct(sessCtx, fieldName, filter, opts)
138	}
139	return coll.Distinct(ctx, fieldName, filter, opts)
140}
141
142func executeInsertOne(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*InsertOneResult, error) {
143	document := args["document"].(map[string]interface{})
144
145	// For some reason, the insertion document is unmarshaled with a float rather than integer,
146	// but the documents that are used to initially populate the collection are unmarshaled
147	// correctly with integers. To ensure that the tests can correctly compare them, we iterate
148	// through the insertion document and change any valid integers stored as floats to actual
149	// integers.
150	replaceFloatsWithInts(document)
151
152	if sess != nil {
153		sessCtx := sessionContext{
154			Context: context.WithValue(ctx, sessionKey{}, sess),
155			Session: sess,
156		}
157		return coll.InsertOne(sessCtx, document)
158	}
159	return coll.InsertOne(context.Background(), document)
160}
161
162func executeInsertMany(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*InsertManyResult, error) {
163	documents := args["documents"].([]interface{})
164
165	// For some reason, the insertion documents are unmarshaled with a float rather than
166	// integer, but the documents that are used to initially populate the collection are
167	// unmarshaled correctly with integers. To ensure that the tests can correctly compare
168	// them, we iterate through the insertion documents and change any valid integers stored
169	// as floats to actual integers.
170	for i, doc := range documents {
171		docM := doc.(map[string]interface{})
172		replaceFloatsWithInts(docM)
173
174		documents[i] = docM
175	}
176
177	if sess != nil {
178		sessCtx := sessionContext{
179			Context: context.WithValue(ctx, sessionKey{}, sess),
180			Session: sess,
181		}
182		return coll.InsertMany(sessCtx, documents)
183	}
184	return coll.InsertMany(context.Background(), documents)
185}
186
187func executeFind(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*Cursor, error) {
188	opts := options.Find()
189	var filter map[string]interface{}
190	for name, opt := range args {
191		switch name {
192		case "filter":
193			filter = opt.(map[string]interface{})
194		case "sort":
195			opts = opts.SetSort(opt.(map[string]interface{}))
196		case "skip":
197			opts = opts.SetSkip(int64(opt.(float64)))
198		case "limit":
199			opts = opts.SetLimit(int64(opt.(float64)))
200		case "batchSize":
201			opts = opts.SetBatchSize(int32(opt.(float64)))
202		case "collation":
203			opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{})))
204		}
205	}
206
207	if sess != nil {
208		sessCtx := sessionContext{
209			Context: context.WithValue(ctx, sessionKey{}, sess),
210			Session: sess,
211		}
212		return coll.Find(sessCtx, filter, opts)
213	}
214	return coll.Find(ctx, filter, opts)
215}
216
217func executeFindOneAndDelete(sess *sessionImpl, coll *Collection, args map[string]interface{}) *SingleResult {
218	opts := options.FindOneAndDelete()
219	var filter map[string]interface{}
220	for name, opt := range args {
221		switch name {
222		case "filter":
223			filter = opt.(map[string]interface{})
224		case "sort":
225			opts = opts.SetSort(opt.(map[string]interface{}))
226		case "projection":
227			opts = opts.SetProjection(opt.(map[string]interface{}))
228		case "collation":
229			opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{})))
230		}
231	}
232
233	if sess != nil {
234		sessCtx := sessionContext{
235			Context: context.WithValue(ctx, sessionKey{}, sess),
236			Session: sess,
237		}
238		return coll.FindOneAndDelete(sessCtx, filter, opts)
239	}
240	return coll.FindOneAndDelete(ctx, filter, opts)
241}
242
243func executeFindOneAndUpdate(sess *sessionImpl, coll *Collection, args map[string]interface{}) *SingleResult {
244	opts := options.FindOneAndUpdate()
245	var filter map[string]interface{}
246	var update map[string]interface{}
247	for name, opt := range args {
248		switch name {
249		case "filter":
250			filter = opt.(map[string]interface{})
251		case "update":
252			update = opt.(map[string]interface{})
253		case "arrayFilters":
254			opts = opts.SetArrayFilters(options.ArrayFilters{
255				Filters: opt.([]interface{}),
256			})
257		case "sort":
258			opts = opts.SetSort(opt.(map[string]interface{}))
259		case "projection":
260			opts = opts.SetProjection(opt.(map[string]interface{}))
261		case "upsert":
262			opts = opts.SetUpsert(opt.(bool))
263		case "returnDocument":
264			switch opt.(string) {
265			case "After":
266				opts = opts.SetReturnDocument(options.After)
267			case "Before":
268				opts = opts.SetReturnDocument(options.Before)
269			}
270		case "collation":
271			opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{})))
272		}
273	}
274
275	// For some reason, the filter and update documents are unmarshaled with floats
276	// rather than integers, but the documents that are used to initially populate the
277	// collection are unmarshaled correctly with integers. To ensure that the tests can
278	// correctly compare them, we iterate through the filter and replacement documents and
279	// change any valid integers stored as floats to actual integers.
280	replaceFloatsWithInts(filter)
281	replaceFloatsWithInts(update)
282
283	if sess != nil {
284		sessCtx := sessionContext{
285			Context: context.WithValue(ctx, sessionKey{}, sess),
286			Session: sess,
287		}
288		return coll.FindOneAndUpdate(sessCtx, filter, update, opts)
289	}
290	return coll.FindOneAndUpdate(ctx, filter, update, opts)
291}
292
293func executeFindOneAndReplace(sess *sessionImpl, coll *Collection, args map[string]interface{}) *SingleResult {
294	opts := options.FindOneAndReplace()
295	var filter map[string]interface{}
296	var replacement map[string]interface{}
297	for name, opt := range args {
298		switch name {
299		case "filter":
300			filter = opt.(map[string]interface{})
301		case "replacement":
302			replacement = opt.(map[string]interface{})
303		case "sort":
304			opts = opts.SetSort(opt.(map[string]interface{}))
305		case "projection":
306			opts = opts.SetProjection(opt.(map[string]interface{}))
307		case "upsert":
308			opts = opts.SetUpsert(opt.(bool))
309		case "returnDocument":
310			switch opt.(string) {
311			case "After":
312				opts = opts.SetReturnDocument(options.After)
313			case "Before":
314				opts = opts.SetReturnDocument(options.Before)
315			}
316		case "collation":
317			opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{})))
318		}
319	}
320
321	// For some reason, the filter and replacement documents are unmarshaled with floats
322	// rather than integers, but the documents that are used to initially populate the
323	// collection are unmarshaled correctly with integers. To ensure that the tests can
324	// correctly compare them, we iterate through the filter and replacement documents and
325	// change any valid integers stored as floats to actual integers.
326	replaceFloatsWithInts(filter)
327	replaceFloatsWithInts(replacement)
328
329	if sess != nil {
330		sessCtx := sessionContext{
331			Context: context.WithValue(ctx, sessionKey{}, sess),
332			Session: sess,
333		}
334		return coll.FindOneAndReplace(sessCtx, filter, replacement, opts)
335	}
336	return coll.FindOneAndReplace(ctx, filter, replacement, opts)
337}
338
339func executeDeleteOne(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*DeleteResult, error) {
340	opts := options.Delete()
341	var filter map[string]interface{}
342	for name, opt := range args {
343		switch name {
344		case "filter":
345			filter = opt.(map[string]interface{})
346		case "collation":
347			opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{})))
348		}
349	}
350
351	// For some reason, the filter document is unmarshaled with floats
352	// rather than integers, but the documents that are used to initially populate the
353	// collection are unmarshaled correctly with integers. To ensure that the tests can
354	// correctly compare them, we iterate through the filter and replacement documents and
355	// change any valid integers stored as floats to actual integers.
356	replaceFloatsWithInts(filter)
357
358	if sess != nil {
359		sessCtx := sessionContext{
360			Context: context.WithValue(ctx, sessionKey{}, sess),
361			Session: sess,
362		}
363		return coll.DeleteOne(sessCtx, filter, opts)
364	}
365	return coll.DeleteOne(ctx, filter, opts)
366}
367
368func executeDeleteMany(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*DeleteResult, error) {
369	opts := options.Delete()
370	var filter map[string]interface{}
371	for name, opt := range args {
372		switch name {
373		case "filter":
374			filter = opt.(map[string]interface{})
375		case "collation":
376			opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{})))
377		}
378	}
379
380	// For some reason, the filter document is unmarshaled with floats
381	// rather than integers, but the documents that are used to initially populate the
382	// collection are unmarshaled correctly with integers. To ensure that the tests can
383	// correctly compare them, we iterate through the filter and replacement documents and
384	// change any valid integers stored as floats to actual integers.
385	replaceFloatsWithInts(filter)
386
387	if sess != nil {
388		sessCtx := sessionContext{
389			Context: context.WithValue(ctx, sessionKey{}, sess),
390			Session: sess,
391		}
392		return coll.DeleteMany(sessCtx, filter, opts)
393	}
394	return coll.DeleteMany(ctx, filter, opts)
395}
396
397func executeReplaceOne(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*UpdateResult, error) {
398	opts := options.Replace()
399	var filter map[string]interface{}
400	var replacement map[string]interface{}
401	for name, opt := range args {
402		switch name {
403		case "filter":
404			filter = opt.(map[string]interface{})
405		case "replacement":
406			replacement = opt.(map[string]interface{})
407		case "upsert":
408			opts = opts.SetUpsert(opt.(bool))
409		case "collation":
410			opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{})))
411		}
412	}
413
414	// For some reason, the filter and replacement documents are unmarshaled with floats
415	// rather than integers, but the documents that are used to initially populate the
416	// collection are unmarshaled correctly with integers. To ensure that the tests can
417	// correctly compare them, we iterate through the filter and replacement documents and
418	// change any valid integers stored as floats to actual integers.
419	replaceFloatsWithInts(filter)
420	replaceFloatsWithInts(replacement)
421
422	// TODO temporarily default upsert to false explicitly to make test pass
423	// because we do not send upsert=false by default
424	//opts = opts.SetUpsert(false)
425	if opts.Upsert == nil {
426		opts = opts.SetUpsert(false)
427	}
428	if sess != nil {
429		sessCtx := sessionContext{
430			Context: context.WithValue(ctx, sessionKey{}, sess),
431			Session: sess,
432		}
433		return coll.ReplaceOne(sessCtx, filter, replacement, opts)
434	}
435	return coll.ReplaceOne(ctx, filter, replacement, opts)
436}
437
438func executeUpdateOne(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*UpdateResult, error) {
439	opts := options.Update()
440	var filter map[string]interface{}
441	var update map[string]interface{}
442	for name, opt := range args {
443		switch name {
444		case "filter":
445			filter = opt.(map[string]interface{})
446		case "update":
447			update = opt.(map[string]interface{})
448		case "arrayFilters":
449			opts = opts.SetArrayFilters(options.ArrayFilters{Filters: opt.([]interface{})})
450		case "upsert":
451			opts = opts.SetUpsert(opt.(bool))
452		case "collation":
453			opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{})))
454		}
455	}
456
457	// For some reason, the filter and update documents are unmarshaled with floats
458	// rather than integers, but the documents that are used to initially populate the
459	// collection are unmarshaled correctly with integers. To ensure that the tests can
460	// correctly compare them, we iterate through the filter and replacement documents and
461	// change any valid integers stored as floats to actual integers.
462	replaceFloatsWithInts(filter)
463	replaceFloatsWithInts(update)
464
465	if opts.Upsert == nil {
466		opts = opts.SetUpsert(false)
467	}
468	if sess != nil {
469		sessCtx := sessionContext{
470			Context: context.WithValue(ctx, sessionKey{}, sess),
471			Session: sess,
472		}
473		return coll.UpdateOne(sessCtx, filter, update, opts)
474	}
475	return coll.UpdateOne(ctx, filter, update, opts)
476}
477
478func executeUpdateMany(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*UpdateResult, error) {
479	opts := options.Update()
480	var filter map[string]interface{}
481	var update map[string]interface{}
482	for name, opt := range args {
483		switch name {
484		case "filter":
485			filter = opt.(map[string]interface{})
486		case "update":
487			update = opt.(map[string]interface{})
488		case "arrayFilters":
489			opts = opts.SetArrayFilters(options.ArrayFilters{Filters: opt.([]interface{})})
490		case "upsert":
491			opts = opts.SetUpsert(opt.(bool))
492		case "collation":
493			opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{})))
494		}
495	}
496
497	// For some reason, the filter and update documents are unmarshaled with floats
498	// rather than integers, but the documents that are used to initially populate the
499	// collection are unmarshaled correctly with integers. To ensure that the tests can
500	// correctly compare them, we iterate through the filter and replacement documents and
501	// change any valid integers stored as floats to actual integers.
502	replaceFloatsWithInts(filter)
503	replaceFloatsWithInts(update)
504
505	if opts.Upsert == nil {
506		opts = opts.SetUpsert(false)
507	}
508	if sess != nil {
509		sessCtx := sessionContext{
510			Context: context.WithValue(ctx, sessionKey{}, sess),
511			Session: sess,
512		}
513		return coll.UpdateMany(sessCtx, filter, update, opts)
514	}
515	return coll.UpdateMany(ctx, filter, update, opts)
516}
517
518func executeAggregate(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*Cursor, error) {
519	var pipeline []interface{}
520	opts := options.Aggregate()
521	for name, opt := range args {
522		switch name {
523		case "pipeline":
524			pipeline = opt.([]interface{})
525		case "batchSize":
526			opts = opts.SetBatchSize(int32(opt.(float64)))
527		case "collation":
528			opts = opts.SetCollation(collationFromMap(opt.(map[string]interface{})))
529		}
530	}
531
532	if sess != nil {
533		sessCtx := sessionContext{
534			Context: context.WithValue(ctx, sessionKey{}, sess),
535			Session: sess,
536		}
537		return coll.Aggregate(sessCtx, pipeline, opts)
538	}
539	return coll.Aggregate(ctx, pipeline, opts)
540}
541
542func executeRunCommand(sess Session, db *Database, argmap map[string]interface{}, args json.RawMessage) *SingleResult {
543	var cmd bsonx.Doc
544	opts := options.RunCmd()
545	for name, opt := range argmap {
546		switch name {
547		case "command":
548			argBytes, err := args.MarshalJSON()
549			if err != nil {
550				return &SingleResult{err: err}
551			}
552
553			var argCmdStruct struct {
554				Cmd json.RawMessage `json:"command"`
555			}
556			err = json.NewDecoder(bytes.NewBuffer(argBytes)).Decode(&argCmdStruct)
557			if err != nil {
558				return &SingleResult{err: err}
559			}
560
561			err = bson.UnmarshalExtJSON(argCmdStruct.Cmd, true, &cmd)
562			if err != nil {
563				return &SingleResult{err: err}
564			}
565		case "readPreference":
566			opts = opts.SetReadPreference(getReadPref(opt))
567		}
568	}
569
570	if sess != nil {
571		sessCtx := sessionContext{
572			Context: context.WithValue(ctx, sessionKey{}, sess),
573			Session: sess,
574		}
575		return db.RunCommand(sessCtx, cmd, opts)
576	}
577	return db.RunCommand(ctx, cmd, opts)
578}
579
580func verifyBulkWriteResult(t *testing.T, res *BulkWriteResult, result json.RawMessage) {
581	expectedBytes, err := result.MarshalJSON()
582	require.NoError(t, err)
583
584	var expected BulkWriteResult
585	err = json.NewDecoder(bytes.NewBuffer(expectedBytes)).Decode(&expected)
586	require.NoError(t, err)
587
588	require.Equal(t, expected.DeletedCount, res.DeletedCount)
589	require.Equal(t, expected.InsertedCount, res.InsertedCount)
590	require.Equal(t, expected.MatchedCount, res.MatchedCount)
591	require.Equal(t, expected.ModifiedCount, res.ModifiedCount)
592	require.Equal(t, expected.UpsertedCount, res.UpsertedCount)
593
594	// replace floats with ints
595	for opID, upsertID := range expected.UpsertedIDs {
596		if floatID, ok := upsertID.(float64); ok {
597			expected.UpsertedIDs[opID] = int32(floatID)
598		}
599	}
600
601	for operationID, upsertID := range expected.UpsertedIDs {
602		require.Equal(t, upsertID, res.UpsertedIDs[operationID])
603	}
604}
605
606func verifyInsertOneResult(t *testing.T, res *InsertOneResult, result json.RawMessage) {
607	expectedBytes, err := result.MarshalJSON()
608	require.NoError(t, err)
609
610	var expected InsertOneResult
611	err = json.NewDecoder(bytes.NewBuffer(expectedBytes)).Decode(&expected)
612	require.NoError(t, err)
613
614	expectedID := expected.InsertedID
615	if f, ok := expectedID.(float64); ok && f == math.Floor(f) {
616		expectedID = int32(f)
617	}
618
619	if expectedID != nil {
620		require.NotNil(t, res)
621		require.Equal(t, expectedID, res.InsertedID)
622	}
623}
624
625func verifyInsertManyResult(t *testing.T, res *InsertManyResult, result json.RawMessage) {
626	expectedBytes, err := result.MarshalJSON()
627	require.NoError(t, err)
628
629	var expected struct{ InsertedIds map[string]interface{} }
630	err = json.NewDecoder(bytes.NewBuffer(expectedBytes)).Decode(&expected)
631	require.NoError(t, err)
632
633	if expected.InsertedIds != nil {
634		replaceFloatsWithInts(expected.InsertedIds)
635
636		for _, val := range expected.InsertedIds {
637			require.Contains(t, res.InsertedIDs, val)
638		}
639	}
640}
641
642func verifyCursorResult2(t *testing.T, cur *Cursor, result json.RawMessage) {
643	for _, expected := range docSliceFromRaw(t, result) {
644		require.NotNil(t, cur)
645		require.True(t, cur.Next(context.Background()))
646
647		var actual bsonx.Doc
648		require.NoError(t, cur.Decode(&actual))
649
650		compareDocs(t, expected, actual)
651	}
652
653	require.False(t, cur.Next(ctx))
654	require.NoError(t, cur.Err())
655}
656
657func verifyCursorResult(t *testing.T, cur *Cursor, result json.RawMessage) {
658	for _, expected := range docSliceFromRaw(t, result) {
659		require.NotNil(t, cur)
660		require.True(t, cur.Next(context.Background()))
661
662		var actual bsonx.Doc
663		require.NoError(t, cur.Decode(&actual))
664
665		compareDocs(t, expected, actual)
666	}
667
668	require.False(t, cur.Next(ctx))
669	require.NoError(t, cur.Err())
670}
671
672func verifySingleResult(t *testing.T, res *SingleResult, result json.RawMessage) {
673	jsonBytes, err := result.MarshalJSON()
674	require.NoError(t, err)
675
676	var actual bsonx.Doc
677	err = res.Decode(&actual)
678	if err == ErrNoDocuments {
679		var expected map[string]interface{}
680		err := json.NewDecoder(bytes.NewBuffer(jsonBytes)).Decode(&expected)
681		require.NoError(t, err)
682
683		require.Nil(t, expected)
684		return
685	}
686
687	require.NoError(t, err)
688
689	doc := bsonx.Doc{}
690	err = bson.UnmarshalExtJSON(jsonBytes, true, &doc)
691	require.NoError(t, err)
692
693	require.True(t, doc.Equal(actual))
694}
695
696func verifyDistinctResult(t *testing.T, res []interface{}, result json.RawMessage) {
697	resultBytes, err := result.MarshalJSON()
698	require.NoError(t, err)
699
700	var expected []interface{}
701	require.NoError(t, json.NewDecoder(bytes.NewBuffer(resultBytes)).Decode(&expected))
702
703	require.Equal(t, len(expected), len(res))
704
705	for i := range expected {
706		expectedElem := expected[i]
707		actualElem := res[i]
708
709		iExpected := testhelpers.GetIntFromInterface(expectedElem)
710		iActual := testhelpers.GetIntFromInterface(actualElem)
711
712		require.Equal(t, iExpected == nil, iActual == nil)
713		if iExpected != nil {
714			require.Equal(t, *iExpected, *iActual)
715			continue
716		}
717
718		require.Equal(t, expected[i], res[i])
719	}
720}
721
722func verifyDeleteResult(t *testing.T, res *DeleteResult, result json.RawMessage) {
723	expectedBytes, err := result.MarshalJSON()
724	require.NoError(t, err)
725
726	var expected DeleteResult
727	err = json.NewDecoder(bytes.NewBuffer(expectedBytes)).Decode(&expected)
728	require.NoError(t, err)
729
730	require.Equal(t, expected.DeletedCount, res.DeletedCount)
731}
732
733func verifyUpdateResult(t *testing.T, res *UpdateResult, result json.RawMessage) {
734	expectedBytes, err := result.MarshalJSON()
735	require.NoError(t, err)
736
737	var expected struct {
738		MatchedCount  int64
739		ModifiedCount int64
740		UpsertedCount int64
741	}
742	err = json.NewDecoder(bytes.NewBuffer(expectedBytes)).Decode(&expected)
743
744	require.Equal(t, expected.MatchedCount, res.MatchedCount)
745	require.Equal(t, expected.ModifiedCount, res.ModifiedCount)
746
747	actualUpsertedCount := int64(0)
748	if res.UpsertedID != nil {
749		actualUpsertedCount = 1
750	}
751
752	require.Equal(t, expected.UpsertedCount, actualUpsertedCount)
753}
754
755func verifyRunCommandResult(t *testing.T, res bson.Raw, result json.RawMessage) {
756	if len(result) == 0 {
757		return
758	}
759	jsonBytes, err := result.MarshalJSON()
760	require.NoError(t, err)
761
762	expected := bsonx.Doc{}
763	err = bson.UnmarshalExtJSON(jsonBytes, true, &expected)
764	require.NoError(t, err)
765
766	require.NotNil(t, res)
767	actual, err := bsonx.ReadDoc(res)
768	require.NoError(t, err)
769
770	// All runcommand results in tests are for key "n" only
771	compareElements(t, expected.LookupElement("n"), actual.LookupElement("n"))
772}
773
774func verifyCollectionContents(t *testing.T, coll *Collection, result json.RawMessage) {
775	cursor, err := coll.Find(context.Background(), bsonx.Doc{})
776	require.NoError(t, err)
777
778	verifyCursorResult(t, cursor, result)
779}
780
781func sanitizeCollectionName(kind string, name string) string {
782	// Collections can't have "$" in their names, so we substitute it with "%".
783	name = strings.Replace(name, "$", "%", -1)
784
785	// Namespaces can only have 120 bytes max.
786	if len(kind+"."+name) >= 119 {
787		name = name[:119-len(kind+".")]
788	}
789
790	return name
791}
792
793func compareElements(t *testing.T, expected bsonx.Elem, actual bsonx.Elem) {
794	if expected.Value.IsNumber() {
795		if expectedNum, ok := expected.Value.Int64OK(); ok {
796			switch actual.Value.Type() {
797			case bson.TypeInt32:
798				require.Equal(t, expectedNum, int64(actual.Value.Int32()), "For key %v", expected.Key)
799			case bson.TypeInt64:
800				require.Equal(t, expectedNum, actual.Value.Int64(), "For key %v\n", expected.Key)
801			case bson.TypeDouble:
802				require.Equal(t, expectedNum, int64(actual.Value.Double()), "For key %v\n", expected.Key)
803			}
804		} else {
805			expectedNum := expected.Value.Int32()
806			switch actual.Value.Type() {
807			case bson.TypeInt32:
808				require.Equal(t, expectedNum, actual.Value.Int32(), "For key %v", expected.Key)
809			case bson.TypeInt64:
810				require.Equal(t, expectedNum, int32(actual.Value.Int64()), "For key %v\n", expected.Key)
811			case bson.TypeDouble:
812				require.Equal(t, expectedNum, int32(actual.Value.Double()), "For key %v\n", expected.Key)
813			}
814		}
815	} else if conv, ok := expected.Value.DocumentOK(); ok {
816		actualConv, actualOk := actual.Value.DocumentOK()
817		require.True(t, actualOk)
818		compareDocs(t, conv, actualConv)
819	} else if conv, ok := expected.Value.ArrayOK(); ok {
820		actualConv, actualOk := actual.Value.ArrayOK()
821		require.True(t, actualOk)
822		compareArrays(t, conv, actualConv)
823	} else {
824		require.True(t, actual.Equal(expected), "For key %s, expected %v\nactual: %v", expected.Key, expected, actual)
825	}
826}
827
828func compareArrays(t *testing.T, expected bsonx.Arr, actual bsonx.Arr) {
829	if len(expected) != len(actual) {
830		t.Errorf("array length mismatch. expected %d got %d", len(expected), len(actual))
831		t.FailNow()
832	}
833
834	for idx := range expected {
835		expectedDoc := expected[idx].Document()
836		actualDoc := actual[idx].Document()
837		compareDocs(t, expectedDoc, actualDoc)
838	}
839}
840
841func collationFromMap(m map[string]interface{}) *options.Collation {
842	var collation options.Collation
843
844	if locale, found := m["locale"]; found {
845		collation.Locale = locale.(string)
846	}
847
848	if caseLevel, found := m["caseLevel"]; found {
849		collation.CaseLevel = caseLevel.(bool)
850	}
851
852	if caseFirst, found := m["caseFirst"]; found {
853		collation.CaseFirst = caseFirst.(string)
854	}
855
856	if strength, found := m["strength"]; found {
857		collation.Strength = int(strength.(float64))
858	}
859
860	if numericOrdering, found := m["numericOrdering"]; found {
861		collation.NumericOrdering = numericOrdering.(bool)
862	}
863
864	if alternate, found := m["alternate"]; found {
865		collation.Alternate = alternate.(string)
866	}
867
868	if maxVariable, found := m["maxVariable"]; found {
869		collation.MaxVariable = maxVariable.(string)
870	}
871
872	if normalization, found := m["normalization"]; found {
873		collation.Normalization = normalization.(bool)
874	}
875
876	if backwards, found := m["backwards"]; found {
877		collation.Backwards = backwards.(bool)
878	}
879
880	return &collation
881}
882
883func docSliceFromRaw(t *testing.T, raw json.RawMessage) []bsonx.Doc {
884	jsonBytes, err := raw.MarshalJSON()
885	require.NoError(t, err)
886
887	array := bsonx.Arr{}
888	err = bson.UnmarshalExtJSON(jsonBytes, true, &array)
889	require.NoError(t, err)
890
891	docs := make([]bsonx.Doc, 0)
892
893	for _, val := range array {
894		docs = append(docs, val.Document())
895	}
896
897	return docs
898}
899
900func docSliceToInterfaceSlice(docs []bsonx.Doc) []interface{} {
901	out := make([]interface{}, 0, len(docs))
902
903	for _, doc := range docs {
904		out = append(out, doc)
905	}
906
907	return out
908}
909
910func replaceFloatsWithInts(m map[string]interface{}) {
911	for key, val := range m {
912		if f, ok := val.(float64); ok && f == math.Floor(f) {
913			m[key] = int32(f)
914			continue
915		}
916
917		if innerM, ok := val.(map[string]interface{}); ok {
918			replaceFloatsWithInts(innerM)
919			m[key] = innerM
920		}
921	}
922}
923