1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package testhelpers // import "go.mongodb.org/mongo-driver/internal/testutil/helpers"
8
9import (
10	"fmt"
11	"io/ioutil"
12	"math"
13	"path"
14	"strings"
15	"time"
16
17	"testing"
18
19	"io"
20
21	"reflect"
22
23	"github.com/stretchr/testify/require"
24	"go.mongodb.org/mongo-driver/bson"
25	"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
26)
27
28// Test helpers
29
30// IsNil returns true if the object is nil
31func IsNil(object interface{}) bool {
32	if object == nil {
33		return true
34	}
35
36	value := reflect.ValueOf(object)
37	kind := value.Kind()
38
39	// checking to see if type is Chan, Func, Interface, Map, Ptr, or Slice
40	if kind >= reflect.Chan && kind <= reflect.Slice && value.IsNil() {
41		return true
42	}
43
44	return false
45}
46
47// RequireNotNil throws an error if var is nil
48func RequireNotNil(t *testing.T, variable interface{}, msgFormat string, msgVars ...interface{}) {
49	if IsNil(variable) {
50		t.Fatalf(msgFormat, msgVars...)
51	}
52}
53
54// RequireNil throws an error if var is not nil
55func RequireNil(t *testing.T, variable interface{}, msgFormat string, msgVars ...interface{}) {
56	t.Helper()
57	if !IsNil(variable) {
58		t.Fatalf(msgFormat, msgVars...)
59	}
60}
61
62// FindJSONFilesInDir finds the JSON files in a directory.
63func FindJSONFilesInDir(t *testing.T, dir string) []string {
64	files := make([]string, 0)
65
66	entries, err := ioutil.ReadDir(dir)
67	require.NoError(t, err)
68
69	for _, entry := range entries {
70		if entry.IsDir() || path.Ext(entry.Name()) != ".json" {
71			continue
72		}
73
74		files = append(files, entry.Name())
75	}
76
77	return files
78}
79
80// RequireNoErrorOnClose ensures there is not an error when calling Close.
81func RequireNoErrorOnClose(t *testing.T, c io.Closer) {
82	require.NoError(t, c.Close())
83}
84
85// VerifyConnStringOptions verifies the options on the connection string.
86func VerifyConnStringOptions(t *testing.T, cs connstring.ConnString, options map[string]interface{}) {
87	// Check that all options are present.
88	for key, value := range options {
89
90		key = strings.ToLower(key)
91		switch key {
92		case "appname":
93			require.Equal(t, value, cs.AppName)
94		case "authsource":
95			require.Equal(t, value, cs.AuthSource)
96		case "authmechanism":
97			require.Equal(t, value, cs.AuthMechanism)
98		case "authmechanismproperties":
99			convertedMap := value.(map[string]interface{})
100			require.Equal(t,
101				mapInterfaceToString(convertedMap),
102				cs.AuthMechanismProperties)
103		case "compressors":
104			require.Equal(t, convertToStringSlice(value), cs.Compressors)
105		case "connecttimeoutms":
106			require.Equal(t, value, float64(cs.ConnectTimeout/time.Millisecond))
107		case "directconnection":
108			require.True(t, cs.DirectConnectionSet)
109			require.Equal(t, value, cs.DirectConnection)
110		case "heartbeatfrequencyms":
111			require.Equal(t, value, float64(cs.HeartbeatInterval/time.Millisecond))
112		case "journal":
113			require.True(t, cs.JSet)
114			require.Equal(t, value, cs.J)
115		case "localthresholdms":
116			require.True(t, cs.LocalThresholdSet)
117			require.Equal(t, value, float64(cs.LocalThreshold/time.Millisecond))
118		case "maxidletimems":
119			require.Equal(t, value, float64(cs.MaxConnIdleTime/time.Millisecond))
120		case "maxpoolsize":
121			require.True(t, cs.MaxPoolSizeSet)
122			require.Equal(t, value, cs.MaxPoolSize)
123		case "maxstalenessseconds":
124			require.True(t, cs.MaxStalenessSet)
125			require.Equal(t, value, float64(cs.MaxStaleness/time.Second))
126		case "minpoolsize":
127			require.True(t, cs.MinPoolSizeSet)
128			require.Equal(t, value, int64(cs.MinPoolSize))
129		case "readpreference":
130			require.Equal(t, value, cs.ReadPreference)
131		case "readpreferencetags":
132			sm, ok := value.([]interface{})
133			require.True(t, ok)
134			tags := make([]map[string]string, 0, len(sm))
135			for _, i := range sm {
136				m, ok := i.(map[string]interface{})
137				require.True(t, ok)
138				tags = append(tags, mapInterfaceToString(m))
139			}
140			require.Equal(t, tags, cs.ReadPreferenceTagSets)
141		case "readconcernlevel":
142			require.Equal(t, value, cs.ReadConcernLevel)
143		case "replicaset":
144			require.Equal(t, value, cs.ReplicaSet)
145		case "retrywrites":
146			require.True(t, cs.RetryWritesSet)
147			require.Equal(t, value, cs.RetryWrites)
148		case "serverselectiontimeoutms":
149			require.Equal(t, value, float64(cs.ServerSelectionTimeout/time.Millisecond))
150		case "ssl", "tls":
151			require.Equal(t, value, cs.SSL)
152		case "sockettimeoutms":
153			require.Equal(t, value, float64(cs.SocketTimeout/time.Millisecond))
154		case "tlsallowinvalidcertificates", "tlsallowinvalidhostnames", "tlsinsecure":
155			require.True(t, cs.SSLInsecureSet)
156			require.Equal(t, value, cs.SSLInsecure)
157		case "tlscafile":
158			require.True(t, cs.SSLCaFileSet)
159			require.Equal(t, value, cs.SSLCaFile)
160		case "tlscertificatekeyfile":
161			require.True(t, cs.SSLClientCertificateKeyFileSet)
162			require.Equal(t, value, cs.SSLClientCertificateKeyFile)
163		case "tlscertificatekeyfilepassword":
164			require.True(t, cs.SSLClientCertificateKeyPasswordSet)
165			require.Equal(t, value, cs.SSLClientCertificateKeyPassword())
166		case "w":
167			if cs.WNumberSet {
168				valueInt := GetIntFromInterface(value)
169				require.NotNil(t, valueInt)
170				require.Equal(t, *valueInt, int64(cs.WNumber))
171			} else {
172				require.Equal(t, value, cs.WString)
173			}
174		case "wtimeoutms":
175			require.Equal(t, value, float64(cs.WTimeout/time.Millisecond))
176		case "waitqueuetimeoutms":
177		case "zlibcompressionlevel":
178			require.Equal(t, value, float64(cs.ZlibLevel))
179		case "zstdcompressionlevel":
180			require.Equal(t, value, float64(cs.ZstdLevel))
181		case "tlsdisableocspendpointcheck":
182			require.Equal(t, value, cs.SSLDisableOCSPEndpointCheck)
183		default:
184			opt, ok := cs.UnknownOptions[key]
185			require.True(t, ok)
186			require.Contains(t, opt, fmt.Sprint(value))
187		}
188	}
189}
190
191// RawSliceToInterfaceSlice converts a []bson.Raw to []interface{}.
192func RawSliceToInterfaceSlice(elems []bson.Raw) []interface{} {
193	out := make([]interface{}, 0, len(elems))
194	for _, elem := range elems {
195		out = append(out, elem)
196	}
197	return out
198}
199
200// RawToInterfaceSlice converts a bson.Raw that is internally an array to []interface{}.
201func RawToInterfaceSlice(doc bson.Raw) []interface{} {
202	values, _ := doc.Values()
203
204	out := make([]interface{}, 0, len(values))
205	for _, val := range values {
206		out = append(out, val.Document())
207	}
208
209	return out
210}
211
212// Convert each interface{} value in the map to a string.
213func mapInterfaceToString(m map[string]interface{}) map[string]string {
214	out := make(map[string]string)
215
216	for key, value := range m {
217		out[key] = fmt.Sprint(value)
218	}
219
220	return out
221}
222
223func convertToStringSlice(i interface{}) []string {
224	s, ok := i.([]interface{})
225	if !ok {
226		return nil
227	}
228	ret := make([]string, 0, len(s))
229	for _, v := range s {
230		str, ok := v.(string)
231		if !ok {
232			continue
233		}
234		ret = append(ret, str)
235	}
236	return ret
237}
238
239// GetIntFromInterface attempts to convert an empty interface value to an integer.
240//
241// Returns nil if it is not possible.
242func GetIntFromInterface(i interface{}) *int64 {
243	var out int64
244
245	switch v := i.(type) {
246	case int:
247		out = int64(v)
248	case int32:
249		out = int64(v)
250	case int64:
251		out = v
252	case float32:
253		f := float64(v)
254		if math.Floor(f) != f || f > float64(math.MaxInt64) {
255			break
256		}
257
258		out = int64(f)
259
260	case float64:
261		if math.Floor(v) != v || v > float64(math.MaxInt64) {
262			break
263		}
264
265		out = int64(v)
266	default:
267		return nil
268	}
269
270	return &out
271}
272