1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package testutil
8
9import (
10	"context"
11	"fmt"
12	"math"
13	"os"
14	"reflect"
15	"strconv"
16	"strings"
17	"sync"
18	"testing"
19
20	"github.com/stretchr/testify/require"
21	"go.mongodb.org/mongo-driver/event"
22	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
23	"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
24	"go.mongodb.org/mongo-driver/x/mongo/driver/description"
25	"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
26	"go.mongodb.org/mongo-driver/x/mongo/driver/session"
27	"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
28)
29
30var connectionString connstring.ConnString
31var connectionStringOnce sync.Once
32var connectionStringErr error
33var liveTopology *topology.Topology
34var liveSessionPool *session.Pool
35var liveTopologyOnce sync.Once
36var liveTopologyErr error
37var monitoredTopology *topology.Topology
38var monitoredSessionPool *session.Pool
39var monitoredTopologyOnce sync.Once
40var monitoredTopologyErr error
41
42// AddOptionsToURI appends connection string options to a URI.
43func AddOptionsToURI(uri string, opts ...string) string {
44	if !strings.ContainsRune(uri, '?') {
45		if uri[len(uri)-1] != '/' {
46			uri += "/"
47		}
48
49		uri += "?"
50	} else {
51		uri += "&"
52	}
53
54	for _, opt := range opts {
55		uri += opt
56	}
57
58	return uri
59}
60
61// AddTLSConfigToURI checks for the environmental variable indicating that the tests are being run
62// on an SSL-enabled server, and if so, returns a new URI with the necessary configuration.
63func AddTLSConfigToURI(uri string) string {
64	caFile := os.Getenv("MONGO_GO_DRIVER_CA_FILE")
65	if len(caFile) == 0 {
66		return uri
67	}
68
69	return AddOptionsToURI(uri, "ssl=true&sslCertificateAuthorityFile=", caFile)
70}
71
72// AddCompressorToUri checks for the environment variable indicating that the tests are being run with compression
73// enabled. If so, it returns a new URI with the necessary configuration
74func AddCompressorToUri(uri string) string {
75	comp := os.Getenv("MONGO_GO_DRIVER_COMPRESSOR")
76	if len(comp) == 0 {
77		return uri
78	}
79
80	return AddOptionsToURI(uri, "compressors=", comp)
81}
82
83// MonitoredTopology returns a new topology with the command monitor attached
84func MonitoredTopology(t *testing.T, dbName string, monitor *event.CommandMonitor) *topology.Topology {
85	cs := ConnString(t)
86	opts := []topology.Option{
87		topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }),
88		topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption {
89			return append(
90				opts,
91				topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption {
92					return append(
93						opts,
94						topology.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor {
95							return monitor
96						}),
97					)
98				}),
99			)
100		}),
101	}
102
103	monitoredTopology, err := topology.New(opts...)
104	if err != nil {
105		t.Fatal(err)
106	} else {
107		monitoredTopology.Connect()
108
109		err = operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dropDatabase", 1))).
110			Database(dbName).ServerSelector(description.WriteSelector()).Deployment(monitoredTopology).Execute(context.Background())
111
112		require.NoError(t, err)
113	}
114
115	return monitoredTopology
116}
117
118// GlobalMonitoredTopology gets the globally configured topology and attaches a command monitor.
119func GlobalMonitoredTopology(t *testing.T, monitor *event.CommandMonitor) *topology.Topology {
120	cs := ConnString(t)
121	opts := []topology.Option{
122		topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }),
123		topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption {
124			return append(
125				opts,
126				topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption {
127					return append(
128						opts,
129						topology.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor {
130							return monitor
131						}),
132					)
133				}),
134			)
135		}),
136	}
137
138	monitoredTopologyOnce.Do(func() {
139		var err error
140		monitoredTopology, err = topology.New(opts...)
141		if err != nil {
142			monitoredTopologyErr = err
143		} else {
144			monitoredTopology.Connect()
145
146			err = operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dropDatabase", 1))).
147				Database(DBName(t)).ServerSelector(description.WriteSelector()).Deployment(monitoredTopology).Execute(context.Background())
148
149			require.NoError(t, err)
150
151			sub, err := monitoredTopology.Subscribe()
152			require.NoError(t, err)
153			monitoredSessionPool = session.NewPool(sub.Updates)
154		}
155	})
156
157	if monitoredTopologyErr != nil {
158		t.Fatal(monitoredTopologyErr)
159	}
160
161	return monitoredTopology
162}
163
164// GlobalMonitoredSessionPool returns the globally configured session pool.
165// Must be called after GlobalMonitoredTopology()
166func GlobalMonitoredSessionPool() *session.Pool {
167	return monitoredSessionPool
168}
169
170// Topology gets the globally configured topology.
171func Topology(t *testing.T) *topology.Topology {
172	cs := ConnString(t)
173	liveTopologyOnce.Do(func() {
174		var err error
175		liveTopology, err = topology.New(topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }))
176		if err != nil {
177			liveTopologyErr = err
178		} else {
179			liveTopology.Connect()
180
181			err = operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dropDatabase", 1))).
182				Database(DBName(t)).ServerSelector(description.WriteSelector()).Deployment(liveTopology).Execute(context.Background())
183			require.NoError(t, err)
184
185			sub, err := liveTopology.Subscribe()
186			require.NoError(t, err)
187			liveSessionPool = session.NewPool(sub.Updates)
188		}
189	})
190
191	if liveTopologyErr != nil {
192		t.Fatal(liveTopologyErr)
193	}
194
195	return liveTopology
196}
197
198// SessionPool gets the globally configured session pool. Must be called after Topology().
199func SessionPool() *session.Pool {
200	return liveSessionPool
201}
202
203// TopologyWithConnString takes a connection string and returns a connected
204// topology, or else bails out of testing
205func TopologyWithConnString(t *testing.T, cs connstring.ConnString) *topology.Topology {
206	topology, err := topology.New(topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }))
207	if err != nil {
208		t.Fatal("Could not construct topology")
209	}
210	err = topology.Connect()
211	if err != nil {
212		t.Fatal("Could not start topology connection")
213	}
214	return topology
215}
216
217// ColName gets a collection name that should be unique
218// to the currently executing test.
219func ColName(t *testing.T) string {
220	// Get this indirectly to avoid copying a mutex
221	v := reflect.Indirect(reflect.ValueOf(t))
222	name := v.FieldByName("name")
223	return name.String()
224}
225
226// ConnString gets the globally configured connection string.
227func ConnString(t *testing.T) connstring.ConnString {
228	connectionStringOnce.Do(func() {
229		connectionString, connectionStringErr = GetConnString()
230		mongodbURI := os.Getenv("MONGODB_URI")
231		if mongodbURI == "" {
232			mongodbURI = "mongodb://localhost:27017"
233		}
234
235		mongodbURI = AddTLSConfigToURI(mongodbURI)
236		mongodbURI = AddCompressorToUri(mongodbURI)
237
238		var err error
239		connectionString, err = connstring.Parse(mongodbURI)
240		if err != nil {
241			connectionStringErr = err
242		}
243	})
244	if connectionStringErr != nil {
245		t.Fatal(connectionStringErr)
246	}
247
248	return connectionString
249}
250
251func GetConnString() (connstring.ConnString, error) {
252	mongodbURI := os.Getenv("MONGODB_URI")
253	if mongodbURI == "" {
254		mongodbURI = "mongodb://localhost:27017"
255	}
256
257	mongodbURI = AddTLSConfigToURI(mongodbURI)
258
259	cs, err := connstring.Parse(mongodbURI)
260	if err != nil {
261		return connstring.ConnString{}, err
262	}
263
264	return cs, nil
265}
266
267// DBName gets the globally configured database name.
268func DBName(t *testing.T) string {
269	return GetDBName(ConnString(t))
270}
271
272func GetDBName(cs connstring.ConnString) string {
273	if cs.Database != "" {
274		return cs.Database
275	}
276
277	return fmt.Sprintf("mongo-go-driver-%d", os.Getpid())
278}
279
280// Integration should be called at the beginning of integration
281// tests to ensure that they are skipped if integration testing is
282// turned off.
283func Integration(t *testing.T) {
284	if testing.Short() {
285		t.Skip("skipping integration test in short mode")
286	}
287}
288
289// compareVersions compares two version number strings (i.e. positive integers separated by
290// periods). Comparisons are done to the lesser precision of the two versions. For example, 3.2 is
291// considered equal to 3.2.11, whereas 3.2.0 is considered less than 3.2.11.
292//
293// Returns a positive int if version1 is greater than version2, a negative int if version1 is less
294// than version2, and 0 if version1 is equal to version2.
295func CompareVersions(t *testing.T, v1 string, v2 string) int {
296	n1 := strings.Split(v1, ".")
297	n2 := strings.Split(v2, ".")
298
299	for i := 0; i < int(math.Min(float64(len(n1)), float64(len(n2)))); i++ {
300		i1, err := strconv.Atoi(n1[i])
301		require.NoError(t, err)
302
303		i2, err := strconv.Atoi(n2[i])
304		require.NoError(t, err)
305
306		difference := i1 - i2
307		if difference != 0 {
308			return difference
309		}
310	}
311
312	return 0
313}
314