1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package integration
8
9import (
10	"context"
11	"io/ioutil"
12	"os"
13	"path"
14	"runtime"
15	"strings"
16	"testing"
17
18	"go.mongodb.org/mongo-driver/bson"
19	"go.mongodb.org/mongo-driver/internal/testutil/assert"
20	"go.mongodb.org/mongo-driver/mongo/description"
21	"go.mongodb.org/mongo-driver/mongo/integration/mtest"
22	"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
23	"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
24)
25
26const (
27	seedlistDiscoveryTestsDir = "../../data/initial-dns-seedlist-discovery"
28)
29
30type seedlistTest struct {
31	URI     string   `bson:"uri"`
32	Seeds   []string `bson:"seeds"`
33	Hosts   []string `bson:"hosts"`
34	Error   bool     `bson:"error"`
35	Options bson.Raw `bson:"options"`
36}
37
38func TestInitialDNSSeedlistDiscoverySpec(t *testing.T) {
39	mtOpts := mtest.NewOptions().Topologies(mtest.ReplicaSet).CreateClient(false)
40	mt := mtest.New(t, mtOpts)
41	defer mt.Close()
42
43	for _, file := range jsonFilesInDir(mt, seedlistDiscoveryTestsDir) {
44		mt.RunOpts(file, noClientOpts, func(mt *mtest.T) {
45			runSeedlistDiscoveryTest(mt, path.Join(seedlistDiscoveryTestsDir, file))
46		})
47	}
48}
49
50func runSeedlistDiscoveryTest(mt *mtest.T, file string) {
51	content, err := ioutil.ReadFile(file)
52	assert.Nil(mt, err, "ReadFile error for %v: %v", file, err)
53
54	var test seedlistTest
55	err = bson.UnmarshalExtJSONWithRegistry(specTestRegistry, content, false, &test)
56	assert.Nil(mt, err, "UnmarshalExtJSONWithRegistry error: %v", err)
57
58	if runtime.GOOS == "windows" && strings.HasSuffix(file, "/two-txt-records.json") {
59		mt.Skip("skipping to avoid windows multiple TXT record lookup bug")
60	}
61	if strings.HasPrefix(runtime.Version(), "go1.11") && strings.HasSuffix(file, "/one-txt-record-multiple-strings.json") {
62		mt.Skip("skipping to avoid go1.11 problem with multiple strings in one TXT record")
63	}
64
65	cs, err := connstring.ParseAndValidate(test.URI)
66	if test.Error {
67		assert.NotNil(mt, err, "expected URI parsing error, got nil")
68		return
69	}
70
71	assert.Nil(mt, err, "ParseAndValidate error: %v", err)
72	assert.Equal(mt, connstring.SchemeMongoDBSRV, cs.Scheme,
73		"expected scheme %v, got %v", connstring.SchemeMongoDBSRV, cs.Scheme)
74
75	// DNS records may be out of order from the test file's ordering
76	expectedSeedlist := buildSet(test.Seeds)
77	actualSeedlist := buildSet(cs.Hosts)
78	assert.Equal(mt, expectedSeedlist, actualSeedlist, "expected seedlist %v, got %v", expectedSeedlist, actualSeedlist)
79	verifyConnstringOptions(mt, test.Options, cs)
80	setSSLSettings(mt, &cs, test)
81
82	// make a topology from the options
83	topo, err := topology.New(topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }))
84	assert.Nil(mt, err, "topology.New error: %v", err)
85	err = topo.Connect()
86	assert.Nil(mt, err, "topology.Connect error: %v", err)
87	defer func() { _ = topo.Disconnect(mtest.Background) }()
88
89	for _, host := range test.Hosts {
90		_, err := getServerByAddress(host, topo)
91		assert.Nil(mt, err, "did not find host %v", host)
92	}
93}
94
95func buildSet(list []string) map[string]struct{} {
96	set := make(map[string]struct{})
97	for _, s := range list {
98		set[s] = struct{}{}
99	}
100	return set
101}
102
103func verifyConnstringOptions(mt *mtest.T, expected bson.Raw, cs connstring.ConnString) {
104	mt.Helper()
105
106	elems, _ := expected.Elements()
107	for _, elem := range elems {
108		key := elem.Key()
109		opt := elem.Value()
110
111		switch key {
112		case "replicaSet":
113			rs := opt.StringValue()
114			assert.Equal(mt, rs, cs.ReplicaSet, "expected replicaSet value %v, got %v", rs, cs.ReplicaSet)
115		case "ssl":
116			ssl := opt.Boolean()
117			assert.Equal(mt, ssl, cs.SSL, "expected ssl value %v, got %v", ssl, cs.SSL)
118		case "authSource":
119			source := opt.StringValue()
120			assert.Equal(mt, source, cs.AuthSource, "expected auth source value %v, got %v", source, cs.AuthSource)
121		case "directConnection":
122			dc := opt.Boolean()
123			assert.True(mt, cs.DirectConnectionSet, "expected cs.DirectConnectionSet to be true, got false")
124			assert.Equal(mt, dc, cs.DirectConnection, "expected cs.DirectConnection to be %v, got %v", dc, cs.DirectConnection)
125		default:
126			mt.Fatalf("unrecognized connstring option %v", key)
127		}
128	}
129}
130
131// Because the Go driver tests can be run either against a server with SSL enabled or without, a
132// number of configurations have to be checked to ensure that the SRV tests are run properly.
133//
134// First, the "ssl" option in the JSON test description has to be checked. If this option is not
135// present, we assume that the test will assert an error, so we proceed with the test as normal.
136// If the option is false, then we skip the test if the server is running with SSL enabled.
137// If the option is true, then we skip the test if the server is running without SSL enabled; if
138// the server is running with SSL enabled, then we manually set the necessary SSL options in the
139// connection string.
140func setSSLSettings(mt *mtest.T, cs *connstring.ConnString, test seedlistTest) {
141	ssl, err := test.Options.LookupErr("ssl")
142	if err != nil {
143		// No "ssl" option is specified
144		return
145	}
146	testCaseExpectsSSL := ssl.Boolean()
147	envSSL := os.Getenv("SSL") == "ssl"
148
149	// Skip non-SSL tests if the server is running with SSL.
150	if !testCaseExpectsSSL && envSSL {
151		mt.Skip("skipping test that does not expect ssl in an ssl environment")
152	}
153
154	// Skip SSL tests if the server is running without SSL.
155	if testCaseExpectsSSL && !envSSL {
156		mt.Skip("skipping test that expectes ssl in a non-ssl environment")
157	}
158
159	// If SSL tests are running, set the CA file.
160	if testCaseExpectsSSL && envSSL {
161		cs.SSLInsecure = true
162	}
163}
164
165func getServerByAddress(address string, topo *topology.Topology) (description.Server, error) {
166	selectByName := description.ServerSelectorFunc(func(_ description.Topology, servers []description.Server) ([]description.Server, error) {
167		for _, s := range servers {
168			if s.Addr.String() == address {
169				return []description.Server{s}, nil
170			}
171		}
172		return []description.Server{}, nil
173	})
174
175	selectedServer, err := topo.SelectServerLegacy(context.Background(), selectByName)
176	if err != nil {
177		return description.Server{}, err
178	}
179	return selectedServer.Server.Description(), nil
180}
181