1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package testhelpers // import "go.mongodb.org/mongo-driver/internal/testutil/helpers" 8 9import ( 10 "fmt" 11 "io/ioutil" 12 "math" 13 "path" 14 "strings" 15 "time" 16 17 "testing" 18 19 "io" 20 21 "reflect" 22 23 "github.com/stretchr/testify/require" 24 "go.mongodb.org/mongo-driver/bson" 25 "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" 26) 27 28// Test helpers 29 30// IsNil returns true if the object is nil 31func IsNil(object interface{}) bool { 32 if object == nil { 33 return true 34 } 35 36 value := reflect.ValueOf(object) 37 kind := value.Kind() 38 39 // checking to see if type is Chan, Func, Interface, Map, Ptr, or Slice 40 if kind >= reflect.Chan && kind <= reflect.Slice && value.IsNil() { 41 return true 42 } 43 44 return false 45} 46 47// RequireNotNil throws an error if var is nil 48func RequireNotNil(t *testing.T, variable interface{}, msgFormat string, msgVars ...interface{}) { 49 if IsNil(variable) { 50 t.Fatalf(msgFormat, msgVars...) 51 } 52} 53 54// RequireNil throws an error if var is not nil 55func RequireNil(t *testing.T, variable interface{}, msgFormat string, msgVars ...interface{}) { 56 t.Helper() 57 if !IsNil(variable) { 58 t.Fatalf(msgFormat, msgVars...) 59 } 60} 61 62// FindJSONFilesInDir finds the JSON files in a directory. 63func FindJSONFilesInDir(t *testing.T, dir string) []string { 64 files := make([]string, 0) 65 66 entries, err := ioutil.ReadDir(dir) 67 require.NoError(t, err) 68 69 for _, entry := range entries { 70 if entry.IsDir() || path.Ext(entry.Name()) != ".json" { 71 continue 72 } 73 74 files = append(files, entry.Name()) 75 } 76 77 return files 78} 79 80// RequireNoErrorOnClose ensures there is not an error when calling Close. 81func RequireNoErrorOnClose(t *testing.T, c io.Closer) { 82 require.NoError(t, c.Close()) 83} 84 85// VerifyConnStringOptions verifies the options on the connection string. 86func VerifyConnStringOptions(t *testing.T, cs connstring.ConnString, options map[string]interface{}) { 87 // Check that all options are present. 88 for key, value := range options { 89 90 key = strings.ToLower(key) 91 switch key { 92 case "appname": 93 require.Equal(t, value, cs.AppName) 94 case "authsource": 95 require.Equal(t, value, cs.AuthSource) 96 case "authmechanism": 97 require.Equal(t, value, cs.AuthMechanism) 98 case "authmechanismproperties": 99 convertedMap := value.(map[string]interface{}) 100 require.Equal(t, 101 mapInterfaceToString(convertedMap), 102 cs.AuthMechanismProperties) 103 case "compressors": 104 require.Equal(t, convertToStringSlice(value), cs.Compressors) 105 case "connecttimeoutms": 106 require.Equal(t, value, float64(cs.ConnectTimeout/time.Millisecond)) 107 case "directconnection": 108 require.True(t, cs.DirectConnectionSet) 109 require.Equal(t, value, cs.DirectConnection) 110 case "heartbeatfrequencyms": 111 require.Equal(t, value, float64(cs.HeartbeatInterval/time.Millisecond)) 112 case "journal": 113 require.True(t, cs.JSet) 114 require.Equal(t, value, cs.J) 115 case "localthresholdms": 116 require.True(t, cs.LocalThresholdSet) 117 require.Equal(t, value, float64(cs.LocalThreshold/time.Millisecond)) 118 case "maxidletimems": 119 require.Equal(t, value, float64(cs.MaxConnIdleTime/time.Millisecond)) 120 case "maxpoolsize": 121 require.True(t, cs.MaxPoolSizeSet) 122 require.Equal(t, value, cs.MaxPoolSize) 123 case "maxstalenessseconds": 124 require.True(t, cs.MaxStalenessSet) 125 require.Equal(t, value, float64(cs.MaxStaleness/time.Second)) 126 case "minpoolsize": 127 require.True(t, cs.MinPoolSizeSet) 128 require.Equal(t, value, int64(cs.MinPoolSize)) 129 case "readpreference": 130 require.Equal(t, value, cs.ReadPreference) 131 case "readpreferencetags": 132 sm, ok := value.([]interface{}) 133 require.True(t, ok) 134 tags := make([]map[string]string, 0, len(sm)) 135 for _, i := range sm { 136 m, ok := i.(map[string]interface{}) 137 require.True(t, ok) 138 tags = append(tags, mapInterfaceToString(m)) 139 } 140 require.Equal(t, tags, cs.ReadPreferenceTagSets) 141 case "readconcernlevel": 142 require.Equal(t, value, cs.ReadConcernLevel) 143 case "replicaset": 144 require.Equal(t, value, cs.ReplicaSet) 145 case "retrywrites": 146 require.True(t, cs.RetryWritesSet) 147 require.Equal(t, value, cs.RetryWrites) 148 case "serverselectiontimeoutms": 149 require.Equal(t, value, float64(cs.ServerSelectionTimeout/time.Millisecond)) 150 case "ssl", "tls": 151 require.Equal(t, value, cs.SSL) 152 case "sockettimeoutms": 153 require.Equal(t, value, float64(cs.SocketTimeout/time.Millisecond)) 154 case "tlsallowinvalidcertificates", "tlsallowinvalidhostnames", "tlsinsecure": 155 require.True(t, cs.SSLInsecureSet) 156 require.Equal(t, value, cs.SSLInsecure) 157 case "tlscafile": 158 require.True(t, cs.SSLCaFileSet) 159 require.Equal(t, value, cs.SSLCaFile) 160 case "tlscertificatekeyfile": 161 require.True(t, cs.SSLClientCertificateKeyFileSet) 162 require.Equal(t, value, cs.SSLClientCertificateKeyFile) 163 case "tlscertificatekeyfilepassword": 164 require.True(t, cs.SSLClientCertificateKeyPasswordSet) 165 require.Equal(t, value, cs.SSLClientCertificateKeyPassword()) 166 case "w": 167 if cs.WNumberSet { 168 valueInt := GetIntFromInterface(value) 169 require.NotNil(t, valueInt) 170 require.Equal(t, *valueInt, int64(cs.WNumber)) 171 } else { 172 require.Equal(t, value, cs.WString) 173 } 174 case "wtimeoutms": 175 require.Equal(t, value, float64(cs.WTimeout/time.Millisecond)) 176 case "waitqueuetimeoutms": 177 case "zlibcompressionlevel": 178 require.Equal(t, value, float64(cs.ZlibLevel)) 179 case "zstdcompressionlevel": 180 require.Equal(t, value, float64(cs.ZstdLevel)) 181 case "tlsdisableocspendpointcheck": 182 require.Equal(t, value, cs.SSLDisableOCSPEndpointCheck) 183 default: 184 opt, ok := cs.UnknownOptions[key] 185 require.True(t, ok) 186 require.Contains(t, opt, fmt.Sprint(value)) 187 } 188 } 189} 190 191// RawSliceToInterfaceSlice converts a []bson.Raw to []interface{}. 192func RawSliceToInterfaceSlice(elems []bson.Raw) []interface{} { 193 out := make([]interface{}, 0, len(elems)) 194 for _, elem := range elems { 195 out = append(out, elem) 196 } 197 return out 198} 199 200// RawToInterfaceSlice converts a bson.Raw that is internally an array to []interface{}. 201func RawToInterfaceSlice(doc bson.Raw) []interface{} { 202 values, _ := doc.Values() 203 204 out := make([]interface{}, 0, len(values)) 205 for _, val := range values { 206 out = append(out, val.Document()) 207 } 208 209 return out 210} 211 212// Convert each interface{} value in the map to a string. 213func mapInterfaceToString(m map[string]interface{}) map[string]string { 214 out := make(map[string]string) 215 216 for key, value := range m { 217 out[key] = fmt.Sprint(value) 218 } 219 220 return out 221} 222 223func convertToStringSlice(i interface{}) []string { 224 s, ok := i.([]interface{}) 225 if !ok { 226 return nil 227 } 228 ret := make([]string, 0, len(s)) 229 for _, v := range s { 230 str, ok := v.(string) 231 if !ok { 232 continue 233 } 234 ret = append(ret, str) 235 } 236 return ret 237} 238 239// GetIntFromInterface attempts to convert an empty interface value to an integer. 240// 241// Returns nil if it is not possible. 242func GetIntFromInterface(i interface{}) *int64 { 243 var out int64 244 245 switch v := i.(type) { 246 case int: 247 out = int64(v) 248 case int32: 249 out = int64(v) 250 case int64: 251 out = v 252 case float32: 253 f := float64(v) 254 if math.Floor(f) != f || f > float64(math.MaxInt64) { 255 break 256 } 257 258 out = int64(f) 259 260 case float64: 261 if math.Floor(v) != v || v > float64(math.MaxInt64) { 262 break 263 } 264 265 out = int64(v) 266 default: 267 return nil 268 } 269 270 return &out 271} 272