1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package testutil 8 9import ( 10 "context" 11 "fmt" 12 "math" 13 "os" 14 "reflect" 15 "strconv" 16 "strings" 17 "sync" 18 "testing" 19 20 "github.com/stretchr/testify/require" 21 "go.mongodb.org/mongo-driver/event" 22 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" 23 "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" 24 "go.mongodb.org/mongo-driver/x/mongo/driver/description" 25 "go.mongodb.org/mongo-driver/x/mongo/driver/operation" 26 "go.mongodb.org/mongo-driver/x/mongo/driver/session" 27 "go.mongodb.org/mongo-driver/x/mongo/driver/topology" 28) 29 30var connectionString connstring.ConnString 31var connectionStringOnce sync.Once 32var connectionStringErr error 33var liveTopology *topology.Topology 34var liveSessionPool *session.Pool 35var liveTopologyOnce sync.Once 36var liveTopologyErr error 37var monitoredTopology *topology.Topology 38var monitoredSessionPool *session.Pool 39var monitoredTopologyOnce sync.Once 40var monitoredTopologyErr error 41 42// AddOptionsToURI appends connection string options to a URI. 43func AddOptionsToURI(uri string, opts ...string) string { 44 if !strings.ContainsRune(uri, '?') { 45 if uri[len(uri)-1] != '/' { 46 uri += "/" 47 } 48 49 uri += "?" 50 } else { 51 uri += "&" 52 } 53 54 for _, opt := range opts { 55 uri += opt 56 } 57 58 return uri 59} 60 61// AddTLSConfigToURI checks for the environmental variable indicating that the tests are being run 62// on an SSL-enabled server, and if so, returns a new URI with the necessary configuration. 63func AddTLSConfigToURI(uri string) string { 64 caFile := os.Getenv("MONGO_GO_DRIVER_CA_FILE") 65 if len(caFile) == 0 { 66 return uri 67 } 68 69 return AddOptionsToURI(uri, "ssl=true&sslCertificateAuthorityFile=", caFile) 70} 71 72// AddCompressorToUri checks for the environment variable indicating that the tests are being run with compression 73// enabled. If so, it returns a new URI with the necessary configuration 74func AddCompressorToUri(uri string) string { 75 comp := os.Getenv("MONGO_GO_DRIVER_COMPRESSOR") 76 if len(comp) == 0 { 77 return uri 78 } 79 80 return AddOptionsToURI(uri, "compressors=", comp) 81} 82 83// MonitoredTopology returns a new topology with the command monitor attached 84func MonitoredTopology(t *testing.T, dbName string, monitor *event.CommandMonitor) *topology.Topology { 85 cs := ConnString(t) 86 opts := []topology.Option{ 87 topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }), 88 topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption { 89 return append( 90 opts, 91 topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption { 92 return append( 93 opts, 94 topology.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { 95 return monitor 96 }), 97 ) 98 }), 99 ) 100 }), 101 } 102 103 monitoredTopology, err := topology.New(opts...) 104 if err != nil { 105 t.Fatal(err) 106 } else { 107 monitoredTopology.Connect() 108 109 err = operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dropDatabase", 1))). 110 Database(dbName).ServerSelector(description.WriteSelector()).Deployment(monitoredTopology).Execute(context.Background()) 111 112 require.NoError(t, err) 113 } 114 115 return monitoredTopology 116} 117 118// GlobalMonitoredTopology gets the globally configured topology and attaches a command monitor. 119func GlobalMonitoredTopology(t *testing.T, monitor *event.CommandMonitor) *topology.Topology { 120 cs := ConnString(t) 121 opts := []topology.Option{ 122 topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }), 123 topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption { 124 return append( 125 opts, 126 topology.WithConnectionOptions(func(opts ...topology.ConnectionOption) []topology.ConnectionOption { 127 return append( 128 opts, 129 topology.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { 130 return monitor 131 }), 132 ) 133 }), 134 ) 135 }), 136 } 137 138 monitoredTopologyOnce.Do(func() { 139 var err error 140 monitoredTopology, err = topology.New(opts...) 141 if err != nil { 142 monitoredTopologyErr = err 143 } else { 144 monitoredTopology.Connect() 145 146 err = operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dropDatabase", 1))). 147 Database(DBName(t)).ServerSelector(description.WriteSelector()).Deployment(monitoredTopology).Execute(context.Background()) 148 149 require.NoError(t, err) 150 151 sub, err := monitoredTopology.Subscribe() 152 require.NoError(t, err) 153 monitoredSessionPool = session.NewPool(sub.Updates) 154 } 155 }) 156 157 if monitoredTopologyErr != nil { 158 t.Fatal(monitoredTopologyErr) 159 } 160 161 return monitoredTopology 162} 163 164// GlobalMonitoredSessionPool returns the globally configured session pool. 165// Must be called after GlobalMonitoredTopology() 166func GlobalMonitoredSessionPool() *session.Pool { 167 return monitoredSessionPool 168} 169 170// Topology gets the globally configured topology. 171func Topology(t *testing.T) *topology.Topology { 172 cs := ConnString(t) 173 liveTopologyOnce.Do(func() { 174 var err error 175 liveTopology, err = topology.New(topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs })) 176 if err != nil { 177 liveTopologyErr = err 178 } else { 179 liveTopology.Connect() 180 181 err = operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dropDatabase", 1))). 182 Database(DBName(t)).ServerSelector(description.WriteSelector()).Deployment(liveTopology).Execute(context.Background()) 183 require.NoError(t, err) 184 185 sub, err := liveTopology.Subscribe() 186 require.NoError(t, err) 187 liveSessionPool = session.NewPool(sub.Updates) 188 } 189 }) 190 191 if liveTopologyErr != nil { 192 t.Fatal(liveTopologyErr) 193 } 194 195 return liveTopology 196} 197 198// SessionPool gets the globally configured session pool. Must be called after Topology(). 199func SessionPool() *session.Pool { 200 return liveSessionPool 201} 202 203// TopologyWithConnString takes a connection string and returns a connected 204// topology, or else bails out of testing 205func TopologyWithConnString(t *testing.T, cs connstring.ConnString) *topology.Topology { 206 topology, err := topology.New(topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs })) 207 if err != nil { 208 t.Fatal("Could not construct topology") 209 } 210 err = topology.Connect() 211 if err != nil { 212 t.Fatal("Could not start topology connection") 213 } 214 return topology 215} 216 217// ColName gets a collection name that should be unique 218// to the currently executing test. 219func ColName(t *testing.T) string { 220 // Get this indirectly to avoid copying a mutex 221 v := reflect.Indirect(reflect.ValueOf(t)) 222 name := v.FieldByName("name") 223 return name.String() 224} 225 226// ConnString gets the globally configured connection string. 227func ConnString(t *testing.T) connstring.ConnString { 228 connectionStringOnce.Do(func() { 229 connectionString, connectionStringErr = GetConnString() 230 mongodbURI := os.Getenv("MONGODB_URI") 231 if mongodbURI == "" { 232 mongodbURI = "mongodb://localhost:27017" 233 } 234 235 mongodbURI = AddTLSConfigToURI(mongodbURI) 236 mongodbURI = AddCompressorToUri(mongodbURI) 237 238 var err error 239 connectionString, err = connstring.Parse(mongodbURI) 240 if err != nil { 241 connectionStringErr = err 242 } 243 }) 244 if connectionStringErr != nil { 245 t.Fatal(connectionStringErr) 246 } 247 248 return connectionString 249} 250 251func GetConnString() (connstring.ConnString, error) { 252 mongodbURI := os.Getenv("MONGODB_URI") 253 if mongodbURI == "" { 254 mongodbURI = "mongodb://localhost:27017" 255 } 256 257 mongodbURI = AddTLSConfigToURI(mongodbURI) 258 259 cs, err := connstring.Parse(mongodbURI) 260 if err != nil { 261 return connstring.ConnString{}, err 262 } 263 264 return cs, nil 265} 266 267// DBName gets the globally configured database name. 268func DBName(t *testing.T) string { 269 return GetDBName(ConnString(t)) 270} 271 272func GetDBName(cs connstring.ConnString) string { 273 if cs.Database != "" { 274 return cs.Database 275 } 276 277 return fmt.Sprintf("mongo-go-driver-%d", os.Getpid()) 278} 279 280// Integration should be called at the beginning of integration 281// tests to ensure that they are skipped if integration testing is 282// turned off. 283func Integration(t *testing.T) { 284 if testing.Short() { 285 t.Skip("skipping integration test in short mode") 286 } 287} 288 289// compareVersions compares two version number strings (i.e. positive integers separated by 290// periods). Comparisons are done to the lesser precision of the two versions. For example, 3.2 is 291// considered equal to 3.2.11, whereas 3.2.0 is considered less than 3.2.11. 292// 293// Returns a positive int if version1 is greater than version2, a negative int if version1 is less 294// than version2, and 0 if version1 is equal to version2. 295func CompareVersions(t *testing.T, v1 string, v2 string) int { 296 n1 := strings.Split(v1, ".") 297 n2 := strings.Split(v2, ".") 298 299 for i := 0; i < int(math.Min(float64(len(n1)), float64(len(n2)))); i++ { 300 i1, err := strconv.Atoi(n1[i]) 301 require.NoError(t, err) 302 303 i2, err := strconv.Atoi(n2[i]) 304 require.NoError(t, err) 305 306 difference := i1 - i2 307 if difference != 0 { 308 return difference 309 } 310 } 311 312 return 0 313} 314