1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package auth_test
8
9import (
10	"encoding/json"
11	"fmt"
12	"io/ioutil"
13	"path"
14	"testing"
15
16	"github.com/stretchr/testify/require"
17	testhelpers "go.mongodb.org/mongo-driver/internal/testutil/helpers"
18	"go.mongodb.org/mongo-driver/mongo/options"
19	"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
20)
21
22type credential struct {
23	Username  string
24	Password  *string
25	Source    string
26	Mechanism string
27	MechProps map[string]interface{} `json:"mechanism_properties"`
28}
29
30type testCase struct {
31	Description string
32	URI         string
33	Valid       bool
34	Credential  *credential
35}
36
37type testContainer struct {
38	Tests []testCase
39}
40
41// Note a test supporting the deprecated gssapiServiceName property was removed from data/auth/auth_tests.json
42const authTestsDir = "../../../../data/auth/"
43
44func runTestsInFile(t *testing.T, dirname string, filename string) {
45	filepath := path.Join(dirname, filename)
46	content, err := ioutil.ReadFile(filepath)
47	require.NoError(t, err)
48
49	var container testContainer
50	require.NoError(t, json.Unmarshal(content, &container))
51
52	// Remove ".json" from filename.
53	filename = filename[:len(filename)-5]
54
55	for _, testCase := range container.Tests {
56		runTest(t, filename, &testCase)
57	}
58}
59
60func runTest(t *testing.T, filename string, test *testCase) {
61	t.Run(test.Description, func(t *testing.T) {
62		opts := options.Client().ApplyURI(test.URI)
63		if test.Valid {
64			require.NoError(t, opts.Validate())
65		} else {
66			require.Error(t, opts.Validate())
67			return
68		}
69
70		if test.Credential == nil {
71			require.Nil(t, opts.Auth)
72			return
73		}
74		require.NotNil(t, opts.Auth)
75		require.Equal(t, test.Credential.Username, opts.Auth.Username)
76
77		if test.Credential.Password == nil {
78			require.False(t, opts.Auth.PasswordSet)
79		} else {
80			require.True(t, opts.Auth.PasswordSet)
81			require.Equal(t, *test.Credential.Password, opts.Auth.Password)
82		}
83
84		require.Equal(t, test.Credential.Source, opts.Auth.AuthSource)
85
86		require.Equal(t, test.Credential.Mechanism, opts.Auth.AuthMechanism)
87
88		if len(test.Credential.MechProps) > 0 {
89			require.Equal(t, mapInterfaceToString(test.Credential.MechProps), opts.Auth.AuthMechanismProperties)
90		} else {
91			require.Equal(t, 0, len(opts.Auth.AuthMechanismProperties))
92		}
93	})
94}
95
96// Convert each interface{} value in the map to a string.
97func mapInterfaceToString(m map[string]interface{}) map[string]string {
98	out := make(map[string]string)
99
100	for key, value := range m {
101		out[key] = fmt.Sprint(value)
102	}
103
104	return out
105}
106
107func verifyMechProperties(t *testing.T, cs connstring.ConnString, mechProps map[string]interface{}) {
108	// Check that all options are present.
109	for key, value := range mechProps {
110		require.Equal(t, value, cs.AuthMechanismProperties[key])
111	}
112}
113
114// Test case for all connection string spec tests.
115func TestAuthSpec(t *testing.T) {
116	for _, file := range testhelpers.FindJSONFilesInDir(t, authTestsDir) {
117		runTestsInFile(t, authTestsDir, file)
118	}
119}
120