1package mssql
2
3import (
4	"context"
5	"database/sql"
6	"errors"
7	"fmt"
8	"strings"
9
10	_ "github.com/denisenkom/go-mssqldb"
11	multierror "github.com/hashicorp/go-multierror"
12	dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
13	"github.com/hashicorp/vault/sdk/database/helper/connutil"
14	"github.com/hashicorp/vault/sdk/database/helper/dbutil"
15	"github.com/hashicorp/vault/sdk/helper/dbtxn"
16	"github.com/hashicorp/vault/sdk/helper/strutil"
17	"github.com/hashicorp/vault/sdk/helper/template"
18)
19
20const (
21	msSQLTypeName = "mssql"
22
23	defaultUserNameTemplate = `{{ printf "v-%s-%s-%s-%s" (.DisplayName | truncate 20) (.RoleName | truncate 20) (random 20) (unix_time) | truncate 128 }}`
24)
25
26var _ dbplugin.Database = &MSSQL{}
27
28// MSSQL is an implementation of Database interface
29type MSSQL struct {
30	*connutil.SQLConnectionProducer
31
32	usernameProducer template.StringTemplate
33}
34
35func New() (interface{}, error) {
36	db := new()
37	// Wrap the plugin with middleware to sanitize errors
38	dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
39
40	return dbType, nil
41}
42
43func new() *MSSQL {
44	connProducer := &connutil.SQLConnectionProducer{}
45	connProducer.Type = msSQLTypeName
46
47	return &MSSQL{
48		SQLConnectionProducer: connProducer,
49	}
50}
51
52// Type returns the TypeName for this backend
53func (m *MSSQL) Type() (string, error) {
54	return msSQLTypeName, nil
55}
56
57func (m *MSSQL) secretValues() map[string]string {
58	return map[string]string{
59		m.Password: "[password]",
60	}
61}
62
63func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) {
64	db, err := m.Connection(ctx)
65	if err != nil {
66		return nil, err
67	}
68
69	return db.(*sql.DB), nil
70}
71
72func (m *MSSQL) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
73	newConf, err := m.SQLConnectionProducer.Init(ctx, req.Config, req.VerifyConnection)
74	if err != nil {
75		return dbplugin.InitializeResponse{}, err
76	}
77
78	usernameTemplate, err := strutil.GetString(req.Config, "username_template")
79	if err != nil {
80		return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve username_template: %w", err)
81	}
82	if usernameTemplate == "" {
83		usernameTemplate = defaultUserNameTemplate
84	}
85
86	up, err := template.NewTemplate(template.Template(usernameTemplate))
87	if err != nil {
88		return dbplugin.InitializeResponse{}, fmt.Errorf("unable to initialize username template: %w", err)
89	}
90	m.usernameProducer = up
91
92	_, err = m.usernameProducer.Generate(dbplugin.UsernameMetadata{})
93	if err != nil {
94		return dbplugin.InitializeResponse{}, fmt.Errorf("invalid username template - did you reference a field that isn't available? : %w", err)
95	}
96
97	resp := dbplugin.InitializeResponse{
98		Config: newConf,
99	}
100	return resp, nil
101}
102
103// NewUser generates the username/password on the underlying MSSQL secret backend as instructed by
104// the statements provided.
105func (m *MSSQL) NewUser(ctx context.Context, req dbplugin.NewUserRequest) (dbplugin.NewUserResponse, error) {
106	m.Lock()
107	defer m.Unlock()
108
109	db, err := m.getConnection(ctx)
110	if err != nil {
111		return dbplugin.NewUserResponse{}, fmt.Errorf("unable to get connection: %w", err)
112	}
113
114	if len(req.Statements.Commands) == 0 {
115		return dbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement
116	}
117
118	username, err := m.usernameProducer.Generate(req.UsernameConfig)
119	if err != nil {
120		return dbplugin.NewUserResponse{}, err
121	}
122
123	expirationStr := req.Expiration.Format("2006-01-02 15:04:05-0700")
124
125	tx, err := db.BeginTx(ctx, nil)
126	if err != nil {
127		return dbplugin.NewUserResponse{}, err
128	}
129	defer tx.Rollback()
130
131	for _, stmt := range req.Statements.Commands {
132		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
133			query = strings.TrimSpace(query)
134			if len(query) == 0 {
135				continue
136			}
137
138			m := map[string]string{
139				"name":       username,
140				"password":   req.Password,
141				"expiration": expirationStr,
142			}
143
144			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
145				return dbplugin.NewUserResponse{}, err
146			}
147		}
148	}
149
150	if err := tx.Commit(); err != nil {
151		return dbplugin.NewUserResponse{}, err
152	}
153
154	resp := dbplugin.NewUserResponse{
155		Username: username,
156	}
157
158	return resp, nil
159}
160
161// DeleteUser attempts to drop the specified user. It will first attempt to disable login,
162// then kill pending connections from that user, and finally drop the user and login from the
163// database instance.
164func (m *MSSQL) DeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) {
165	if len(req.Statements.Commands) == 0 {
166		err := m.revokeUserDefault(ctx, req.Username)
167		return dbplugin.DeleteUserResponse{}, err
168	}
169
170	db, err := m.getConnection(ctx)
171	if err != nil {
172		return dbplugin.DeleteUserResponse{}, fmt.Errorf("unable to get connection: %w", err)
173	}
174
175	merr := &multierror.Error{}
176
177	// Execute each query
178	for _, stmt := range req.Statements.Commands {
179		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
180			query = strings.TrimSpace(query)
181			if len(query) == 0 {
182				continue
183			}
184
185			m := map[string]string{
186				"name": req.Username,
187			}
188			if err := dbtxn.ExecuteDBQuery(ctx, db, m, query); err != nil {
189				merr = multierror.Append(merr, err)
190			}
191		}
192	}
193
194	return dbplugin.DeleteUserResponse{}, merr.ErrorOrNil()
195}
196
197func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
198	// Get connection
199	db, err := m.getConnection(ctx)
200	if err != nil {
201		return err
202	}
203
204	// First disable server login
205	disableStmt, err := db.PrepareContext(ctx, fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
206	if err != nil {
207		return err
208	}
209	defer disableStmt.Close()
210	if _, err := disableStmt.ExecContext(ctx); err != nil {
211		return err
212	}
213
214	// Query for sessions for the login so that we can kill any outstanding
215	// sessions.  There cannot be any active sessions before we drop the logins
216	// This isn't done in a transaction because even if we fail along the way,
217	// we want to remove as much access as possible
218	sessionStmt, err := db.PrepareContext(ctx,
219		"SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = @p1;")
220	if err != nil {
221		return err
222	}
223	defer sessionStmt.Close()
224
225	sessionRows, err := sessionStmt.QueryContext(ctx, username)
226	if err != nil {
227		return err
228	}
229	defer sessionRows.Close()
230
231	var revokeStmts []string
232	for sessionRows.Next() {
233		var sessionID int
234		err = sessionRows.Scan(&sessionID)
235		if err != nil {
236			return err
237		}
238		revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID))
239	}
240
241	// Query for database users using undocumented stored procedure for now since
242	// it is the easiest way to get this information;
243	// we need to drop the database users before we can drop the login and the role
244	// This isn't done in a transaction because even if we fail along the way,
245	// we want to remove as much access as possible
246	stmt, err := db.PrepareContext(ctx, "EXEC master.dbo.sp_msloginmappings @p1;")
247	if err != nil {
248		return err
249	}
250	defer stmt.Close()
251
252	rows, err := stmt.QueryContext(ctx, username)
253	if err != nil {
254		return err
255	}
256	defer rows.Close()
257
258	for rows.Next() {
259		var loginName, dbName, qUsername, aliasName sql.NullString
260		err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName)
261		if err != nil {
262			return err
263		}
264		if !dbName.Valid {
265			continue
266		}
267		revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName.String, username, username))
268	}
269
270	// we do not stop on error, as we want to remove as
271	// many permissions as possible right now
272	var lastStmtError error
273	for _, query := range revokeStmts {
274		if err := dbtxn.ExecuteDBQuery(ctx, db, nil, query); err != nil {
275			lastStmtError = err
276		}
277	}
278
279	// can't drop if not all database users are dropped
280	if rows.Err() != nil {
281		return fmt.Errorf("could not generate sql statements for all rows: %w", rows.Err())
282	}
283	if lastStmtError != nil {
284		return fmt.Errorf("could not perform all sql statements: %w", lastStmtError)
285	}
286
287	// Drop this login
288	stmt, err = db.PrepareContext(ctx, fmt.Sprintf(dropLoginSQL, username, username))
289	if err != nil {
290		return err
291	}
292	defer stmt.Close()
293	if _, err := stmt.ExecContext(ctx); err != nil {
294		return err
295	}
296
297	return nil
298}
299
300func (m *MSSQL) UpdateUser(ctx context.Context, req dbplugin.UpdateUserRequest) (dbplugin.UpdateUserResponse, error) {
301	if req.Password == nil && req.Expiration == nil {
302		return dbplugin.UpdateUserResponse{}, fmt.Errorf("no changes requested")
303	}
304	if req.Password != nil {
305		err := m.updateUserPass(ctx, req.Username, req.Password)
306		return dbplugin.UpdateUserResponse{}, err
307	}
308	// Expiration is a no-op
309	return dbplugin.UpdateUserResponse{}, nil
310}
311
312func (m *MSSQL) updateUserPass(ctx context.Context, username string, changePass *dbplugin.ChangePassword) error {
313	stmts := changePass.Statements.Commands
314	if len(stmts) == 0 {
315		stmts = []string{alterLoginSQL}
316	}
317
318	password := changePass.NewPassword
319
320	if username == "" || password == "" {
321		return errors.New("must provide both username and password")
322	}
323
324	m.Lock()
325	defer m.Unlock()
326
327	db, err := m.getConnection(ctx)
328	if err != nil {
329		return err
330	}
331
332	var exists bool
333
334	err = db.QueryRowContext(ctx, "SELECT 1 FROM master.sys.server_principals where name = N'$1'", username).Scan(&exists)
335
336	if err != nil && err != sql.ErrNoRows {
337		return err
338	}
339
340	tx, err := db.BeginTx(ctx, nil)
341	if err != nil {
342		return err
343	}
344
345	defer func() {
346		_ = tx.Rollback()
347	}()
348
349	for _, stmt := range stmts {
350		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
351			query = strings.TrimSpace(query)
352			if len(query) == 0 {
353				continue
354			}
355
356			m := map[string]string{
357				"name":     username,
358				"username": username,
359				"password": password,
360			}
361			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
362				return fmt.Errorf("failed to execute query: %w", err)
363			}
364		}
365	}
366
367	if err := tx.Commit(); err != nil {
368		return fmt.Errorf("failed to commit transaction: %w", err)
369	}
370
371	return nil
372}
373
374const dropUserSQL = `
375USE [%s]
376IF EXISTS
377  (SELECT name
378   FROM sys.database_principals
379   WHERE name = N'%s')
380BEGIN
381  DROP USER [%s]
382END
383`
384
385const dropLoginSQL = `
386IF EXISTS
387  (SELECT name
388   FROM master.sys.server_principals
389   WHERE name = N'%s')
390BEGIN
391  DROP LOGIN [%s]
392END
393`
394
395const alterLoginSQL = `
396ALTER LOGIN [{{username}}] WITH PASSWORD = '{{password}}'
397`
398