1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package mongo
8
9import (
10	"bytes"
11	"io/ioutil"
12	"path"
13	"reflect"
14	"testing"
15	"time"
16
17	"go.mongodb.org/mongo-driver/bson"
18	"go.mongodb.org/mongo-driver/bson/bsontype"
19	"go.mongodb.org/mongo-driver/internal/testutil/assert"
20	"go.mongodb.org/mongo-driver/mongo/readconcern"
21	"go.mongodb.org/mongo-driver/mongo/writeconcern"
22	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
23	"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
24)
25
26const (
27	readWriteConcernTestsDir = "../data/read-write-concern"
28	connstringTestsDir       = "connection-string"
29	documentTestsDir         = "document"
30)
31
32var (
33	serverDefaultConcern = []byte{5, 0, 0, 0, 0} // server default read concern and write concern is empty document
34	specTestRegistry     = bson.NewRegistryBuilder().
35				RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})).Build()
36)
37
38type connectionStringTestFile struct {
39	Tests []connectionStringTest `bson:"tests"`
40}
41
42type connectionStringTest struct {
43	Description  string   `bson:"description"`
44	URI          string   `bson:"uri"`
45	Valid        bool     `bson:"valid"`
46	ReadConcern  bson.Raw `bson:"readConcern"`
47	WriteConcern bson.Raw `bson:"writeConcern"`
48}
49
50type documentTestFile struct {
51	Tests []documentTest `bson:"tests"`
52}
53
54type documentTest struct {
55	Description          string    `bson:"description"`
56	Valid                bool      `bson:"valid"`
57	ReadConcern          bson.Raw  `bson:"readConcern"`
58	ReadConcernDocument  *bson.Raw `bson:"readConcernDocument"`
59	WriteConcern         bson.Raw  `bson:"writeConcern"`
60	WriteConcernDocument *bson.Raw `bson:"writeConcernDocument"`
61	IsServerDefault      *bool     `bson:"isServerDefault"`
62	IsAcknowledged       *bool     `bson:"isAcknowledged"`
63}
64
65func TestReadWriteConcernSpec(t *testing.T) {
66	t.Run("connstring", func(t *testing.T) {
67		for _, file := range jsonFilesInDir(t, path.Join(readWriteConcernTestsDir, connstringTestsDir)) {
68			t.Run(file, func(t *testing.T) {
69				runConnectionStringTestFile(t, path.Join(readWriteConcernTestsDir, connstringTestsDir, file))
70			})
71		}
72	})
73	t.Run("document", func(t *testing.T) {
74		for _, file := range jsonFilesInDir(t, path.Join(readWriteConcernTestsDir, documentTestsDir)) {
75			t.Run(file, func(t *testing.T) {
76				runDocumentTestFile(t, path.Join(readWriteConcernTestsDir, documentTestsDir, file))
77			})
78		}
79	})
80}
81
82func runConnectionStringTestFile(t *testing.T, filePath string) {
83	content, err := ioutil.ReadFile(filePath)
84	assert.Nil(t, err, "ReadFile error for %v: %v", filePath, err)
85
86	var testFile connectionStringTestFile
87	err = bson.UnmarshalExtJSONWithRegistry(specTestRegistry, content, false, &testFile)
88	assert.Nil(t, err, "UnmarshalExtJSONWithRegistry error: %v", err)
89
90	for _, test := range testFile.Tests {
91		t.Run(test.Description, func(t *testing.T) {
92			runConnectionStringTest(t, test)
93		})
94	}
95}
96
97func runConnectionStringTest(t *testing.T, test connectionStringTest) {
98	cs, err := connstring.Parse(test.URI)
99	if !test.Valid {
100		assert.NotNil(t, err, "expected Parse error, got nil")
101		return
102	}
103	assert.Nil(t, err, "Parse error: %v", err)
104
105	if test.ReadConcern != nil {
106		expected := readConcernFromRaw(t, test.ReadConcern)
107		assert.Equal(t, expected.GetLevel(), cs.ReadConcernLevel,
108			"expected level %v, got %v", expected.GetLevel(), cs.ReadConcernLevel)
109	}
110	if test.WriteConcern != nil {
111		expectedWc := writeConcernFromRaw(t, test.WriteConcern)
112		if expectedWc.wSet {
113			expected := expectedWc.GetW()
114			if _, ok := expected.(int); ok {
115				assert.True(t, cs.WNumberSet, "expected WNumberSet, got false")
116				assert.Equal(t, expected, cs.WNumber, "expected w value %v, got %v", expected, cs.WNumber)
117			} else {
118				assert.False(t, cs.WNumberSet, "expected WNumberSet to be false, got true")
119				assert.Equal(t, expected, cs.WString, "expected w value %v, got %v", expected, cs.WString)
120			}
121		}
122		if expectedWc.timeoutSet {
123			assert.True(t, cs.WTimeoutSet, "expected WTimeoutSet, got false")
124			assert.Equal(t, expectedWc.GetWTimeout(), cs.WTimeout,
125				"expected timeout value %v, got %v", expectedWc.GetWTimeout(), cs.WTimeout)
126		}
127		if expectedWc.jSet {
128			assert.True(t, cs.JSet, "expected JSet, got false")
129			assert.Equal(t, expectedWc.GetJ(), cs.J, "expected j value %v, got %v", expectedWc.GetJ(), cs.J)
130		}
131	}
132}
133
134func runDocumentTestFile(t *testing.T, filePath string) {
135	content, err := ioutil.ReadFile(filePath)
136	assert.Nil(t, err, "ReadFile error: %v", err)
137
138	var testFile documentTestFile
139	err = bson.UnmarshalExtJSONWithRegistry(specTestRegistry, content, false, &testFile)
140	assert.Nil(t, err, "UnmarshalExtJSONWithRegistry error: %v", err)
141
142	for _, test := range testFile.Tests {
143		t.Run(test.Description, func(t *testing.T) {
144			runDocumentTest(t, test)
145		})
146	}
147}
148
149func runDocumentTest(t *testing.T, test documentTest) {
150	if test.ReadConcern != nil {
151		_, actual, err := readConcernFromRaw(t, test.ReadConcern).MarshalBSONValue()
152		if !test.Valid {
153			assert.NotNil(t, err, "expected MarshalBSONValue error, got nil")
154		} else {
155			assert.Nil(t, err, "MarshalBSONValue error: %v", err)
156			compareDocuments(t, *test.ReadConcernDocument, actual)
157		}
158
159		if test.IsServerDefault != nil {
160			gotServerDefault := bytes.Equal(actual, serverDefaultConcern)
161			assert.Equal(t, *test.IsServerDefault, gotServerDefault, "expected server default read concern, got %s", actual)
162		}
163	}
164	if test.WriteConcern != nil {
165		actualWc := writeConcernFromRaw(t, test.WriteConcern)
166		_, actual, err := actualWc.MarshalBSONValue()
167		if !test.Valid {
168			assert.NotNil(t, err, "expected MarshalBSONValue error, got nil")
169			return
170		}
171		if test.IsAcknowledged != nil {
172			actualAck := actualWc.Acknowledged()
173			assert.Equal(t, *test.IsAcknowledged, actualAck,
174				"expected acknowledged %v, got %v", *test.IsAcknowledged, actualAck)
175		}
176
177		expected := *test.WriteConcernDocument
178		if err == writeconcern.ErrEmptyWriteConcern {
179			elems, _ := expected.Elements()
180			if len(elems) == 0 {
181				assert.NotNil(t, test.IsServerDefault, "expected write concern %s, got empty", expected)
182				assert.True(t, *test.IsServerDefault, "expected write concern %s, got empty", expected)
183				return
184			}
185			if _, jErr := expected.LookupErr("j"); jErr == nil && len(elems) == 1 {
186				return
187			}
188		}
189
190		assert.Nil(t, err, "MarshalBSONValue error: %v", err)
191		if jVal, err := expected.LookupErr("j"); err == nil && !jVal.Boolean() {
192			actual = actual[:len(actual)-1]
193			actual = bsoncore.AppendBooleanElement(actual, "j", false)
194			actual, _ = bsoncore.AppendDocumentEnd(actual, 0)
195		}
196		compareDocuments(t, expected, actual)
197	}
198}
199
200func readConcernFromRaw(t *testing.T, rc bson.Raw) *readconcern.ReadConcern {
201	t.Helper()
202
203	var opts []readconcern.Option
204	elems, _ := rc.Elements()
205	for _, elem := range elems {
206		key := elem.Key()
207		val := elem.Value()
208
209		switch key {
210		case "level":
211			opts = append(opts, readconcern.Level(val.StringValue()))
212		default:
213			t.Fatalf("unrecognized read concern field %v", key)
214		}
215	}
216	return readconcern.New(opts...)
217}
218
219type writeConcern struct {
220	*writeconcern.WriteConcern
221	jSet       bool
222	wSet       bool
223	timeoutSet bool
224}
225
226func writeConcernFromRaw(t *testing.T, wcRaw bson.Raw) writeConcern {
227	var wc writeConcern
228	var opts []writeconcern.Option
229
230	elems, _ := wcRaw.Elements()
231	for _, elem := range elems {
232		key := elem.Key()
233		val := elem.Value()
234
235		switch key {
236		case "w":
237			wc.wSet = true
238			switch val.Type {
239			case bsontype.Int32:
240				w := int(val.Int32())
241				opts = append(opts, writeconcern.W(w))
242			case bsontype.String:
243				opts = append(opts, writeconcern.WTagSet(val.StringValue()))
244			default:
245				t.Fatalf("unexpected type for w: %v", val.Type)
246			}
247		case "wtimeoutMS":
248			wc.timeoutSet = true
249			timeout := time.Duration(val.Int32()) * time.Millisecond
250			opts = append(opts, writeconcern.WTimeout(timeout))
251		case "journal":
252			wc.jSet = true
253			j := val.Boolean()
254			opts = append(opts, writeconcern.J(j))
255		default:
256			t.Fatalf("unrecognized write concern field: %v", key)
257		}
258	}
259
260	wc.WriteConcern = writeconcern.New(opts...)
261	return wc
262}
263
264// generate a slice of all JSON file names in a directory
265func jsonFilesInDir(t *testing.T, dir string) []string {
266	t.Helper()
267
268	files := make([]string, 0)
269
270	entries, err := ioutil.ReadDir(dir)
271	assert.Nil(t, err, "unable to read json file: %v", err)
272
273	for _, entry := range entries {
274		if entry.IsDir() || path.Ext(entry.Name()) != ".json" {
275			continue
276		}
277
278		files = append(files, entry.Name())
279	}
280
281	return files
282}
283