1package mssql 2 3import ( 4 "context" 5 "database/sql" 6 "errors" 7 "fmt" 8 "strings" 9 10 _ "github.com/denisenkom/go-mssqldb" 11 multierror "github.com/hashicorp/go-multierror" 12 dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5" 13 "github.com/hashicorp/vault/sdk/database/helper/connutil" 14 "github.com/hashicorp/vault/sdk/database/helper/dbutil" 15 "github.com/hashicorp/vault/sdk/helper/dbtxn" 16 "github.com/hashicorp/vault/sdk/helper/strutil" 17 "github.com/hashicorp/vault/sdk/helper/template" 18) 19 20const ( 21 msSQLTypeName = "mssql" 22 23 defaultUserNameTemplate = `{{ printf "v-%s-%s-%s-%s" (.DisplayName | truncate 20) (.RoleName | truncate 20) (random 20) (unix_time) | truncate 128 }}` 24) 25 26var _ dbplugin.Database = &MSSQL{} 27 28// MSSQL is an implementation of Database interface 29type MSSQL struct { 30 *connutil.SQLConnectionProducer 31 32 usernameProducer template.StringTemplate 33} 34 35func New() (interface{}, error) { 36 db := new() 37 // Wrap the plugin with middleware to sanitize errors 38 dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues) 39 40 return dbType, nil 41} 42 43func new() *MSSQL { 44 connProducer := &connutil.SQLConnectionProducer{} 45 connProducer.Type = msSQLTypeName 46 47 return &MSSQL{ 48 SQLConnectionProducer: connProducer, 49 } 50} 51 52// Type returns the TypeName for this backend 53func (m *MSSQL) Type() (string, error) { 54 return msSQLTypeName, nil 55} 56 57func (m *MSSQL) secretValues() map[string]string { 58 return map[string]string{ 59 m.Password: "[password]", 60 } 61} 62 63func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) { 64 db, err := m.Connection(ctx) 65 if err != nil { 66 return nil, err 67 } 68 69 return db.(*sql.DB), nil 70} 71 72func (m *MSSQL) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) { 73 newConf, err := m.SQLConnectionProducer.Init(ctx, req.Config, req.VerifyConnection) 74 if err != nil { 75 return dbplugin.InitializeResponse{}, err 76 } 77 78 usernameTemplate, err := strutil.GetString(req.Config, "username_template") 79 if err != nil { 80 return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve username_template: %w", err) 81 } 82 if usernameTemplate == "" { 83 usernameTemplate = defaultUserNameTemplate 84 } 85 86 up, err := template.NewTemplate(template.Template(usernameTemplate)) 87 if err != nil { 88 return dbplugin.InitializeResponse{}, fmt.Errorf("unable to initialize username template: %w", err) 89 } 90 m.usernameProducer = up 91 92 _, err = m.usernameProducer.Generate(dbplugin.UsernameMetadata{}) 93 if err != nil { 94 return dbplugin.InitializeResponse{}, fmt.Errorf("invalid username template - did you reference a field that isn't available? : %w", err) 95 } 96 97 resp := dbplugin.InitializeResponse{ 98 Config: newConf, 99 } 100 return resp, nil 101} 102 103// NewUser generates the username/password on the underlying MSSQL secret backend as instructed by 104// the statements provided. 105func (m *MSSQL) NewUser(ctx context.Context, req dbplugin.NewUserRequest) (dbplugin.NewUserResponse, error) { 106 m.Lock() 107 defer m.Unlock() 108 109 db, err := m.getConnection(ctx) 110 if err != nil { 111 return dbplugin.NewUserResponse{}, fmt.Errorf("unable to get connection: %w", err) 112 } 113 114 if len(req.Statements.Commands) == 0 { 115 return dbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement 116 } 117 118 username, err := m.usernameProducer.Generate(req.UsernameConfig) 119 if err != nil { 120 return dbplugin.NewUserResponse{}, err 121 } 122 123 expirationStr := req.Expiration.Format("2006-01-02 15:04:05-0700") 124 125 tx, err := db.BeginTx(ctx, nil) 126 if err != nil { 127 return dbplugin.NewUserResponse{}, err 128 } 129 defer tx.Rollback() 130 131 for _, stmt := range req.Statements.Commands { 132 for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { 133 query = strings.TrimSpace(query) 134 if len(query) == 0 { 135 continue 136 } 137 138 m := map[string]string{ 139 "name": username, 140 "password": req.Password, 141 "expiration": expirationStr, 142 } 143 144 if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { 145 return dbplugin.NewUserResponse{}, err 146 } 147 } 148 } 149 150 if err := tx.Commit(); err != nil { 151 return dbplugin.NewUserResponse{}, err 152 } 153 154 resp := dbplugin.NewUserResponse{ 155 Username: username, 156 } 157 158 return resp, nil 159} 160 161// DeleteUser attempts to drop the specified user. It will first attempt to disable login, 162// then kill pending connections from that user, and finally drop the user and login from the 163// database instance. 164func (m *MSSQL) DeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) { 165 if len(req.Statements.Commands) == 0 { 166 err := m.revokeUserDefault(ctx, req.Username) 167 return dbplugin.DeleteUserResponse{}, err 168 } 169 170 db, err := m.getConnection(ctx) 171 if err != nil { 172 return dbplugin.DeleteUserResponse{}, fmt.Errorf("unable to get connection: %w", err) 173 } 174 175 merr := &multierror.Error{} 176 177 // Execute each query 178 for _, stmt := range req.Statements.Commands { 179 for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { 180 query = strings.TrimSpace(query) 181 if len(query) == 0 { 182 continue 183 } 184 185 m := map[string]string{ 186 "name": req.Username, 187 } 188 if err := dbtxn.ExecuteDBQuery(ctx, db, m, query); err != nil { 189 merr = multierror.Append(merr, err) 190 } 191 } 192 } 193 194 return dbplugin.DeleteUserResponse{}, merr.ErrorOrNil() 195} 196 197func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error { 198 // Get connection 199 db, err := m.getConnection(ctx) 200 if err != nil { 201 return err 202 } 203 204 // First disable server login 205 disableStmt, err := db.PrepareContext(ctx, fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username)) 206 if err != nil { 207 return err 208 } 209 defer disableStmt.Close() 210 if _, err := disableStmt.ExecContext(ctx); err != nil { 211 return err 212 } 213 214 // Query for sessions for the login so that we can kill any outstanding 215 // sessions. There cannot be any active sessions before we drop the logins 216 // This isn't done in a transaction because even if we fail along the way, 217 // we want to remove as much access as possible 218 sessionStmt, err := db.PrepareContext(ctx, 219 "SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = @p1;") 220 if err != nil { 221 return err 222 } 223 defer sessionStmt.Close() 224 225 sessionRows, err := sessionStmt.QueryContext(ctx, username) 226 if err != nil { 227 return err 228 } 229 defer sessionRows.Close() 230 231 var revokeStmts []string 232 for sessionRows.Next() { 233 var sessionID int 234 err = sessionRows.Scan(&sessionID) 235 if err != nil { 236 return err 237 } 238 revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID)) 239 } 240 241 // Query for database users using undocumented stored procedure for now since 242 // it is the easiest way to get this information; 243 // we need to drop the database users before we can drop the login and the role 244 // This isn't done in a transaction because even if we fail along the way, 245 // we want to remove as much access as possible 246 stmt, err := db.PrepareContext(ctx, "EXEC master.dbo.sp_msloginmappings @p1;") 247 if err != nil { 248 return err 249 } 250 defer stmt.Close() 251 252 rows, err := stmt.QueryContext(ctx, username) 253 if err != nil { 254 return err 255 } 256 defer rows.Close() 257 258 for rows.Next() { 259 var loginName, dbName, qUsername, aliasName sql.NullString 260 err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName) 261 if err != nil { 262 return err 263 } 264 if !dbName.Valid { 265 continue 266 } 267 revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName.String, username, username)) 268 } 269 270 // we do not stop on error, as we want to remove as 271 // many permissions as possible right now 272 var lastStmtError error 273 for _, query := range revokeStmts { 274 if err := dbtxn.ExecuteDBQuery(ctx, db, nil, query); err != nil { 275 lastStmtError = err 276 } 277 } 278 279 // can't drop if not all database users are dropped 280 if rows.Err() != nil { 281 return fmt.Errorf("could not generate sql statements for all rows: %w", rows.Err()) 282 } 283 if lastStmtError != nil { 284 return fmt.Errorf("could not perform all sql statements: %w", lastStmtError) 285 } 286 287 // Drop this login 288 stmt, err = db.PrepareContext(ctx, fmt.Sprintf(dropLoginSQL, username, username)) 289 if err != nil { 290 return err 291 } 292 defer stmt.Close() 293 if _, err := stmt.ExecContext(ctx); err != nil { 294 return err 295 } 296 297 return nil 298} 299 300func (m *MSSQL) UpdateUser(ctx context.Context, req dbplugin.UpdateUserRequest) (dbplugin.UpdateUserResponse, error) { 301 if req.Password == nil && req.Expiration == nil { 302 return dbplugin.UpdateUserResponse{}, fmt.Errorf("no changes requested") 303 } 304 if req.Password != nil { 305 err := m.updateUserPass(ctx, req.Username, req.Password) 306 return dbplugin.UpdateUserResponse{}, err 307 } 308 // Expiration is a no-op 309 return dbplugin.UpdateUserResponse{}, nil 310} 311 312func (m *MSSQL) updateUserPass(ctx context.Context, username string, changePass *dbplugin.ChangePassword) error { 313 stmts := changePass.Statements.Commands 314 if len(stmts) == 0 { 315 stmts = []string{alterLoginSQL} 316 } 317 318 password := changePass.NewPassword 319 320 if username == "" || password == "" { 321 return errors.New("must provide both username and password") 322 } 323 324 m.Lock() 325 defer m.Unlock() 326 327 db, err := m.getConnection(ctx) 328 if err != nil { 329 return err 330 } 331 332 var exists bool 333 334 err = db.QueryRowContext(ctx, "SELECT 1 FROM master.sys.server_principals where name = N'$1'", username).Scan(&exists) 335 336 if err != nil && err != sql.ErrNoRows { 337 return err 338 } 339 340 tx, err := db.BeginTx(ctx, nil) 341 if err != nil { 342 return err 343 } 344 345 defer func() { 346 _ = tx.Rollback() 347 }() 348 349 for _, stmt := range stmts { 350 for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { 351 query = strings.TrimSpace(query) 352 if len(query) == 0 { 353 continue 354 } 355 356 m := map[string]string{ 357 "name": username, 358 "username": username, 359 "password": password, 360 } 361 if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { 362 return fmt.Errorf("failed to execute query: %w", err) 363 } 364 } 365 } 366 367 if err := tx.Commit(); err != nil { 368 return fmt.Errorf("failed to commit transaction: %w", err) 369 } 370 371 return nil 372} 373 374const dropUserSQL = ` 375USE [%s] 376IF EXISTS 377 (SELECT name 378 FROM sys.database_principals 379 WHERE name = N'%s') 380BEGIN 381 DROP USER [%s] 382END 383` 384 385const dropLoginSQL = ` 386IF EXISTS 387 (SELECT name 388 FROM master.sys.server_principals 389 WHERE name = N'%s') 390BEGIN 391 DROP LOGIN [%s] 392END 393` 394 395const alterLoginSQL = ` 396ALTER LOGIN [{{username}}] WITH PASSWORD = '{{password}}' 397` 398