1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package mtest
8
9import (
10	"context"
11	"errors"
12	"fmt"
13	"math"
14	"os"
15	"strconv"
16	"strings"
17	"time"
18
19	"go.mongodb.org/mongo-driver/bson"
20	"go.mongodb.org/mongo-driver/mongo"
21	"go.mongodb.org/mongo-driver/mongo/options"
22	"go.mongodb.org/mongo-driver/mongo/readpref"
23	"go.mongodb.org/mongo-driver/mongo/writeconcern"
24	"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
25	"go.mongodb.org/mongo-driver/x/mongo/driver/description"
26	"go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
27	"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
28)
29
30const (
31	// TestDb specifies the name of default test database.
32	TestDb = "test"
33)
34
35// testContext holds the global context for the integration tests. The testContext members should only be initialized
36// once during the global setup in TestMain. These variables should only be accessed indirectly through MongoTest
37// instances.
38var testContext struct {
39	connString       connstring.ConnString
40	topo             *topology.Topology
41	topoKind         TopologyKind
42	client           *mongo.Client // client used for setup and teardown
43	serverVersion    string
44	authEnabled      bool
45	sslEnabled       bool
46	enterpriseServer bool
47}
48
49func setupClient(cs connstring.ConnString, opts *options.ClientOptions) (*mongo.Client, error) {
50	wcMajority := writeconcern.New(writeconcern.WMajority())
51	return mongo.Connect(Background, opts.ApplyURI(cs.Original).SetWriteConcern(wcMajority))
52}
53
54// Setup initializes the current testing context.
55// This function must only be called one time and must be called before any tests run.
56func Setup() error {
57	var err error
58	testContext.connString, err = getConnString()
59	if err != nil {
60		return fmt.Errorf("error getting connection string: %v", err)
61	}
62
63	connectionOpts := []topology.ConnectionOption{
64		topology.WithOCSPCache(func(ocsp.Cache) ocsp.Cache {
65			return ocsp.NewCache()
66		}),
67	}
68	serverOpts := []topology.ServerOption{
69		topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption {
70			return append(opts, connectionOpts...)
71		}),
72	}
73	testContext.topo, err = topology.New(
74		topology.WithConnString(func(connstring.ConnString) connstring.ConnString {
75			return testContext.connString
76		}),
77		topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption {
78			return append(opts, serverOpts...)
79		}),
80	)
81	if err != nil {
82		return fmt.Errorf("error creating topology: %v", err)
83	}
84	if err = testContext.topo.Connect(); err != nil {
85		return fmt.Errorf("error connecting topology: %v", err)
86	}
87
88	testContext.client, err = setupClient(testContext.connString, options.Client())
89	if err != nil {
90		return fmt.Errorf("error connecting test client: %v", err)
91	}
92
93	pingCtx, cancel := context.WithTimeout(Background, 2*time.Second)
94	defer cancel()
95	if err := testContext.client.Ping(pingCtx, readpref.Primary()); err != nil {
96		return fmt.Errorf("ping error: %v; make sure the deployment is running on URI %v", err,
97			testContext.connString.Original)
98	}
99
100	if testContext.serverVersion, err = getServerVersion(); err != nil {
101		return fmt.Errorf("error getting server version: %v", err)
102	}
103
104	switch testContext.topo.Kind() {
105	case description.Single:
106		testContext.topoKind = Single
107	case description.ReplicaSet, description.ReplicaSetWithPrimary, description.ReplicaSetNoPrimary:
108		testContext.topoKind = ReplicaSet
109	case description.Sharded:
110		testContext.topoKind = Sharded
111	default:
112		return fmt.Errorf("could not detect topology kind; current topology: %s", testContext.topo.String())
113	}
114
115	if testContext.topoKind == ReplicaSet && CompareServerVersions(testContext.serverVersion, "4.0") >= 0 {
116		err = testContext.client.Database("admin").RunCommand(Background, bson.D{
117			{"setParameter", 1},
118			{"transactionLifetimeLimitSeconds", 3},
119		}).Err()
120		if err != nil {
121			return fmt.Errorf("error setting transactionLifetimeLimitSeconds: %v", err)
122		}
123	}
124
125	testContext.authEnabled = os.Getenv("AUTH") == "auth"
126	testContext.sslEnabled = os.Getenv("SSL") == "ssl"
127	biRes, err := testContext.client.Database("admin").RunCommand(Background, bson.D{{"buildInfo", 1}}).DecodeBytes()
128	if err != nil {
129		return fmt.Errorf("buildInfo error: %v", err)
130	}
131	modulesRaw, err := biRes.LookupErr("modules")
132	if err == nil {
133		// older server versions don't report "modules" field in buildInfo result
134		modules, _ := modulesRaw.Array().Values()
135		for _, module := range modules {
136			if module.StringValue() == "enterprise" {
137				testContext.enterpriseServer = true
138				break
139			}
140		}
141	}
142	return nil
143}
144
145// Teardown cleans up resources initialized by Setup.
146// This function must be called once after all tests have finished running.
147func Teardown() error {
148	if err := testContext.client.Database(TestDb).Drop(Background); err != nil {
149		return fmt.Errorf("error dropping test database: %v", err)
150	}
151	if err := testContext.client.Disconnect(Background); err != nil {
152		return fmt.Errorf("error disconnecting test client: %v", err)
153	}
154	if err := testContext.topo.Disconnect(Background); err != nil {
155		return fmt.Errorf("error disconnecting test topology: %v", err)
156	}
157	return nil
158}
159
160func getServerVersion() (string, error) {
161	var serverStatus bson.Raw
162	err := testContext.client.Database(TestDb).RunCommand(
163		Background,
164		bson.D{{"serverStatus", 1}},
165	).Decode(&serverStatus)
166	if err != nil {
167		return "", err
168	}
169
170	version, err := serverStatus.LookupErr("version")
171	if err != nil {
172		return "", errors.New("no version string in serverStatus response")
173	}
174
175	return version.StringValue(), nil
176}
177
178// addOptions appends connection string options to a URI.
179func addOptions(uri string, opts ...string) string {
180	if !strings.ContainsRune(uri, '?') {
181		if uri[len(uri)-1] != '/' {
182			uri += "/"
183		}
184
185		uri += "?"
186	} else {
187		uri += "&"
188	}
189
190	for _, opt := range opts {
191		uri += opt
192	}
193
194	return uri
195}
196
197// addTLSConfig checks for the environmental variable indicating that the tests are being run
198// on an SSL-enabled server, and if so, returns a new URI with the necessary configuration.
199func addTLSConfig(uri string) string {
200	caFile := os.Getenv("MONGO_GO_DRIVER_CA_FILE")
201	if len(caFile) == 0 {
202		return uri
203	}
204
205	return addOptions(uri, "ssl=true&sslCertificateAuthorityFile=", caFile)
206}
207
208// addCompressors checks for the environment variable indicating that the tests are being run with compression
209// enabled. If so, it returns a new URI with the necessary configuration
210func addCompressors(uri string) string {
211	comp := os.Getenv("MONGO_GO_DRIVER_COMPRESSOR")
212	if len(comp) == 0 {
213		return uri
214	}
215
216	return addOptions(uri, "compressors=", comp)
217}
218
219// ConnString gets the globally configured connection string.
220func getConnString() (connstring.ConnString, error) {
221	uri := os.Getenv("MONGODB_URI")
222	if uri == "" {
223		uri = "mongodb://localhost:27017"
224	}
225	uri = addTLSConfig(uri)
226	uri = addCompressors(uri)
227	return connstring.ParseAndValidate(uri)
228}
229
230// CompareServerVersions compares two version number strings (i.e. positive integers separated by
231// periods). Comparisons are done to the lesser precision of the two versions. For example, 3.2 is
232// considered equal to 3.2.11, whereas 3.2.0 is considered less than 3.2.11.
233//
234// Returns a positive int if version1 is greater than version2, a negative int if version1 is less
235// than version2, and 0 if version1 is equal to version2.
236func CompareServerVersions(v1 string, v2 string) int {
237	n1 := strings.Split(v1, ".")
238	n2 := strings.Split(v2, ".")
239
240	for i := 0; i < int(math.Min(float64(len(n1)), float64(len(n2)))); i++ {
241		i1, err := strconv.Atoi(n1[i])
242		if err != nil {
243			return 1
244		}
245
246		i2, err := strconv.Atoi(n2[i])
247		if err != nil {
248			return -1
249		}
250
251		difference := i1 - i2
252		if difference != 0 {
253			return difference
254		}
255	}
256
257	return 0
258}
259