1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package testutil 8 9import ( 10 "context" 11 "fmt" 12 "math" 13 "os" 14 "reflect" 15 "strconv" 16 "strings" 17 "sync" 18 "testing" 19 20 "github.com/stretchr/testify/require" 21 "go.mongodb.org/mongo-driver/event" 22 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" 23 "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" 24 "go.mongodb.org/mongo-driver/x/mongo/driver/description" 25 "go.mongodb.org/mongo-driver/x/mongo/driver/ocsp" 26 "go.mongodb.org/mongo-driver/x/mongo/driver/operation" 27 "go.mongodb.org/mongo-driver/x/mongo/driver/session" 28 "go.mongodb.org/mongo-driver/x/mongo/driver/topology" 29) 30 31var connectionString connstring.ConnString 32var connectionStringOnce sync.Once 33var connectionStringErr error 34var liveTopology *topology.Topology 35var liveSessionPool *session.Pool 36var liveTopologyOnce sync.Once 37var liveTopologyErr error 38var monitoredTopology *topology.Topology 39var monitoredSessionPool *session.Pool 40var monitoredTopologyOnce sync.Once 41var monitoredTopologyErr error 42 43// AddOptionsToURI appends connection string options to a URI. 44func AddOptionsToURI(uri string, opts ...string) string { 45 if !strings.ContainsRune(uri, '?') { 46 if uri[len(uri)-1] != '/' { 47 uri += "/" 48 } 49 50 uri += "?" 51 } else { 52 uri += "&" 53 } 54 55 for _, opt := range opts { 56 uri += opt 57 } 58 59 return uri 60} 61 62// AddTLSConfigToURI checks for the environmental variable indicating that the tests are being run 63// on an SSL-enabled server, and if so, returns a new URI with the necessary configuration. 64func AddTLSConfigToURI(uri string) string { 65 caFile := os.Getenv("MONGO_GO_DRIVER_CA_FILE") 66 if len(caFile) == 0 { 67 return uri 68 } 69 70 return AddOptionsToURI(uri, "ssl=true&sslCertificateAuthorityFile=", caFile) 71} 72 73// AddCompressorToUri checks for the environment variable indicating that the tests are being run with compression 74// enabled. If so, it returns a new URI with the necessary configuration 75func AddCompressorToUri(uri string) string { 76 comp := os.Getenv("MONGO_GO_DRIVER_COMPRESSOR") 77 if len(comp) == 0 { 78 return uri 79 } 80 81 return AddOptionsToURI(uri, "compressors=", comp) 82} 83 84// MonitoredTopology returns a new topology with the command monitor attached 85func MonitoredTopology(t *testing.T, dbName string, monitor *event.CommandMonitor) *topology.Topology { 86 cs := ConnString(t) 87 opts := []topology.Option{ 88 topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }), 89 topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption { 90 return append( 91 opts, 92 topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption { 93 return append( 94 opts, 95 topology.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { 96 return monitor 97 }), 98 topology.WithOCSPCache(func(ocsp.Cache) ocsp.Cache { 99 return ocsp.NewCache() 100 }), 101 ) 102 }), 103 ) 104 }), 105 } 106 107 monitoredTopology, err := topology.New(opts...) 108 if err != nil { 109 t.Fatal(err) 110 } else { 111 monitoredTopology.Connect() 112 113 err = operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dropDatabase", 1))). 114 Database(dbName).ServerSelector(description.WriteSelector()).Deployment(monitoredTopology).Execute(context.Background()) 115 116 require.NoError(t, err) 117 } 118 119 return monitoredTopology 120} 121 122// GlobalMonitoredTopology gets the globally configured topology and attaches a command monitor. 123func GlobalMonitoredTopology(t *testing.T, monitor *event.CommandMonitor) *topology.Topology { 124 cs := ConnString(t) 125 opts := []topology.Option{ 126 topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }), 127 topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption { 128 return append( 129 opts, 130 topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption { 131 return append( 132 opts, 133 topology.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { 134 return monitor 135 }), 136 topology.WithOCSPCache(func(ocsp.Cache) ocsp.Cache { 137 return ocsp.NewCache() 138 }), 139 ) 140 }), 141 ) 142 }), 143 } 144 145 monitoredTopologyOnce.Do(func() { 146 var err error 147 monitoredTopology, err = topology.New(opts...) 148 if err != nil { 149 monitoredTopologyErr = err 150 } else { 151 monitoredTopology.Connect() 152 153 err = operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dropDatabase", 1))). 154 Database(DBName(t)).ServerSelector(description.WriteSelector()).Deployment(monitoredTopology).Execute(context.Background()) 155 156 require.NoError(t, err) 157 158 sub, err := monitoredTopology.Subscribe() 159 require.NoError(t, err) 160 monitoredSessionPool = session.NewPool(sub.Updates) 161 } 162 }) 163 164 if monitoredTopologyErr != nil { 165 t.Fatal(monitoredTopologyErr) 166 } 167 168 return monitoredTopology 169} 170 171// GlobalMonitoredSessionPool returns the globally configured session pool. 172// Must be called after GlobalMonitoredTopology() 173func GlobalMonitoredSessionPool() *session.Pool { 174 return monitoredSessionPool 175} 176 177// Topology gets the globally configured topology. 178func Topology(t *testing.T) *topology.Topology { 179 cs := ConnString(t) 180 opts := []topology.Option{ 181 topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }), 182 topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption { 183 return append( 184 opts, 185 topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption { 186 return append( 187 opts, 188 topology.WithOCSPCache(func(ocsp.Cache) ocsp.Cache { 189 return ocsp.NewCache() 190 }), 191 ) 192 }), 193 ) 194 }), 195 } 196 197 liveTopologyOnce.Do(func() { 198 var err error 199 liveTopology, err = topology.New(opts...) 200 if err != nil { 201 liveTopologyErr = err 202 } else { 203 liveTopology.Connect() 204 205 err = operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dropDatabase", 1))). 206 Database(DBName(t)).ServerSelector(description.WriteSelector()).Deployment(liveTopology).Execute(context.Background()) 207 require.NoError(t, err) 208 209 sub, err := liveTopology.Subscribe() 210 require.NoError(t, err) 211 liveSessionPool = session.NewPool(sub.Updates) 212 } 213 }) 214 215 if liveTopologyErr != nil { 216 t.Fatal(liveTopologyErr) 217 } 218 219 return liveTopology 220} 221 222// SessionPool gets the globally configured session pool. Must be called after Topology(). 223func SessionPool() *session.Pool { 224 return liveSessionPool 225} 226 227// TopologyWithConnString takes a connection string and returns a connected 228// topology, or else bails out of testing 229func TopologyWithConnString(t *testing.T, cs connstring.ConnString) *topology.Topology { 230 opts := []topology.Option{ 231 topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }), 232 topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption { 233 return append( 234 opts, 235 topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption { 236 return append( 237 opts, 238 topology.WithOCSPCache(func(ocsp.Cache) ocsp.Cache { 239 return ocsp.NewCache() 240 }), 241 ) 242 }), 243 ) 244 }), 245 } 246 247 topology, err := topology.New(opts...) 248 if err != nil { 249 t.Fatal("Could not construct topology") 250 } 251 err = topology.Connect() 252 if err != nil { 253 t.Fatal("Could not start topology connection") 254 } 255 return topology 256} 257 258// ColName gets a collection name that should be unique 259// to the currently executing test. 260func ColName(t *testing.T) string { 261 // Get this indirectly to avoid copying a mutex 262 v := reflect.Indirect(reflect.ValueOf(t)) 263 name := v.FieldByName("name") 264 return name.String() 265} 266 267// ConnString gets the globally configured connection string. 268func ConnString(t *testing.T) connstring.ConnString { 269 connectionStringOnce.Do(func() { 270 connectionString, connectionStringErr = GetConnString() 271 mongodbURI := os.Getenv("MONGODB_URI") 272 if mongodbURI == "" { 273 mongodbURI = "mongodb://localhost:27017" 274 } 275 276 mongodbURI = AddTLSConfigToURI(mongodbURI) 277 mongodbURI = AddCompressorToUri(mongodbURI) 278 279 var err error 280 connectionString, err = connstring.ParseAndValidate(mongodbURI) 281 if err != nil { 282 connectionStringErr = err 283 } 284 }) 285 if connectionStringErr != nil { 286 t.Fatal(connectionStringErr) 287 } 288 289 return connectionString 290} 291 292func GetConnString() (connstring.ConnString, error) { 293 mongodbURI := os.Getenv("MONGODB_URI") 294 if mongodbURI == "" { 295 mongodbURI = "mongodb://localhost:27017" 296 } 297 298 mongodbURI = AddTLSConfigToURI(mongodbURI) 299 300 cs, err := connstring.ParseAndValidate(mongodbURI) 301 if err != nil { 302 return connstring.ConnString{}, err 303 } 304 305 return cs, nil 306} 307 308// DBName gets the globally configured database name. 309func DBName(t *testing.T) string { 310 return GetDBName(ConnString(t)) 311} 312 313func GetDBName(cs connstring.ConnString) string { 314 if cs.Database != "" { 315 return cs.Database 316 } 317 318 return fmt.Sprintf("mongo-go-driver-%d", os.Getpid()) 319} 320 321// Integration should be called at the beginning of integration 322// tests to ensure that they are skipped if integration testing is 323// turned off. 324func Integration(t *testing.T) { 325 if testing.Short() { 326 t.Skip("skipping integration test in short mode") 327 } 328} 329 330// compareVersions compares two version number strings (i.e. positive integers separated by 331// periods). Comparisons are done to the lesser precision of the two versions. For example, 3.2 is 332// considered equal to 3.2.11, whereas 3.2.0 is considered less than 3.2.11. 333// 334// Returns a positive int if version1 is greater than version2, a negative int if version1 is less 335// than version2, and 0 if version1 is equal to version2. 336func CompareVersions(t *testing.T, v1 string, v2 string) int { 337 n1 := strings.Split(v1, ".") 338 n2 := strings.Split(v2, ".") 339 340 for i := 0; i < int(math.Min(float64(len(n1)), float64(len(n2)))); i++ { 341 i1, err := strconv.Atoi(n1[i]) 342 require.NoError(t, err) 343 344 i2, err := strconv.Atoi(n2[i]) 345 require.NoError(t, err) 346 347 difference := i1 - i2 348 if difference != 0 { 349 return difference 350 } 351 } 352 353 return 0 354} 355