1package postgresql 2 3import ( 4 "context" 5 "database/sql" 6 "errors" 7 "fmt" 8 "strings" 9 "time" 10 11 "github.com/hashicorp/errwrap" 12 "github.com/hashicorp/vault/api" 13 "github.com/hashicorp/vault/sdk/database/dbplugin" 14 "github.com/hashicorp/vault/sdk/database/helper/connutil" 15 "github.com/hashicorp/vault/sdk/database/helper/credsutil" 16 "github.com/hashicorp/vault/sdk/database/helper/dbutil" 17 "github.com/hashicorp/vault/sdk/helper/dbtxn" 18 "github.com/hashicorp/vault/sdk/helper/strutil" 19 "github.com/lib/pq" 20) 21 22const ( 23 postgreSQLTypeName = "postgres" 24 defaultPostgresRenewSQL = ` 25ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}'; 26` 27 defaultPostgresRotateRootCredentialsSQL = ` 28ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}'; 29` 30 31 defaultPostgresRotateCredentialsSQL = ` 32ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}'; 33` 34) 35 36var _ dbplugin.Database = &PostgreSQL{} 37 38// New implements builtinplugins.BuiltinFactory 39func New() (interface{}, error) { 40 db := new() 41 // Wrap the plugin with middleware to sanitize errors 42 dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) 43 return dbType, nil 44} 45 46func new() *PostgreSQL { 47 connProducer := &connutil.SQLConnectionProducer{} 48 connProducer.Type = postgreSQLTypeName 49 50 credsProducer := &credsutil.SQLCredentialsProducer{ 51 DisplayNameLen: 8, 52 RoleNameLen: 8, 53 UsernameLen: 63, 54 Separator: "-", 55 } 56 57 db := &PostgreSQL{ 58 SQLConnectionProducer: connProducer, 59 CredentialsProducer: credsProducer, 60 } 61 62 return db 63} 64 65// Run instantiates a PostgreSQL object, and runs the RPC server for the plugin 66func Run(apiTLSConfig *api.TLSConfig) error { 67 dbType, err := New() 68 if err != nil { 69 return err 70 } 71 72 dbplugin.Serve(dbType.(dbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig)) 73 74 return nil 75} 76 77type PostgreSQL struct { 78 *connutil.SQLConnectionProducer 79 credsutil.CredentialsProducer 80} 81 82func (p *PostgreSQL) Type() (string, error) { 83 return postgreSQLTypeName, nil 84} 85 86func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) { 87 db, err := p.Connection(ctx) 88 if err != nil { 89 return nil, err 90 } 91 92 return db.(*sql.DB), nil 93} 94 95// SetCredentials uses provided information to set/create a user in the 96// database. Unlike CreateUser, this method requires a username be provided and 97// uses the name given, instead of generating a name. This is used for creating 98// and setting the password of static accounts, as well as rolling back 99// passwords in the database in the event an updated database fails to save in 100// Vault's storage. 101func (p *PostgreSQL) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) { 102 if len(statements.Rotation) == 0 { 103 return "", "", errors.New("empty rotation statements") 104 } 105 106 username = staticUser.Username 107 password = staticUser.Password 108 if username == "" || password == "" { 109 return "", "", errors.New("must provide both username and password") 110 } 111 112 // Grab the lock 113 p.Lock() 114 defer p.Unlock() 115 116 // Get the connection 117 db, err := p.getConnection(ctx) 118 if err != nil { 119 return "", "", err 120 } 121 122 // Check if the role exists 123 var exists bool 124 err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) 125 if err != nil && err != sql.ErrNoRows { 126 return "", "", err 127 } 128 129 // Vault requires the database user already exist, and that the credentials 130 // used to execute the rotation statements has sufficient privileges. 131 stmts := statements.Rotation 132 133 // Start a transaction 134 tx, err := db.BeginTx(ctx, nil) 135 if err != nil { 136 return "", "", err 137 } 138 defer func() { 139 _ = tx.Rollback() 140 }() 141 142 // Execute each query 143 for _, stmt := range stmts { 144 for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { 145 query = strings.TrimSpace(query) 146 if len(query) == 0 { 147 continue 148 } 149 150 m := map[string]string{ 151 "name": staticUser.Username, 152 "password": password, 153 } 154 if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { 155 return "", "", err 156 } 157 } 158 } 159 160 // Commit the transaction 161 if err := tx.Commit(); err != nil { 162 return "", "", err 163 } 164 165 return username, password, nil 166} 167 168func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { 169 statements = dbutil.StatementCompatibilityHelper(statements) 170 171 if len(statements.Creation) == 0 { 172 return "", "", dbutil.ErrEmptyCreationStatement 173 } 174 175 // Grab the lock 176 p.Lock() 177 defer p.Unlock() 178 179 username, err = p.GenerateUsername(usernameConfig) 180 if err != nil { 181 return "", "", err 182 } 183 184 password, err = p.GeneratePassword() 185 if err != nil { 186 return "", "", err 187 } 188 189 expirationStr, err := p.GenerateExpiration(expiration) 190 if err != nil { 191 return "", "", err 192 } 193 194 // Get the connection 195 db, err := p.getConnection(ctx) 196 if err != nil { 197 return "", "", err 198 } 199 200 // Start a transaction 201 tx, err := db.BeginTx(ctx, nil) 202 if err != nil { 203 return "", "", err 204 205 } 206 defer func() { 207 tx.Rollback() 208 }() 209 210 // Execute each query 211 for _, stmt := range statements.Creation { 212 for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { 213 query = strings.TrimSpace(query) 214 if len(query) == 0 { 215 continue 216 } 217 218 m := map[string]string{ 219 "name": username, 220 "password": password, 221 "expiration": expirationStr, 222 } 223 if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { 224 return "", "", err 225 } 226 } 227 } 228 229 // Commit the transaction 230 if err := tx.Commit(); err != nil { 231 return "", "", err 232 } 233 234 return username, password, nil 235} 236 237func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { 238 p.Lock() 239 defer p.Unlock() 240 241 statements = dbutil.StatementCompatibilityHelper(statements) 242 243 renewStmts := statements.Renewal 244 if len(renewStmts) == 0 { 245 renewStmts = []string{defaultPostgresRenewSQL} 246 } 247 248 db, err := p.getConnection(ctx) 249 if err != nil { 250 return err 251 } 252 253 tx, err := db.BeginTx(ctx, nil) 254 if err != nil { 255 return err 256 } 257 defer func() { 258 tx.Rollback() 259 }() 260 261 expirationStr, err := p.GenerateExpiration(expiration) 262 if err != nil { 263 return err 264 } 265 266 for _, stmt := range renewStmts { 267 for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { 268 query = strings.TrimSpace(query) 269 if len(query) == 0 { 270 continue 271 } 272 273 m := map[string]string{ 274 "name": username, 275 "expiration": expirationStr, 276 } 277 if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { 278 return err 279 } 280 } 281 } 282 283 return tx.Commit() 284} 285 286func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { 287 // Grab the lock 288 p.Lock() 289 defer p.Unlock() 290 291 statements = dbutil.StatementCompatibilityHelper(statements) 292 293 if len(statements.Revocation) == 0 { 294 return p.defaultRevokeUser(ctx, username) 295 } 296 297 return p.customRevokeUser(ctx, username, statements.Revocation) 298} 299 300func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revocationStmts []string) error { 301 db, err := p.getConnection(ctx) 302 if err != nil { 303 return err 304 } 305 306 tx, err := db.BeginTx(ctx, nil) 307 if err != nil { 308 return err 309 } 310 defer func() { 311 tx.Rollback() 312 }() 313 314 for _, stmt := range revocationStmts { 315 for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { 316 query = strings.TrimSpace(query) 317 if len(query) == 0 { 318 continue 319 } 320 321 m := map[string]string{ 322 "name": username, 323 } 324 if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { 325 return err 326 } 327 } 328 } 329 330 return tx.Commit() 331} 332 333func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error { 334 db, err := p.getConnection(ctx) 335 if err != nil { 336 return err 337 } 338 339 // Check if the role exists 340 var exists bool 341 err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) 342 if err != nil && err != sql.ErrNoRows { 343 return err 344 } 345 346 if !exists { 347 return nil 348 } 349 350 // Query for permissions; we need to revoke permissions before we can drop 351 // the role 352 // This isn't done in a transaction because even if we fail along the way, 353 // we want to remove as much access as possible 354 stmt, err := db.PrepareContext(ctx, "SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;") 355 if err != nil { 356 return err 357 } 358 defer stmt.Close() 359 360 rows, err := stmt.QueryContext(ctx, username) 361 if err != nil { 362 return err 363 } 364 defer rows.Close() 365 366 const initialNumRevocations = 16 367 revocationStmts := make([]string, 0, initialNumRevocations) 368 for rows.Next() { 369 var schema string 370 err = rows.Scan(&schema) 371 if err != nil { 372 // keep going; remove as many permissions as possible right now 373 continue 374 } 375 revocationStmts = append(revocationStmts, fmt.Sprintf( 376 `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`, 377 pq.QuoteIdentifier(schema), 378 pq.QuoteIdentifier(username))) 379 380 revocationStmts = append(revocationStmts, fmt.Sprintf( 381 `REVOKE USAGE ON SCHEMA %s FROM %s;`, 382 pq.QuoteIdentifier(schema), 383 pq.QuoteIdentifier(username))) 384 } 385 386 // for good measure, revoke all privileges and usage on schema public 387 revocationStmts = append(revocationStmts, fmt.Sprintf( 388 `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`, 389 pq.QuoteIdentifier(username))) 390 391 revocationStmts = append(revocationStmts, fmt.Sprintf( 392 "REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;", 393 pq.QuoteIdentifier(username))) 394 395 revocationStmts = append(revocationStmts, fmt.Sprintf( 396 "REVOKE USAGE ON SCHEMA public FROM %s;", 397 pq.QuoteIdentifier(username))) 398 399 // get the current database name so we can issue a REVOKE CONNECT for 400 // this username 401 var dbname sql.NullString 402 if err := db.QueryRowContext(ctx, "SELECT current_database();").Scan(&dbname); err != nil { 403 return err 404 } 405 406 if dbname.Valid { 407 revocationStmts = append(revocationStmts, fmt.Sprintf( 408 `REVOKE CONNECT ON DATABASE %s FROM %s;`, 409 pq.QuoteIdentifier(dbname.String), 410 pq.QuoteIdentifier(username))) 411 } 412 413 // again, here, we do not stop on error, as we want to remove as 414 // many permissions as possible right now 415 var lastStmtError error 416 for _, query := range revocationStmts { 417 if err := dbtxn.ExecuteDBQuery(ctx, db, nil, query); err != nil { 418 lastStmtError = err 419 } 420 } 421 422 // can't drop if not all privileges are revoked 423 if rows.Err() != nil { 424 return errwrap.Wrapf("could not generate revocation statements for all rows: {{err}}", rows.Err()) 425 } 426 if lastStmtError != nil { 427 return errwrap.Wrapf("could not perform all revocation statements: {{err}}", lastStmtError) 428 } 429 430 // Drop this user 431 stmt, err = db.PrepareContext(ctx, fmt.Sprintf( 432 `DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username))) 433 if err != nil { 434 return err 435 } 436 defer stmt.Close() 437 if _, err := stmt.ExecContext(ctx); err != nil { 438 return err 439 } 440 441 return nil 442} 443 444func (p *PostgreSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) { 445 p.Lock() 446 defer p.Unlock() 447 448 if len(p.Username) == 0 || len(p.Password) == 0 { 449 return nil, errors.New("username and password are required to rotate") 450 } 451 452 rotateStatents := statements 453 if len(rotateStatents) == 0 { 454 rotateStatents = []string{defaultPostgresRotateRootCredentialsSQL} 455 } 456 457 db, err := p.getConnection(ctx) 458 if err != nil { 459 return nil, err 460 } 461 462 tx, err := db.BeginTx(ctx, nil) 463 if err != nil { 464 return nil, err 465 } 466 defer func() { 467 tx.Rollback() 468 }() 469 470 password, err := p.GeneratePassword() 471 if err != nil { 472 return nil, err 473 } 474 475 for _, stmt := range rotateStatents { 476 for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { 477 query = strings.TrimSpace(query) 478 if len(query) == 0 { 479 continue 480 } 481 m := map[string]string{ 482 "username": p.Username, 483 "password": password, 484 } 485 if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { 486 return nil, err 487 } 488 } 489 } 490 491 if err := tx.Commit(); err != nil { 492 return nil, err 493 } 494 495 // Close the database connection to ensure no new connections come in 496 if err := db.Close(); err != nil { 497 return nil, err 498 } 499 500 p.RawConfig["password"] = password 501 return p.RawConfig, nil 502} 503