1package options
2
3import (
4	"bytes"
5	"context"
6	"crypto/tls"
7	"crypto/x509"
8	"encoding/pem"
9	"errors"
10	"fmt"
11	"io/ioutil"
12	"net"
13	"os"
14	"reflect"
15	"testing"
16	"time"
17
18	"github.com/google/go-cmp/cmp"
19	"github.com/google/go-cmp/cmp/cmpopts"
20	"go.mongodb.org/mongo-driver/bson"
21	"go.mongodb.org/mongo-driver/bson/bsoncodec"
22	"go.mongodb.org/mongo-driver/event"
23	"go.mongodb.org/mongo-driver/internal"
24	"go.mongodb.org/mongo-driver/internal/testutil/assert"
25	"go.mongodb.org/mongo-driver/mongo/readconcern"
26	"go.mongodb.org/mongo-driver/mongo/readpref"
27	"go.mongodb.org/mongo-driver/mongo/writeconcern"
28	"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
29)
30
31var tClientOptions = reflect.TypeOf(&ClientOptions{})
32
33func TestClientOptions(t *testing.T) {
34	t.Run("ApplyURI/doesn't overwrite previous errors", func(t *testing.T) {
35		uri := "not-mongo-db-uri://"
36		want := internal.WrapErrorf(
37			errors.New(`scheme must be "mongodb" or "mongodb+srv"`), "error parsing uri",
38		)
39		co := Client().ApplyURI(uri).ApplyURI("mongodb://localhost/")
40		got := co.Validate()
41		if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
42			t.Errorf("Did not received expected error. got %v; want %v", got, want)
43		}
44	})
45	t.Run("Validate/returns error", func(t *testing.T) {
46		want := errors.New("validate error")
47		co := &ClientOptions{err: want}
48		got := co.Validate()
49		if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
50			t.Errorf("Did not receive expected error. got %v; want %v", got, want)
51		}
52	})
53	t.Run("Set", func(t *testing.T) {
54		testCases := []struct {
55			name        string
56			fn          interface{} // method to be run
57			arg         interface{} // argument for method
58			field       string      // field to be set
59			dereference bool        // Should we compare a pointer or the field
60		}{
61			{"AppName", (*ClientOptions).SetAppName, "example-application", "AppName", true},
62			{"Auth", (*ClientOptions).SetAuth, Credential{Username: "foo", Password: "bar"}, "Auth", true},
63			{"Compressors", (*ClientOptions).SetCompressors, []string{"zstd", "snappy", "zlib"}, "Compressors", true},
64			{"ConnectTimeout", (*ClientOptions).SetConnectTimeout, 5 * time.Second, "ConnectTimeout", true},
65			{"Dialer", (*ClientOptions).SetDialer, testDialer{Num: 12345}, "Dialer", true},
66			{"HeartbeatInterval", (*ClientOptions).SetHeartbeatInterval, 5 * time.Second, "HeartbeatInterval", true},
67			{"Hosts", (*ClientOptions).SetHosts, []string{"localhost:27017", "localhost:27018", "localhost:27019"}, "Hosts", true},
68			{"LocalThreshold", (*ClientOptions).SetLocalThreshold, 5 * time.Second, "LocalThreshold", true},
69			{"MaxConnIdleTime", (*ClientOptions).SetMaxConnIdleTime, 5 * time.Second, "MaxConnIdleTime", true},
70			{"MaxPoolSize", (*ClientOptions).SetMaxPoolSize, uint64(250), "MaxPoolSize", true},
71			{"MinPoolSize", (*ClientOptions).SetMinPoolSize, uint64(10), "MinPoolSize", true},
72			{"PoolMonitor", (*ClientOptions).SetPoolMonitor, &event.PoolMonitor{}, "PoolMonitor", false},
73			{"Monitor", (*ClientOptions).SetMonitor, &event.CommandMonitor{}, "Monitor", false},
74			{"ReadConcern", (*ClientOptions).SetReadConcern, readconcern.Majority(), "ReadConcern", false},
75			{"ReadPreference", (*ClientOptions).SetReadPreference, readpref.SecondaryPreferred(), "ReadPreference", false},
76			{"Registry", (*ClientOptions).SetRegistry, bson.NewRegistryBuilder().Build(), "Registry", false},
77			{"ReplicaSet", (*ClientOptions).SetReplicaSet, "example-replicaset", "ReplicaSet", true},
78			{"RetryWrites", (*ClientOptions).SetRetryWrites, true, "RetryWrites", true},
79			{"ServerSelectionTimeout", (*ClientOptions).SetServerSelectionTimeout, 5 * time.Second, "ServerSelectionTimeout", true},
80			{"Direct", (*ClientOptions).SetDirect, true, "Direct", true},
81			{"SocketTimeout", (*ClientOptions).SetSocketTimeout, 5 * time.Second, "SocketTimeout", true},
82			{"TLSConfig", (*ClientOptions).SetTLSConfig, &tls.Config{}, "TLSConfig", false},
83			{"WriteConcern", (*ClientOptions).SetWriteConcern, writeconcern.New(writeconcern.WMajority()), "WriteConcern", false},
84			{"ZlibLevel", (*ClientOptions).SetZlibLevel, 6, "ZlibLevel", true},
85			{"DisableOCSPEndpointCheck", (*ClientOptions).SetDisableOCSPEndpointCheck, true, "DisableOCSPEndpointCheck", true},
86		}
87
88		opt1, opt2, optResult := Client(), Client(), Client()
89		for idx, tc := range testCases {
90			t.Run(tc.name, func(t *testing.T) {
91				fn := reflect.ValueOf(tc.fn)
92				if fn.Kind() != reflect.Func {
93					t.Fatal("fn argument must be a function")
94				}
95				if fn.Type().NumIn() < 2 || fn.Type().In(0) != tClientOptions {
96					t.Fatal("fn argument must have a *ClientOptions as the first argument and one other argument")
97				}
98				if _, exists := tClientOptions.Elem().FieldByName(tc.field); !exists {
99					t.Fatalf("field (%s) does not exist in ClientOptions", tc.field)
100				}
101				args := make([]reflect.Value, 2)
102				client := reflect.New(tClientOptions.Elem())
103				args[0] = client
104				want := reflect.ValueOf(tc.arg)
105				args[1] = want
106
107				if !want.IsValid() || !want.CanInterface() {
108					t.Fatal("arg property of test case must be valid")
109				}
110
111				_ = fn.Call(args)
112
113				// To avoid duplication we're piggybacking on the Set* tests to make the
114				// MergeClientOptions test simpler and more thorough.
115				// To do this we set the odd numbered test cases to the first opt, the even and
116				// divisible by three test cases to the second, and the result of merging the two to
117				// the result option. This gives us coverage of options set by the first option, by
118				// the second, and by both.
119				if idx%2 != 0 {
120					args[0] = reflect.ValueOf(opt1)
121					_ = fn.Call(args)
122				}
123				if idx%2 == 0 || idx%3 == 0 {
124					args[0] = reflect.ValueOf(opt2)
125					_ = fn.Call(args)
126				}
127				args[0] = reflect.ValueOf(optResult)
128				_ = fn.Call(args)
129
130				got := client.Elem().FieldByName(tc.field)
131				if !got.IsValid() || !got.CanInterface() {
132					t.Fatal("cannot create concrete instance from retrieved field")
133				}
134
135				if got.Kind() == reflect.Ptr && tc.dereference {
136					got = got.Elem()
137				}
138
139				if !cmp.Equal(
140					got.Interface(), want.Interface(),
141					cmp.AllowUnexported(readconcern.ReadConcern{}, writeconcern.WriteConcern{}, readpref.ReadPref{}),
142					cmp.Comparer(func(r1, r2 *bsoncodec.Registry) bool { return r1 == r2 }),
143					cmp.Comparer(func(cfg1, cfg2 *tls.Config) bool { return cfg1 == cfg2 }),
144					cmp.Comparer(func(fp1, fp2 *event.PoolMonitor) bool { return fp1 == fp2 }),
145				) {
146					t.Errorf("Field not set properly. got %v; want %v", got.Interface(), want.Interface())
147				}
148			})
149		}
150		t.Run("MergeClientOptions/all set", func(t *testing.T) {
151			want := optResult
152			got := MergeClientOptions(nil, opt1, opt2)
153			if diff := cmp.Diff(
154				got, want,
155				cmp.AllowUnexported(readconcern.ReadConcern{}, writeconcern.WriteConcern{}, readpref.ReadPref{}),
156				cmp.Comparer(func(r1, r2 *bsoncodec.Registry) bool { return r1 == r2 }),
157				cmp.Comparer(func(cfg1, cfg2 *tls.Config) bool { return cfg1 == cfg2 }),
158				cmp.Comparer(func(fp1, fp2 *event.PoolMonitor) bool { return fp1 == fp2 }),
159				cmp.AllowUnexported(ClientOptions{}),
160			); diff != "" {
161				t.Errorf("diff:\n%s", diff)
162				t.Errorf("Merged client options do not match. got %v; want %v", got, want)
163			}
164		})
165
166		// go-cmp dont support error comparisons (https://github.com/google/go-cmp/issues/24)
167		// Use specifique test for this
168		t.Run("MergeClientOptions/err", func(t *testing.T) {
169			opt1, opt2 := Client(), Client()
170			opt1.err = errors.New("Test error")
171
172			got := MergeClientOptions(nil, opt1, opt2)
173			if got.err.Error() != "Test error" {
174				t.Errorf("Merged client options do not match. got %v; want %v", got.err.Error(), opt1.err.Error())
175			}
176		})
177	})
178	t.Run("ApplyURI", func(t *testing.T) {
179		baseClient := func() *ClientOptions {
180			return Client().SetHosts([]string{"localhost"})
181		}
182		testCases := []struct {
183			name   string
184			uri    string
185			result *ClientOptions
186		}{
187			{
188				"ParseError",
189				"not-mongo-db-uri://",
190				&ClientOptions{err: internal.WrapErrorf(
191					errors.New(`scheme must be "mongodb" or "mongodb+srv"`), "error parsing uri",
192				)},
193			},
194			{
195				"ReadPreference Invalid Mode",
196				"mongodb://localhost/?maxStaleness=200",
197				&ClientOptions{
198					err:   fmt.Errorf("unknown read preference %v", ""),
199					Hosts: []string{"localhost"},
200				},
201			},
202			{
203				"ReadPreference Primary With Options",
204				"mongodb://localhost/?readPreference=Primary&maxStaleness=200",
205				&ClientOptions{
206					err:   errors.New("can not specify tags, max staleness, or hedge with mode primary"),
207					Hosts: []string{"localhost"},
208				},
209			},
210			{
211				"TLS addCertFromFile error",
212				"mongodb://localhost/?ssl=true&sslCertificateAuthorityFile=testdata/doesntexist",
213				&ClientOptions{
214					err:   &os.PathError{Op: "open", Path: "testdata/doesntexist"},
215					Hosts: []string{"localhost"},
216				},
217			},
218			{
219				"TLS ClientCertificateKey",
220				"mongodb://localhost/?ssl=true&sslClientCertificateKeyFile=testdata/doesntexist",
221				&ClientOptions{
222					err:   &os.PathError{Op: "open", Path: "testdata/doesntexist"},
223					Hosts: []string{"localhost"},
224				},
225			},
226			{
227				"AppName",
228				"mongodb://localhost/?appName=awesome-example-application",
229				baseClient().SetAppName("awesome-example-application"),
230			},
231			{
232				"AuthMechanism",
233				"mongodb://localhost/?authMechanism=mongodb-x509",
234				baseClient().SetAuth(Credential{AuthSource: "$external", AuthMechanism: "mongodb-x509"}),
235			},
236			{
237				"AuthMechanismProperties",
238				"mongodb://foo@localhost/?authMechanism=gssapi&authMechanismProperties=SERVICE_NAME:mongodb-fake",
239				baseClient().SetAuth(Credential{
240					AuthSource:              "$external",
241					AuthMechanism:           "gssapi",
242					AuthMechanismProperties: map[string]string{"SERVICE_NAME": "mongodb-fake"},
243					Username:                "foo",
244				}),
245			},
246			{
247				"AuthSource",
248				"mongodb://foo@localhost/?authSource=random-database-example",
249				baseClient().SetAuth(Credential{AuthSource: "random-database-example", Username: "foo"}),
250			},
251			{
252				"Username",
253				"mongodb://foo@localhost/",
254				baseClient().SetAuth(Credential{AuthSource: "admin", Username: "foo"}),
255			},
256			{
257				"Unescaped slash in username",
258				"mongodb:///:pwd@localhost",
259				&ClientOptions{err: internal.WrapErrorf(
260					errors.New("unescaped slash in username"),
261					"error parsing uri",
262				)},
263			},
264			{
265				"Password",
266				"mongodb://foo:bar@localhost/",
267				baseClient().SetAuth(Credential{
268					AuthSource: "admin", Username: "foo",
269					Password: "bar", PasswordSet: true,
270				}),
271			},
272			{
273				"Single character username and password",
274				"mongodb://f:b@localhost/",
275				baseClient().SetAuth(Credential{
276					AuthSource: "admin", Username: "f",
277					Password: "b", PasswordSet: true,
278				}),
279			},
280			{
281				"Connect",
282				"mongodb://localhost/?connect=direct",
283				baseClient().SetDirect(true),
284			},
285			{
286				"ConnectTimeout",
287				"mongodb://localhost/?connectTimeoutms=5000",
288				baseClient().SetConnectTimeout(5 * time.Second),
289			},
290			{
291				"Compressors",
292				"mongodb://localhost/?compressors=zlib,snappy",
293				baseClient().SetCompressors([]string{"zlib", "snappy"}).SetZlibLevel(6),
294			},
295			{
296				"DatabaseNoAuth",
297				"mongodb://localhost/example-database",
298				baseClient(),
299			},
300			{
301				"DatabaseAsDefault",
302				"mongodb://foo@localhost/example-database",
303				baseClient().SetAuth(Credential{AuthSource: "example-database", Username: "foo"}),
304			},
305			{
306				"HeartbeatInterval",
307				"mongodb://localhost/?heartbeatIntervalms=12000",
308				baseClient().SetHeartbeatInterval(12 * time.Second),
309			},
310			{
311				"Hosts",
312				"mongodb://localhost:27017,localhost:27018,localhost:27019/",
313				baseClient().SetHosts([]string{"localhost:27017", "localhost:27018", "localhost:27019"}),
314			},
315			{
316				"LocalThreshold",
317				"mongodb://localhost/?localThresholdMS=200",
318				baseClient().SetLocalThreshold(200 * time.Millisecond),
319			},
320			{
321				"MaxConnIdleTime",
322				"mongodb://localhost/?maxIdleTimeMS=300000",
323				baseClient().SetMaxConnIdleTime(5 * time.Minute),
324			},
325			{
326				"MaxPoolSize",
327				"mongodb://localhost/?maxPoolSize=256",
328				baseClient().SetMaxPoolSize(256),
329			},
330			{
331				"ReadConcern",
332				"mongodb://localhost/?readConcernLevel=linearizable",
333				baseClient().SetReadConcern(readconcern.Linearizable()),
334			},
335			{
336				"ReadPreference",
337				"mongodb://localhost/?readPreference=secondaryPreferred",
338				baseClient().SetReadPreference(readpref.SecondaryPreferred()),
339			},
340			{
341				"ReadPreferenceTagSets",
342				"mongodb://localhost/?readPreference=secondaryPreferred&readPreferenceTags=foo:bar",
343				baseClient().SetReadPreference(readpref.SecondaryPreferred(readpref.WithTags("foo", "bar"))),
344			},
345			{
346				"MaxStaleness",
347				"mongodb://localhost/?readPreference=secondaryPreferred&maxStaleness=250",
348				baseClient().SetReadPreference(readpref.SecondaryPreferred(readpref.WithMaxStaleness(250 * time.Second))),
349			},
350			{
351				"RetryWrites",
352				"mongodb://localhost/?retryWrites=true",
353				baseClient().SetRetryWrites(true),
354			},
355			{
356				"ReplicaSet",
357				"mongodb://localhost/?replicaSet=rs01",
358				baseClient().SetReplicaSet("rs01"),
359			},
360			{
361				"ServerSelectionTimeout",
362				"mongodb://localhost/?serverSelectionTimeoutMS=45000",
363				baseClient().SetServerSelectionTimeout(45 * time.Second),
364			},
365			{
366				"SocketTimeout",
367				"mongodb://localhost/?socketTimeoutMS=15000",
368				baseClient().SetSocketTimeout(15 * time.Second),
369			},
370			{
371				"TLS CACertificate",
372				"mongodb://localhost/?ssl=true&sslCertificateAuthorityFile=testdata/ca.pem",
373				baseClient().SetTLSConfig(&tls.Config{
374					RootCAs: createCertPool(t, "testdata/ca.pem"),
375				}),
376			},
377			{
378				"TLS Insecure",
379				"mongodb://localhost/?ssl=true&sslInsecure=true",
380				baseClient().SetTLSConfig(&tls.Config{InsecureSkipVerify: true}),
381			},
382			{
383				"TLS ClientCertificateKey",
384				"mongodb://localhost/?ssl=true&sslClientCertificateKeyFile=testdata/nopass/certificate.pem",
385				baseClient().SetTLSConfig(&tls.Config{Certificates: make([]tls.Certificate, 1)}),
386			},
387			{
388				"TLS ClientCertificateKey with password",
389				"mongodb://localhost/?ssl=true&sslClientCertificateKeyFile=testdata/certificate.pem&sslClientCertificateKeyPassword=passphrase",
390				baseClient().SetTLSConfig(&tls.Config{Certificates: make([]tls.Certificate, 1)}),
391			},
392			{
393				"TLS Username",
394				"mongodb://localhost/?ssl=true&authMechanism=mongodb-x509&sslClientCertificateKeyFile=testdata/nopass/certificate.pem",
395				baseClient().SetAuth(Credential{
396					AuthMechanism: "mongodb-x509", AuthSource: "$external",
397					Username: `C=US,ST=New York,L=New York City, Inc,O=MongoDB\,OU=WWW`,
398				}),
399			},
400			{
401				"WriteConcern J",
402				"mongodb://localhost/?journal=true",
403				baseClient().SetWriteConcern(writeconcern.New(writeconcern.J(true))),
404			},
405			{
406				"WriteConcern WString",
407				"mongodb://localhost/?w=majority",
408				baseClient().SetWriteConcern(writeconcern.New(writeconcern.WMajority())),
409			},
410			{
411				"WriteConcern W",
412				"mongodb://localhost/?w=3",
413				baseClient().SetWriteConcern(writeconcern.New(writeconcern.W(3))),
414			},
415			{
416				"WriteConcern WTimeout",
417				"mongodb://localhost/?wTimeoutMS=45000",
418				baseClient().SetWriteConcern(writeconcern.New(writeconcern.WTimeout(45 * time.Second))),
419			},
420			{
421				"ZLibLevel",
422				"mongodb://localhost/?zlibCompressionLevel=4",
423				baseClient().SetZlibLevel(4),
424			},
425			{
426				"TLS tlsCertificateFile and tlsPrivateKeyFile",
427				"mongodb://localhost/?tlsCertificateFile=testdata/nopass/cert.pem&tlsPrivateKeyFile=testdata/nopass/key.pem",
428				baseClient().SetTLSConfig(&tls.Config{Certificates: make([]tls.Certificate, 1)}),
429			},
430			{
431				"TLS only tlsCertificateFile",
432				"mongodb://localhost/?tlsCertificateFile=testdata/nopass/cert.pem",
433				&ClientOptions{err: internal.WrapErrorf(
434					errors.New("the tlsPrivateKeyFile URI option must be provided if the tlsCertificateFile option is specified"),
435					"error validating uri",
436				)},
437			},
438			{
439				"TLS only tlsPrivateKeyFile",
440				"mongodb://localhost/?tlsPrivateKeyFile=testdata/nopass/key.pem",
441				&ClientOptions{err: internal.WrapErrorf(
442					errors.New("the tlsCertificateFile URI option must be provided if the tlsPrivateKeyFile option is specified"),
443					"error validating uri",
444				)},
445			},
446			{
447				"TLS tlsCertificateFile and tlsPrivateKeyFile and tlsCertificateKeyFile",
448				"mongodb://localhost/?tlsCertificateFile=testdata/nopass/cert.pem&tlsPrivateKeyFile=testdata/nopass/key.pem&tlsCertificateKeyFile=testdata/nopass/certificate.pem",
449				&ClientOptions{err: internal.WrapErrorf(
450					errors.New("the sslClientCertificateKeyFile/tlsCertificateKeyFile URI option cannot be provided "+
451						"along with tlsCertificateFile or tlsPrivateKeyFile"),
452					"error validating uri",
453				)},
454			},
455			{
456				"disable OCSP endpoint check",
457				"mongodb://localhost/?tlsDisableOCSPEndpointCheck=true",
458				baseClient().SetDisableOCSPEndpointCheck(true),
459			},
460			{
461				"directConnection",
462				"mongodb://localhost/?directConnection=true",
463				baseClient().SetDirect(true),
464			},
465			{
466				"TLS CA file with multiple certificiates",
467				"mongodb://localhost/?tlsCAFile=testdata/ca-with-intermediates.pem",
468				baseClient().SetTLSConfig(&tls.Config{
469					RootCAs: createCertPool(t, "testdata/ca-with-intermediates-first.pem",
470						"testdata/ca-with-intermediates-second.pem", "testdata/ca-with-intermediates-third.pem"),
471				}),
472			},
473			{
474				"TLS empty CA file",
475				"mongodb://localhost/?tlsCAFile=testdata/empty-ca.pem",
476				&ClientOptions{
477					Hosts: []string{"localhost"},
478					err:   errors.New("the specified CA file does not contain any valid certificates"),
479				},
480			},
481			{
482				"TLS CA file with no certificates",
483				"mongodb://localhost/?tlsCAFile=testdata/ca-key.pem",
484				&ClientOptions{
485					Hosts: []string{"localhost"},
486					err:   errors.New("the specified CA file does not contain any valid certificates"),
487				},
488			},
489			{
490				"TLS malformed CA file",
491				"mongodb://localhost/?tlsCAFile=testdata/malformed-ca.pem",
492				&ClientOptions{
493					Hosts: []string{"localhost"},
494					err:   errors.New("the specified CA file does not contain any valid certificates"),
495				},
496			},
497		}
498
499		for _, tc := range testCases {
500			t.Run(tc.name, func(t *testing.T) {
501				result := Client().ApplyURI(tc.uri)
502
503				// Manually add the URI and ConnString to the test expectations to avoid adding them in each test
504				// definition. The ConnString should only be recorded if there was no error while parsing.
505				tc.result.uri = tc.uri
506				cs, err := connstring.ParseAndValidate(tc.uri)
507				if err == nil {
508					tc.result.cs = &cs
509				}
510
511				if diff := cmp.Diff(
512					tc.result, result,
513					cmp.AllowUnexported(ClientOptions{}, readconcern.ReadConcern{}, writeconcern.WriteConcern{}, readpref.ReadPref{}),
514					cmp.Comparer(func(r1, r2 *bsoncodec.Registry) bool { return r1 == r2 }),
515					cmp.Comparer(compareTLSConfig),
516					cmp.Comparer(compareErrors),
517					cmpopts.IgnoreFields(connstring.ConnString{}, "SSLClientCertificateKeyPassword"),
518				); diff != "" {
519					t.Errorf("URI did not apply correctly: (-want +got)\n%s", diff)
520				}
521			})
522		}
523	})
524	t.Run("direct connection validation", func(t *testing.T) {
525		t.Run("multiple hosts", func(t *testing.T) {
526			expectedErr := errors.New("a direct connection cannot be made if multiple hosts are specified")
527
528			testCases := []struct {
529				name string
530				opts *ClientOptions
531			}{
532				{"hosts in URI", Client().ApplyURI("mongodb://localhost,localhost2")},
533				{"hosts in options", Client().SetHosts([]string{"localhost", "localhost2"})},
534			}
535			for _, tc := range testCases {
536				t.Run(tc.name, func(t *testing.T) {
537					err := tc.opts.SetDirect(true).Validate()
538					assert.NotNil(t, err, "expected errror, got nil")
539					assert.Equal(t, expectedErr.Error(), err.Error(), "expected error %v, got %v", expectedErr, err)
540				})
541			}
542		})
543		t.Run("srv", func(t *testing.T) {
544			expectedErr := errors.New("a direct connection cannot be made if an SRV URI is used")
545			// Use a non-SRV URI and manually set the scheme because using an SRV URI would force an SRV lookup.
546			opts := Client().ApplyURI("mongodb://localhost:27017")
547			opts.cs.Scheme = connstring.SchemeMongoDBSRV
548
549			err := opts.SetDirect(true).Validate()
550			assert.NotNil(t, err, "expected errror, got nil")
551			assert.Equal(t, expectedErr.Error(), err.Error(), "expected error %v, got %v", expectedErr, err)
552		})
553	})
554}
555
556func createCertPool(t *testing.T, paths ...string) *x509.CertPool {
557	t.Helper()
558
559	pool := x509.NewCertPool()
560	for _, path := range paths {
561		pool.AddCert(loadCert(t, path))
562	}
563	return pool
564}
565
566func loadCert(t *testing.T, file string) *x509.Certificate {
567	t.Helper()
568
569	data := readFile(t, file)
570	block, _ := pem.Decode(data)
571	cert, err := x509.ParseCertificate(block.Bytes)
572	assert.Nil(t, err, "ParseCertificate error for %s: %v", file, err)
573	return cert
574}
575
576func readFile(t *testing.T, path string) []byte {
577	data, err := ioutil.ReadFile(path)
578	assert.Nil(t, err, "ReadFile error for %s: %v", path, err)
579	return data
580}
581
582type testDialer struct {
583	Num int
584}
585
586func (testDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
587	return nil, nil
588}
589
590func compareTLSConfig(cfg1, cfg2 *tls.Config) bool {
591	if cfg1 == nil && cfg2 == nil {
592		return true
593	}
594
595	if cfg1 == nil || cfg2 == nil {
596		return true
597	}
598
599	if (cfg1.RootCAs == nil && cfg1.RootCAs != nil) || (cfg1.RootCAs != nil && cfg1.RootCAs == nil) {
600		return false
601	}
602
603	if cfg1.RootCAs != nil {
604		cfg1Subjects := cfg1.RootCAs.Subjects()
605		cfg2Subjects := cfg2.RootCAs.Subjects()
606		if len(cfg1Subjects) != len(cfg2Subjects) {
607			return false
608		}
609
610		for idx, firstSubject := range cfg1Subjects {
611			if !bytes.Equal(firstSubject, cfg2Subjects[idx]) {
612				return false
613			}
614		}
615	}
616
617	if len(cfg1.Certificates) != len(cfg2.Certificates) {
618		return false
619	}
620
621	if cfg1.InsecureSkipVerify != cfg2.InsecureSkipVerify {
622		return false
623	}
624
625	return true
626}
627
628func compareErrors(err1, err2 error) bool {
629	if err1 == nil && err2 == nil {
630		return true
631	}
632
633	if err1 == nil || err2 == nil {
634		return false
635	}
636
637	ospe1, ok1 := err1.(*os.PathError)
638	ospe2, ok2 := err2.(*os.PathError)
639	if ok1 && ok2 {
640		return ospe1.Op == ospe2.Op && ospe1.Path == ospe2.Path
641	}
642
643	if err1.Error() != err2.Error() {
644		return false
645	}
646
647	return true
648}
649