1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package mtest
8
9import (
10	"context"
11	"errors"
12	"fmt"
13	"math"
14	"os"
15	"strconv"
16	"strings"
17	"time"
18
19	"go.mongodb.org/mongo-driver/bson"
20	"go.mongodb.org/mongo-driver/mongo"
21	"go.mongodb.org/mongo-driver/mongo/options"
22	"go.mongodb.org/mongo-driver/mongo/readpref"
23	"go.mongodb.org/mongo-driver/mongo/writeconcern"
24	"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
25	"go.mongodb.org/mongo-driver/x/mongo/driver/description"
26	"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
27)
28
29const (
30	// TestDb specifies the name of default test database.
31	TestDb = "test"
32)
33
34// testContext holds the global context for the integration tests. The testContext members should only be initialized
35// once during the global setup in TestMain. These variables should only be accessed indirectly through MongoTest
36// instances.
37var testContext struct {
38	connString       connstring.ConnString
39	topo             *topology.Topology
40	topoKind         TopologyKind
41	client           *mongo.Client // client used for setup and teardown
42	serverVersion    string
43	authEnabled      bool
44	enterpriseServer bool
45}
46
47func setupClient(cs connstring.ConnString, opts *options.ClientOptions) (*mongo.Client, error) {
48	wcMajority := writeconcern.New(writeconcern.WMajority())
49	return mongo.Connect(Background, opts.ApplyURI(cs.Original).SetWriteConcern(wcMajority))
50}
51
52// Setup initializes the current testing context.
53// This function must only be called one time and must be called before any tests run.
54func Setup() error {
55	var err error
56	testContext.connString, err = getConnString()
57	if err != nil {
58		return fmt.Errorf("error getting connection string: %v", err)
59	}
60	testContext.topo, err = topology.New(topology.WithConnString(func(connstring.ConnString) connstring.ConnString {
61		return testContext.connString
62	}))
63	if err != nil {
64		return fmt.Errorf("error creating topology: %v", err)
65	}
66	if err = testContext.topo.Connect(); err != nil {
67		return fmt.Errorf("error connecting topology: %v", err)
68	}
69
70	testContext.client, err = setupClient(testContext.connString, options.Client())
71	if err != nil {
72		return fmt.Errorf("error connecting test client: %v", err)
73	}
74
75	pingCtx, cancel := context.WithTimeout(Background, 2*time.Second)
76	defer cancel()
77	if err := testContext.client.Ping(pingCtx, readpref.Primary()); err != nil {
78		return fmt.Errorf("ping error: %v; make sure the deployment is running on URI %v", err,
79			testContext.connString.Original)
80	}
81
82	if testContext.serverVersion, err = getServerVersion(); err != nil {
83		return fmt.Errorf("error getting server version: %v", err)
84	}
85
86	switch testContext.topo.Kind() {
87	case description.Single:
88		testContext.topoKind = Single
89	case description.ReplicaSet, description.ReplicaSetWithPrimary, description.ReplicaSetNoPrimary:
90		testContext.topoKind = ReplicaSet
91	case description.Sharded:
92		testContext.topoKind = Sharded
93	}
94
95	if testContext.topoKind == ReplicaSet && compareVersions(testContext.serverVersion, "4.0") >= 0 {
96		err = testContext.client.Database("admin").RunCommand(Background, bson.D{
97			{"setParameter", 1},
98			{"transactionLifetimeLimitSeconds", 3},
99		}).Err()
100		if err != nil {
101			return fmt.Errorf("error setting transactionLifetimeLimitSeconds: %v", err)
102		}
103	}
104
105	testContext.authEnabled = len(os.Getenv("MONGO_GO_DRIVER_CA_FILE")) != 0
106	biRes, err := testContext.client.Database("admin").RunCommand(Background, bson.D{{"buildInfo", 1}}).DecodeBytes()
107	if err != nil {
108		return fmt.Errorf("buildInfo error: %v", err)
109	}
110	modulesRaw, err := biRes.LookupErr("modules")
111	if err == nil {
112		// older server versions don't report "modules" field in buildInfo result
113		modules, _ := modulesRaw.Array().Values()
114		for _, module := range modules {
115			if module.StringValue() == "enterprise" {
116				testContext.enterpriseServer = true
117				break
118			}
119		}
120	}
121	return nil
122}
123
124// Teardown cleans up resources initialized by Setup.
125// This function must be called once after all tests have finished running.
126func Teardown() error {
127	if err := testContext.client.Database(TestDb).Drop(Background); err != nil {
128		return fmt.Errorf("error dropping test database: %v", err)
129	}
130	if err := testContext.client.Disconnect(Background); err != nil {
131		return fmt.Errorf("error disconnecting test client: %v", err)
132	}
133	if err := testContext.topo.Disconnect(Background); err != nil {
134		return fmt.Errorf("error disconnecting test topology: %v", err)
135	}
136	return nil
137}
138
139func getServerVersion() (string, error) {
140	var serverStatus bson.Raw
141	err := testContext.client.Database(TestDb).RunCommand(
142		Background,
143		bson.D{{"serverStatus", 1}},
144	).Decode(&serverStatus)
145	if err != nil {
146		return "", err
147	}
148
149	version, err := serverStatus.LookupErr("version")
150	if err != nil {
151		return "", errors.New("no version string in serverStatus response")
152	}
153
154	return version.StringValue(), nil
155}
156
157// addOptions appends connection string options to a URI.
158func addOptions(uri string, opts ...string) string {
159	if !strings.ContainsRune(uri, '?') {
160		if uri[len(uri)-1] != '/' {
161			uri += "/"
162		}
163
164		uri += "?"
165	} else {
166		uri += "&"
167	}
168
169	for _, opt := range opts {
170		uri += opt
171	}
172
173	return uri
174}
175
176// addTLSConfig checks for the environmental variable indicating that the tests are being run
177// on an SSL-enabled server, and if so, returns a new URI with the necessary configuration.
178func addTLSConfig(uri string) string {
179	caFile := os.Getenv("MONGO_GO_DRIVER_CA_FILE")
180	if len(caFile) == 0 {
181		return uri
182	}
183
184	return addOptions(uri, "ssl=true&sslCertificateAuthorityFile=", caFile)
185}
186
187// addCompressors checks for the environment variable indicating that the tests are being run with compression
188// enabled. If so, it returns a new URI with the necessary configuration
189func addCompressors(uri string) string {
190	comp := os.Getenv("MONGO_GO_DRIVER_COMPRESSOR")
191	if len(comp) == 0 {
192		return uri
193	}
194
195	return addOptions(uri, "compressors=", comp)
196}
197
198// ConnString gets the globally configured connection string.
199func getConnString() (connstring.ConnString, error) {
200	uri := os.Getenv("MONGODB_URI")
201	if uri == "" {
202		uri = "mongodb://localhost:27017"
203	}
204	uri = addTLSConfig(uri)
205	uri = addCompressors(uri)
206	return connstring.Parse(uri)
207}
208
209// compareVersions compares two version number strings (i.e. positive integers separated by
210// periods). Comparisons are done to the lesser precision of the two versions. For example, 3.2 is
211// considered equal to 3.2.11, whereas 3.2.0 is considered less than 3.2.11.
212//
213// Returns a positive int if version1 is greater than version2, a negative int if version1 is less
214// than version2, and 0 if version1 is equal to version2.
215func compareVersions(v1 string, v2 string) int {
216	n1 := strings.Split(v1, ".")
217	n2 := strings.Split(v2, ".")
218
219	for i := 0; i < int(math.Min(float64(len(n1)), float64(len(n2)))); i++ {
220		i1, err := strconv.Atoi(n1[i])
221		if err != nil {
222			return 1
223		}
224
225		i2, err := strconv.Atoi(n2[i])
226		if err != nil {
227			return -1
228		}
229
230		difference := i1 - i2
231		if difference != 0 {
232			return difference
233		}
234	}
235
236	return 0
237}
238