1package mssql
2
3import (
4	"context"
5	"database/sql"
6	"fmt"
7	"reflect"
8	"regexp"
9	"strings"
10	"testing"
11	"time"
12
13	mssqlhelper "github.com/hashicorp/vault/helper/testhelpers/mssql"
14	dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
15	dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
16	"github.com/hashicorp/vault/sdk/helper/dbtxn"
17)
18
19func TestInitialize(t *testing.T) {
20	cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
21	defer cleanup()
22
23	type testCase struct {
24		req dbplugin.InitializeRequest
25	}
26
27	tests := map[string]testCase{
28		"happy path": {
29			req: dbplugin.InitializeRequest{
30				Config: map[string]interface{}{
31					"connection_url": connURL,
32				},
33				VerifyConnection: true,
34			},
35		},
36		"max_open_connections set": {
37			dbplugin.InitializeRequest{
38				Config: map[string]interface{}{
39					"connection_url":       connURL,
40					"max_open_connections": "5",
41				},
42				VerifyConnection: true,
43			},
44		},
45	}
46
47	for name, test := range tests {
48		t.Run(name, func(t *testing.T) {
49			db := new()
50			dbtesting.AssertInitialize(t, db, test.req)
51			defer dbtesting.AssertClose(t, db)
52
53			if !db.Initialized {
54				t.Fatal("Database should be initialized")
55			}
56		})
57	}
58}
59
60func TestNewUser(t *testing.T) {
61	cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
62	defer cleanup()
63
64	type testCase struct {
65		usernameTemplate string
66		req              dbplugin.NewUserRequest
67		usernameRegex    string
68		expectErr        bool
69		assertUser       func(t testing.TB, connURL, username, password string)
70	}
71
72	tests := map[string]testCase{
73		"no creation statements": {
74			req: dbplugin.NewUserRequest{
75				UsernameConfig: dbplugin.UsernameMetadata{
76					DisplayName: "test",
77					RoleName:    "test",
78				},
79				Statements: dbplugin.Statements{},
80				Password:   "AG4qagho-dsvZ",
81				Expiration: time.Now().Add(1 * time.Second),
82			},
83			usernameRegex: "^$",
84			expectErr:     true,
85			assertUser:    assertCredsDoNotExist,
86		},
87		"with creation statements": {
88			req: dbplugin.NewUserRequest{
89				UsernameConfig: dbplugin.UsernameMetadata{
90					DisplayName: "test",
91					RoleName:    "test",
92				},
93				Statements: dbplugin.Statements{
94					Commands: []string{testMSSQLRole},
95				},
96				Password:   "AG4qagho-dsvZ",
97				Expiration: time.Now().Add(1 * time.Second),
98			},
99			usernameRegex: "^v-test-test-[a-zA-Z0-9]{20}-[0-9]{10}$",
100			expectErr:     false,
101			assertUser:    assertCredsExist,
102		},
103		"custom username template": {
104			usernameTemplate: "{{random 10}}_{{.RoleName}}.{{.DisplayName | sha256}}",
105			req: dbplugin.NewUserRequest{
106				UsernameConfig: dbplugin.UsernameMetadata{
107					DisplayName: "tokenwithlotsofextracharactershere",
108					RoleName:    "myrolenamewithlotsofextracharacters",
109				},
110				Statements: dbplugin.Statements{
111					Commands: []string{testMSSQLRole},
112				},
113				Password:   "AG4qagho-dsvZ",
114				Expiration: time.Now().Add(1 * time.Second),
115			},
116			usernameRegex: "^[a-zA-Z0-9]{10}_myrolenamewithlotsofextracharacters.80d15d22dba29ddbd4994f8009b5ff7b17922c267eb49fb805a9488bd55d11f9$",
117			expectErr:     false,
118			assertUser:    assertCredsExist,
119		},
120	}
121
122	for name, test := range tests {
123		t.Run(name, func(t *testing.T) {
124			usernameRe, err := regexp.Compile(test.usernameRegex)
125			if err != nil {
126				t.Fatalf("failed to compile username regex %q: %s", test.usernameRegex, err)
127			}
128
129			initReq := dbplugin.InitializeRequest{
130				Config: map[string]interface{}{
131					"connection_url":    connURL,
132					"username_template": test.usernameTemplate,
133				},
134				VerifyConnection: true,
135			}
136
137			db := new()
138			dbtesting.AssertInitialize(t, db, initReq)
139			defer dbtesting.AssertClose(t, db)
140
141			createResp, err := db.NewUser(context.Background(), test.req)
142			if test.expectErr && err == nil {
143				t.Fatalf("err expected, got nil")
144			}
145			if !test.expectErr && err != nil {
146				t.Fatalf("no error expected, got: %s", err)
147			}
148
149			if !usernameRe.MatchString(createResp.Username) {
150				t.Fatalf("Generated username %q did not match regex %q", createResp.Username, test.usernameRegex)
151			}
152
153			// Protect against future fields that aren't specified
154			expectedResp := dbplugin.NewUserResponse{
155				Username: createResp.Username,
156			}
157			if !reflect.DeepEqual(createResp, expectedResp) {
158				t.Fatalf("Fields missing from expected response: Actual: %#v", createResp)
159			}
160
161			test.assertUser(t, connURL, createResp.Username, test.req.Password)
162		})
163	}
164}
165
166func TestUpdateUser_password(t *testing.T) {
167	type testCase struct {
168		req              dbplugin.UpdateUserRequest
169		expectErr        bool
170		expectedPassword string
171	}
172
173	dbUser := "vaultuser"
174	initPassword := "p4$sw0rd"
175
176	tests := map[string]testCase{
177		"missing password": {
178			req: dbplugin.UpdateUserRequest{
179				Username: dbUser,
180				Password: &dbplugin.ChangePassword{
181					NewPassword: "",
182					Statements:  dbplugin.Statements{},
183				},
184			},
185			expectErr:        true,
186			expectedPassword: initPassword,
187		},
188		"empty rotation statements": {
189			req: dbplugin.UpdateUserRequest{
190				Username: dbUser,
191				Password: &dbplugin.ChangePassword{
192					NewPassword: "N90gkKLy8$angf",
193					Statements:  dbplugin.Statements{},
194				},
195			},
196			expectErr:        false,
197			expectedPassword: "N90gkKLy8$angf",
198		},
199		"username rotation": {
200			req: dbplugin.UpdateUserRequest{
201				Username: dbUser,
202				Password: &dbplugin.ChangePassword{
203					NewPassword: "N90gkKLy8$angf",
204					Statements: dbplugin.Statements{
205						Commands: []string{
206							"ALTER LOGIN [{{username}}] WITH PASSWORD = '{{password}}'",
207						},
208					},
209				},
210			},
211			expectErr:        false,
212			expectedPassword: "N90gkKLy8$angf",
213		},
214		"bad statements": {
215			req: dbplugin.UpdateUserRequest{
216				Username: dbUser,
217				Password: &dbplugin.ChangePassword{
218					NewPassword: "N90gkKLy8$angf",
219					Statements: dbplugin.Statements{
220						Commands: []string{
221							"ahosh98asjdffs",
222						},
223					},
224				},
225			},
226			expectErr:        true,
227			expectedPassword: initPassword,
228		},
229	}
230
231	for name, test := range tests {
232		t.Run(name, func(t *testing.T) {
233			cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
234			defer cleanup()
235
236			initReq := dbplugin.InitializeRequest{
237				Config: map[string]interface{}{
238					"connection_url": connURL,
239				},
240				VerifyConnection: true,
241			}
242
243			db := new()
244			dbtesting.AssertInitialize(t, db, initReq)
245			defer dbtesting.AssertClose(t, db)
246
247			createTestMSSQLUser(t, connURL, dbUser, initPassword, testMSSQLLogin)
248
249			assertCredsExist(t, connURL, dbUser, initPassword)
250
251			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
252			defer cancel()
253			updateResp, err := db.UpdateUser(ctx, test.req)
254			if test.expectErr && err == nil {
255				t.Fatalf("err expected, got nil")
256			}
257			if !test.expectErr && err != nil {
258				t.Fatalf("no error expected, got: %s", err)
259			}
260
261			// Protect against future fields that aren't specified
262			expectedResp := dbplugin.UpdateUserResponse{}
263			if !reflect.DeepEqual(updateResp, expectedResp) {
264				t.Fatalf("Fields missing from expected response: Actual: %#v", updateResp)
265			}
266
267			assertCredsExist(t, connURL, dbUser, test.expectedPassword)
268		})
269	}
270}
271
272func TestDeleteUser(t *testing.T) {
273	cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
274	defer cleanup()
275
276	dbUser := "vaultuser"
277	initPassword := "p4$sw0rd"
278
279	initReq := dbplugin.InitializeRequest{
280		Config: map[string]interface{}{
281			"connection_url": connURL,
282		},
283		VerifyConnection: true,
284	}
285
286	db := new()
287	dbtesting.AssertInitialize(t, db, initReq)
288	defer dbtesting.AssertClose(t, db)
289
290	createTestMSSQLUser(t, connURL, dbUser, initPassword, testMSSQLLogin)
291
292	assertCredsExist(t, connURL, dbUser, initPassword)
293
294	deleteReq := dbplugin.DeleteUserRequest{
295		Username: dbUser,
296	}
297
298	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
299	defer cancel()
300	deleteResp, err := db.DeleteUser(ctx, deleteReq)
301	if err != nil {
302		t.Fatalf("Failed to delete user: %s", err)
303	}
304
305	// Protect against future fields that aren't specified
306	expectedResp := dbplugin.DeleteUserResponse{}
307	if !reflect.DeepEqual(deleteResp, expectedResp) {
308		t.Fatalf("Fields missing from expected response: Actual: %#v", deleteResp)
309	}
310
311	assertCredsDoNotExist(t, connURL, dbUser, initPassword)
312}
313
314func assertCredsExist(t testing.TB, connURL, username, password string) {
315	t.Helper()
316	err := testCredsExist(connURL, username, password)
317	if err != nil {
318		t.Fatalf("Unable to log in as %q: %s", username, err)
319	}
320}
321
322func assertCredsDoNotExist(t testing.TB, connURL, username, password string) {
323	t.Helper()
324	err := testCredsExist(connURL, username, password)
325	if err == nil {
326		t.Fatalf("Able to log in when it shouldn't")
327	}
328}
329
330func testCredsExist(connURL, username, password string) error {
331	// Log in with the new creds
332	parts := strings.Split(connURL, "@")
333	connURL = fmt.Sprintf("sqlserver://%s:%s@%s", username, password, parts[1])
334	db, err := sql.Open("mssql", connURL)
335	if err != nil {
336		return err
337	}
338	defer db.Close()
339	return db.Ping()
340}
341
342func createTestMSSQLUser(t *testing.T, connURL string, username, password, query string) {
343	db, err := sql.Open("mssql", connURL)
344	defer db.Close()
345	if err != nil {
346		t.Fatal(err)
347	}
348
349	// Start a transaction
350	ctx := context.Background()
351	tx, err := db.BeginTx(ctx, nil)
352	if err != nil {
353		t.Fatal(err)
354	}
355	defer func() {
356		_ = tx.Rollback()
357	}()
358
359	m := map[string]string{
360		"name":     username,
361		"password": password,
362	}
363	if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
364		t.Fatal(err)
365	}
366	// Commit the transaction
367	if err := tx.Commit(); err != nil {
368		t.Fatal(err)
369	}
370}
371
372const testMSSQLRole = `
373CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';
374CREATE USER [{{name}}] FOR LOGIN [{{name}}];
375GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];`
376
377const testMSSQLDrop = `
378DROP USER [{{name}}];
379DROP LOGIN [{{name}}];
380`
381
382const testMSSQLLogin = `
383CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';
384`
385