1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package testutil
8
9import (
10	"context"
11	"fmt"
12	"math"
13	"os"
14	"reflect"
15	"strconv"
16	"strings"
17	"sync"
18	"testing"
19
20	"github.com/stretchr/testify/require"
21	"go.mongodb.org/mongo-driver/event"
22	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
23	"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
24	"go.mongodb.org/mongo-driver/x/mongo/driver/description"
25	"go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
26	"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
27	"go.mongodb.org/mongo-driver/x/mongo/driver/session"
28	"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
29)
30
31var connectionString connstring.ConnString
32var connectionStringOnce sync.Once
33var connectionStringErr error
34var liveTopology *topology.Topology
35var liveSessionPool *session.Pool
36var liveTopologyOnce sync.Once
37var liveTopologyErr error
38var monitoredTopology *topology.Topology
39var monitoredSessionPool *session.Pool
40var monitoredTopologyOnce sync.Once
41var monitoredTopologyErr error
42
43// AddOptionsToURI appends connection string options to a URI.
44func AddOptionsToURI(uri string, opts ...string) string {
45	if !strings.ContainsRune(uri, '?') {
46		if uri[len(uri)-1] != '/' {
47			uri += "/"
48		}
49
50		uri += "?"
51	} else {
52		uri += "&"
53	}
54
55	for _, opt := range opts {
56		uri += opt
57	}
58
59	return uri
60}
61
62// AddTLSConfigToURI checks for the environmental variable indicating that the tests are being run
63// on an SSL-enabled server, and if so, returns a new URI with the necessary configuration.
64func AddTLSConfigToURI(uri string) string {
65	caFile := os.Getenv("MONGO_GO_DRIVER_CA_FILE")
66	if len(caFile) == 0 {
67		return uri
68	}
69
70	return AddOptionsToURI(uri, "ssl=true&sslCertificateAuthorityFile=", caFile)
71}
72
73// AddCompressorToUri checks for the environment variable indicating that the tests are being run with compression
74// enabled. If so, it returns a new URI with the necessary configuration
75func AddCompressorToUri(uri string) string {
76	comp := os.Getenv("MONGO_GO_DRIVER_COMPRESSOR")
77	if len(comp) == 0 {
78		return uri
79	}
80
81	return AddOptionsToURI(uri, "compressors=", comp)
82}
83
84// MonitoredTopology returns a new topology with the command monitor attached
85func MonitoredTopology(t *testing.T, dbName string, monitor *event.CommandMonitor) *topology.Topology {
86	cs := ConnString(t)
87	opts := []topology.Option{
88		topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }),
89		topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption {
90			return append(
91				opts,
92				topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption {
93					return append(
94						opts,
95						topology.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor {
96							return monitor
97						}),
98						topology.WithOCSPCache(func(ocsp.Cache) ocsp.Cache {
99							return ocsp.NewCache()
100						}),
101					)
102				}),
103			)
104		}),
105	}
106
107	monitoredTopology, err := topology.New(opts...)
108	if err != nil {
109		t.Fatal(err)
110	} else {
111		monitoredTopology.Connect()
112
113		err = operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dropDatabase", 1))).
114			Database(dbName).ServerSelector(description.WriteSelector()).Deployment(monitoredTopology).Execute(context.Background())
115
116		require.NoError(t, err)
117	}
118
119	return monitoredTopology
120}
121
122// GlobalMonitoredTopology gets the globally configured topology and attaches a command monitor.
123func GlobalMonitoredTopology(t *testing.T, monitor *event.CommandMonitor) *topology.Topology {
124	cs := ConnString(t)
125	opts := []topology.Option{
126		topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }),
127		topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption {
128			return append(
129				opts,
130				topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption {
131					return append(
132						opts,
133						topology.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor {
134							return monitor
135						}),
136						topology.WithOCSPCache(func(ocsp.Cache) ocsp.Cache {
137							return ocsp.NewCache()
138						}),
139					)
140				}),
141			)
142		}),
143	}
144
145	monitoredTopologyOnce.Do(func() {
146		var err error
147		monitoredTopology, err = topology.New(opts...)
148		if err != nil {
149			monitoredTopologyErr = err
150		} else {
151			monitoredTopology.Connect()
152
153			err = operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dropDatabase", 1))).
154				Database(DBName(t)).ServerSelector(description.WriteSelector()).Deployment(monitoredTopology).Execute(context.Background())
155
156			require.NoError(t, err)
157
158			sub, err := monitoredTopology.Subscribe()
159			require.NoError(t, err)
160			monitoredSessionPool = session.NewPool(sub.Updates)
161		}
162	})
163
164	if monitoredTopologyErr != nil {
165		t.Fatal(monitoredTopologyErr)
166	}
167
168	return monitoredTopology
169}
170
171// GlobalMonitoredSessionPool returns the globally configured session pool.
172// Must be called after GlobalMonitoredTopology()
173func GlobalMonitoredSessionPool() *session.Pool {
174	return monitoredSessionPool
175}
176
177// Topology gets the globally configured topology.
178func Topology(t *testing.T) *topology.Topology {
179	cs := ConnString(t)
180	opts := []topology.Option{
181		topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }),
182		topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption {
183			return append(
184				opts,
185				topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption {
186					return append(
187						opts,
188						topology.WithOCSPCache(func(ocsp.Cache) ocsp.Cache {
189							return ocsp.NewCache()
190						}),
191					)
192				}),
193			)
194		}),
195	}
196
197	liveTopologyOnce.Do(func() {
198		var err error
199		liveTopology, err = topology.New(opts...)
200		if err != nil {
201			liveTopologyErr = err
202		} else {
203			liveTopology.Connect()
204
205			err = operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dropDatabase", 1))).
206				Database(DBName(t)).ServerSelector(description.WriteSelector()).Deployment(liveTopology).Execute(context.Background())
207			require.NoError(t, err)
208
209			sub, err := liveTopology.Subscribe()
210			require.NoError(t, err)
211			liveSessionPool = session.NewPool(sub.Updates)
212		}
213	})
214
215	if liveTopologyErr != nil {
216		t.Fatal(liveTopologyErr)
217	}
218
219	return liveTopology
220}
221
222// SessionPool gets the globally configured session pool. Must be called after Topology().
223func SessionPool() *session.Pool {
224	return liveSessionPool
225}
226
227// TopologyWithConnString takes a connection string and returns a connected
228// topology, or else bails out of testing
229func TopologyWithConnString(t *testing.T, cs connstring.ConnString) *topology.Topology {
230	opts := []topology.Option{
231		topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }),
232		topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption {
233			return append(
234				opts,
235				topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption {
236					return append(
237						opts,
238						topology.WithOCSPCache(func(ocsp.Cache) ocsp.Cache {
239							return ocsp.NewCache()
240						}),
241					)
242				}),
243			)
244		}),
245	}
246
247	topology, err := topology.New(opts...)
248	if err != nil {
249		t.Fatal("Could not construct topology")
250	}
251	err = topology.Connect()
252	if err != nil {
253		t.Fatal("Could not start topology connection")
254	}
255	return topology
256}
257
258// ColName gets a collection name that should be unique
259// to the currently executing test.
260func ColName(t *testing.T) string {
261	// Get this indirectly to avoid copying a mutex
262	v := reflect.Indirect(reflect.ValueOf(t))
263	name := v.FieldByName("name")
264	return name.String()
265}
266
267// ConnString gets the globally configured connection string.
268func ConnString(t *testing.T) connstring.ConnString {
269	connectionStringOnce.Do(func() {
270		connectionString, connectionStringErr = GetConnString()
271		mongodbURI := os.Getenv("MONGODB_URI")
272		if mongodbURI == "" {
273			mongodbURI = "mongodb://localhost:27017"
274		}
275
276		mongodbURI = AddTLSConfigToURI(mongodbURI)
277		mongodbURI = AddCompressorToUri(mongodbURI)
278
279		var err error
280		connectionString, err = connstring.ParseAndValidate(mongodbURI)
281		if err != nil {
282			connectionStringErr = err
283		}
284	})
285	if connectionStringErr != nil {
286		t.Fatal(connectionStringErr)
287	}
288
289	return connectionString
290}
291
292func GetConnString() (connstring.ConnString, error) {
293	mongodbURI := os.Getenv("MONGODB_URI")
294	if mongodbURI == "" {
295		mongodbURI = "mongodb://localhost:27017"
296	}
297
298	mongodbURI = AddTLSConfigToURI(mongodbURI)
299
300	cs, err := connstring.ParseAndValidate(mongodbURI)
301	if err != nil {
302		return connstring.ConnString{}, err
303	}
304
305	return cs, nil
306}
307
308// DBName gets the globally configured database name.
309func DBName(t *testing.T) string {
310	return GetDBName(ConnString(t))
311}
312
313func GetDBName(cs connstring.ConnString) string {
314	if cs.Database != "" {
315		return cs.Database
316	}
317
318	return fmt.Sprintf("mongo-go-driver-%d", os.Getpid())
319}
320
321// Integration should be called at the beginning of integration
322// tests to ensure that they are skipped if integration testing is
323// turned off.
324func Integration(t *testing.T) {
325	if testing.Short() {
326		t.Skip("skipping integration test in short mode")
327	}
328}
329
330// compareVersions compares two version number strings (i.e. positive integers separated by
331// periods). Comparisons are done to the lesser precision of the two versions. For example, 3.2 is
332// considered equal to 3.2.11, whereas 3.2.0 is considered less than 3.2.11.
333//
334// Returns a positive int if version1 is greater than version2, a negative int if version1 is less
335// than version2, and 0 if version1 is equal to version2.
336func CompareVersions(t *testing.T, v1 string, v2 string) int {
337	n1 := strings.Split(v1, ".")
338	n2 := strings.Split(v2, ".")
339
340	for i := 0; i < int(math.Min(float64(len(n1)), float64(len(n2)))); i++ {
341		i1, err := strconv.Atoi(n1[i])
342		require.NoError(t, err)
343
344		i2, err := strconv.Atoi(n2[i])
345		require.NoError(t, err)
346
347		difference := i1 - i2
348		if difference != 0 {
349			return difference
350		}
351	}
352
353	return 0
354}
355