1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package mtest 8 9import ( 10 "context" 11 "errors" 12 "fmt" 13 "math" 14 "os" 15 "strconv" 16 "strings" 17 "time" 18 19 "go.mongodb.org/mongo-driver/bson" 20 "go.mongodb.org/mongo-driver/mongo" 21 "go.mongodb.org/mongo-driver/mongo/options" 22 "go.mongodb.org/mongo-driver/mongo/readpref" 23 "go.mongodb.org/mongo-driver/mongo/writeconcern" 24 "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" 25 "go.mongodb.org/mongo-driver/x/mongo/driver/description" 26 "go.mongodb.org/mongo-driver/x/mongo/driver/ocsp" 27 "go.mongodb.org/mongo-driver/x/mongo/driver/topology" 28) 29 30const ( 31 // TestDb specifies the name of default test database. 32 TestDb = "test" 33) 34 35// testContext holds the global context for the integration tests. The testContext members should only be initialized 36// once during the global setup in TestMain. These variables should only be accessed indirectly through MongoTest 37// instances. 38var testContext struct { 39 connString connstring.ConnString 40 topo *topology.Topology 41 topoKind TopologyKind 42 client *mongo.Client // client used for setup and teardown 43 serverVersion string 44 authEnabled bool 45 sslEnabled bool 46 enterpriseServer bool 47} 48 49func setupClient(cs connstring.ConnString, opts *options.ClientOptions) (*mongo.Client, error) { 50 wcMajority := writeconcern.New(writeconcern.WMajority()) 51 return mongo.Connect(Background, opts.ApplyURI(cs.Original).SetWriteConcern(wcMajority)) 52} 53 54// Setup initializes the current testing context. 55// This function must only be called one time and must be called before any tests run. 56func Setup() error { 57 var err error 58 testContext.connString, err = getConnString() 59 if err != nil { 60 return fmt.Errorf("error getting connection string: %v", err) 61 } 62 63 connectionOpts := []topology.ConnectionOption{ 64 topology.WithOCSPCache(func(ocsp.Cache) ocsp.Cache { 65 return ocsp.NewCache() 66 }), 67 } 68 serverOpts := []topology.ServerOption{ 69 topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption { 70 return append(opts, connectionOpts...) 71 }), 72 } 73 testContext.topo, err = topology.New( 74 topology.WithConnString(func(connstring.ConnString) connstring.ConnString { 75 return testContext.connString 76 }), 77 topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption { 78 return append(opts, serverOpts...) 79 }), 80 ) 81 if err != nil { 82 return fmt.Errorf("error creating topology: %v", err) 83 } 84 if err = testContext.topo.Connect(); err != nil { 85 return fmt.Errorf("error connecting topology: %v", err) 86 } 87 88 testContext.client, err = setupClient(testContext.connString, options.Client()) 89 if err != nil { 90 return fmt.Errorf("error connecting test client: %v", err) 91 } 92 93 pingCtx, cancel := context.WithTimeout(Background, 2*time.Second) 94 defer cancel() 95 if err := testContext.client.Ping(pingCtx, readpref.Primary()); err != nil { 96 return fmt.Errorf("ping error: %v; make sure the deployment is running on URI %v", err, 97 testContext.connString.Original) 98 } 99 100 if testContext.serverVersion, err = getServerVersion(); err != nil { 101 return fmt.Errorf("error getting server version: %v", err) 102 } 103 104 switch testContext.topo.Kind() { 105 case description.Single: 106 testContext.topoKind = Single 107 case description.ReplicaSet, description.ReplicaSetWithPrimary, description.ReplicaSetNoPrimary: 108 testContext.topoKind = ReplicaSet 109 case description.Sharded: 110 testContext.topoKind = Sharded 111 default: 112 return fmt.Errorf("could not detect topology kind; current topology: %s", testContext.topo.String()) 113 } 114 115 if testContext.topoKind == ReplicaSet && CompareServerVersions(testContext.serverVersion, "4.0") >= 0 { 116 err = testContext.client.Database("admin").RunCommand(Background, bson.D{ 117 {"setParameter", 1}, 118 {"transactionLifetimeLimitSeconds", 3}, 119 }).Err() 120 if err != nil { 121 return fmt.Errorf("error setting transactionLifetimeLimitSeconds: %v", err) 122 } 123 } 124 125 testContext.authEnabled = os.Getenv("AUTH") == "auth" 126 testContext.sslEnabled = os.Getenv("SSL") == "ssl" 127 biRes, err := testContext.client.Database("admin").RunCommand(Background, bson.D{{"buildInfo", 1}}).DecodeBytes() 128 if err != nil { 129 return fmt.Errorf("buildInfo error: %v", err) 130 } 131 modulesRaw, err := biRes.LookupErr("modules") 132 if err == nil { 133 // older server versions don't report "modules" field in buildInfo result 134 modules, _ := modulesRaw.Array().Values() 135 for _, module := range modules { 136 if module.StringValue() == "enterprise" { 137 testContext.enterpriseServer = true 138 break 139 } 140 } 141 } 142 return nil 143} 144 145// Teardown cleans up resources initialized by Setup. 146// This function must be called once after all tests have finished running. 147func Teardown() error { 148 if err := testContext.client.Database(TestDb).Drop(Background); err != nil { 149 return fmt.Errorf("error dropping test database: %v", err) 150 } 151 if err := testContext.client.Disconnect(Background); err != nil { 152 return fmt.Errorf("error disconnecting test client: %v", err) 153 } 154 if err := testContext.topo.Disconnect(Background); err != nil { 155 return fmt.Errorf("error disconnecting test topology: %v", err) 156 } 157 return nil 158} 159 160func getServerVersion() (string, error) { 161 var serverStatus bson.Raw 162 err := testContext.client.Database(TestDb).RunCommand( 163 Background, 164 bson.D{{"serverStatus", 1}}, 165 ).Decode(&serverStatus) 166 if err != nil { 167 return "", err 168 } 169 170 version, err := serverStatus.LookupErr("version") 171 if err != nil { 172 return "", errors.New("no version string in serverStatus response") 173 } 174 175 return version.StringValue(), nil 176} 177 178// addOptions appends connection string options to a URI. 179func addOptions(uri string, opts ...string) string { 180 if !strings.ContainsRune(uri, '?') { 181 if uri[len(uri)-1] != '/' { 182 uri += "/" 183 } 184 185 uri += "?" 186 } else { 187 uri += "&" 188 } 189 190 for _, opt := range opts { 191 uri += opt 192 } 193 194 return uri 195} 196 197// addTLSConfig checks for the environmental variable indicating that the tests are being run 198// on an SSL-enabled server, and if so, returns a new URI with the necessary configuration. 199func addTLSConfig(uri string) string { 200 caFile := os.Getenv("MONGO_GO_DRIVER_CA_FILE") 201 if len(caFile) == 0 { 202 return uri 203 } 204 205 return addOptions(uri, "ssl=true&sslCertificateAuthorityFile=", caFile) 206} 207 208// addCompressors checks for the environment variable indicating that the tests are being run with compression 209// enabled. If so, it returns a new URI with the necessary configuration 210func addCompressors(uri string) string { 211 comp := os.Getenv("MONGO_GO_DRIVER_COMPRESSOR") 212 if len(comp) == 0 { 213 return uri 214 } 215 216 return addOptions(uri, "compressors=", comp) 217} 218 219// ConnString gets the globally configured connection string. 220func getConnString() (connstring.ConnString, error) { 221 uri := os.Getenv("MONGODB_URI") 222 if uri == "" { 223 uri = "mongodb://localhost:27017" 224 } 225 uri = addTLSConfig(uri) 226 uri = addCompressors(uri) 227 return connstring.ParseAndValidate(uri) 228} 229 230// CompareServerVersions compares two version number strings (i.e. positive integers separated by 231// periods). Comparisons are done to the lesser precision of the two versions. For example, 3.2 is 232// considered equal to 3.2.11, whereas 3.2.0 is considered less than 3.2.11. 233// 234// Returns a positive int if version1 is greater than version2, a negative int if version1 is less 235// than version2, and 0 if version1 is equal to version2. 236func CompareServerVersions(v1 string, v2 string) int { 237 n1 := strings.Split(v1, ".") 238 n2 := strings.Split(v2, ".") 239 240 for i := 0; i < int(math.Min(float64(len(n1)), float64(len(n2)))); i++ { 241 i1, err := strconv.Atoi(n1[i]) 242 if err != nil { 243 return 1 244 } 245 246 i2, err := strconv.Atoi(n2[i]) 247 if err != nil { 248 return -1 249 } 250 251 difference := i1 - i2 252 if difference != 0 { 253 return difference 254 } 255 } 256 257 return 0 258} 259