1package postgresql
2
3import (
4	"context"
5	"database/sql"
6	"errors"
7	"fmt"
8	"strings"
9	"time"
10
11	"github.com/hashicorp/errwrap"
12	"github.com/hashicorp/vault/api"
13	"github.com/hashicorp/vault/sdk/database/dbplugin"
14	"github.com/hashicorp/vault/sdk/database/helper/connutil"
15	"github.com/hashicorp/vault/sdk/database/helper/credsutil"
16	"github.com/hashicorp/vault/sdk/database/helper/dbutil"
17	"github.com/hashicorp/vault/sdk/helper/dbtxn"
18	"github.com/hashicorp/vault/sdk/helper/strutil"
19	"github.com/lib/pq"
20)
21
22const (
23	postgreSQLTypeName      = "postgres"
24	defaultPostgresRenewSQL = `
25ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}';
26`
27	defaultPostgresRotateRootCredentialsSQL = `
28ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}';
29`
30
31	defaultPostgresRotateCredentialsSQL = `
32ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';
33`
34)
35
36var _ dbplugin.Database = &PostgreSQL{}
37
38// New implements builtinplugins.BuiltinFactory
39func New() (interface{}, error) {
40	db := new()
41	// Wrap the plugin with middleware to sanitize errors
42	dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
43	return dbType, nil
44}
45
46func new() *PostgreSQL {
47	connProducer := &connutil.SQLConnectionProducer{}
48	connProducer.Type = postgreSQLTypeName
49
50	credsProducer := &credsutil.SQLCredentialsProducer{
51		DisplayNameLen: 8,
52		RoleNameLen:    8,
53		UsernameLen:    63,
54		Separator:      "-",
55	}
56
57	db := &PostgreSQL{
58		SQLConnectionProducer: connProducer,
59		CredentialsProducer:   credsProducer,
60	}
61
62	return db
63}
64
65// Run instantiates a PostgreSQL object, and runs the RPC server for the plugin
66func Run(apiTLSConfig *api.TLSConfig) error {
67	dbType, err := New()
68	if err != nil {
69		return err
70	}
71
72	dbplugin.Serve(dbType.(dbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig))
73
74	return nil
75}
76
77type PostgreSQL struct {
78	*connutil.SQLConnectionProducer
79	credsutil.CredentialsProducer
80}
81
82func (p *PostgreSQL) Type() (string, error) {
83	return postgreSQLTypeName, nil
84}
85
86func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) {
87	db, err := p.Connection(ctx)
88	if err != nil {
89		return nil, err
90	}
91
92	return db.(*sql.DB), nil
93}
94
95// SetCredentials uses provided information to set/create a user in the
96// database. Unlike CreateUser, this method requires a username be provided and
97// uses the name given, instead of generating a name. This is used for creating
98// and setting the password of static accounts, as well as rolling back
99// passwords in the database in the event an updated database fails to save in
100// Vault's storage.
101func (p *PostgreSQL) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) {
102	if len(statements.Rotation) == 0 {
103		return "", "", errors.New("empty rotation statements")
104	}
105
106	username = staticUser.Username
107	password = staticUser.Password
108	if username == "" || password == "" {
109		return "", "", errors.New("must provide both username and password")
110	}
111
112	// Grab the lock
113	p.Lock()
114	defer p.Unlock()
115
116	// Get the connection
117	db, err := p.getConnection(ctx)
118	if err != nil {
119		return "", "", err
120	}
121
122	// Check if the role exists
123	var exists bool
124	err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
125	if err != nil && err != sql.ErrNoRows {
126		return "", "", err
127	}
128
129	// Vault requires the database user already exist, and that the credentials
130	// used to execute the rotation statements has sufficient privileges.
131	stmts := statements.Rotation
132
133	// Start a transaction
134	tx, err := db.BeginTx(ctx, nil)
135	if err != nil {
136		return "", "", err
137	}
138	defer func() {
139		_ = tx.Rollback()
140	}()
141
142	// Execute each query
143	for _, stmt := range stmts {
144		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
145			query = strings.TrimSpace(query)
146			if len(query) == 0 {
147				continue
148			}
149
150			m := map[string]string{
151				"name":     staticUser.Username,
152				"password": password,
153			}
154			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
155				return "", "", err
156			}
157		}
158	}
159
160	// Commit the transaction
161	if err := tx.Commit(); err != nil {
162		return "", "", err
163	}
164
165	return username, password, nil
166}
167
168func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
169	statements = dbutil.StatementCompatibilityHelper(statements)
170
171	if len(statements.Creation) == 0 {
172		return "", "", dbutil.ErrEmptyCreationStatement
173	}
174
175	// Grab the lock
176	p.Lock()
177	defer p.Unlock()
178
179	username, err = p.GenerateUsername(usernameConfig)
180	if err != nil {
181		return "", "", err
182	}
183
184	password, err = p.GeneratePassword()
185	if err != nil {
186		return "", "", err
187	}
188
189	expirationStr, err := p.GenerateExpiration(expiration)
190	if err != nil {
191		return "", "", err
192	}
193
194	// Get the connection
195	db, err := p.getConnection(ctx)
196	if err != nil {
197		return "", "", err
198	}
199
200	// Start a transaction
201	tx, err := db.BeginTx(ctx, nil)
202	if err != nil {
203		return "", "", err
204
205	}
206	defer func() {
207		tx.Rollback()
208	}()
209
210	// Execute each query
211	for _, stmt := range statements.Creation {
212		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
213			query = strings.TrimSpace(query)
214			if len(query) == 0 {
215				continue
216			}
217
218			m := map[string]string{
219				"name":       username,
220				"password":   password,
221				"expiration": expirationStr,
222			}
223			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
224				return "", "", err
225			}
226		}
227	}
228
229	// Commit the transaction
230	if err := tx.Commit(); err != nil {
231		return "", "", err
232	}
233
234	return username, password, nil
235}
236
237func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
238	p.Lock()
239	defer p.Unlock()
240
241	statements = dbutil.StatementCompatibilityHelper(statements)
242
243	renewStmts := statements.Renewal
244	if len(renewStmts) == 0 {
245		renewStmts = []string{defaultPostgresRenewSQL}
246	}
247
248	db, err := p.getConnection(ctx)
249	if err != nil {
250		return err
251	}
252
253	tx, err := db.BeginTx(ctx, nil)
254	if err != nil {
255		return err
256	}
257	defer func() {
258		tx.Rollback()
259	}()
260
261	expirationStr, err := p.GenerateExpiration(expiration)
262	if err != nil {
263		return err
264	}
265
266	for _, stmt := range renewStmts {
267		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
268			query = strings.TrimSpace(query)
269			if len(query) == 0 {
270				continue
271			}
272
273			m := map[string]string{
274				"name":       username,
275				"expiration": expirationStr,
276			}
277			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
278				return err
279			}
280		}
281	}
282
283	return tx.Commit()
284}
285
286func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
287	// Grab the lock
288	p.Lock()
289	defer p.Unlock()
290
291	statements = dbutil.StatementCompatibilityHelper(statements)
292
293	if len(statements.Revocation) == 0 {
294		return p.defaultRevokeUser(ctx, username)
295	}
296
297	return p.customRevokeUser(ctx, username, statements.Revocation)
298}
299
300func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revocationStmts []string) error {
301	db, err := p.getConnection(ctx)
302	if err != nil {
303		return err
304	}
305
306	tx, err := db.BeginTx(ctx, nil)
307	if err != nil {
308		return err
309	}
310	defer func() {
311		tx.Rollback()
312	}()
313
314	for _, stmt := range revocationStmts {
315		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
316			query = strings.TrimSpace(query)
317			if len(query) == 0 {
318				continue
319			}
320
321			m := map[string]string{
322				"name": username,
323			}
324			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
325				return err
326			}
327		}
328	}
329
330	return tx.Commit()
331}
332
333func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error {
334	db, err := p.getConnection(ctx)
335	if err != nil {
336		return err
337	}
338
339	// Check if the role exists
340	var exists bool
341	err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
342	if err != nil && err != sql.ErrNoRows {
343		return err
344	}
345
346	if !exists {
347		return nil
348	}
349
350	// Query for permissions; we need to revoke permissions before we can drop
351	// the role
352	// This isn't done in a transaction because even if we fail along the way,
353	// we want to remove as much access as possible
354	stmt, err := db.PrepareContext(ctx, "SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
355	if err != nil {
356		return err
357	}
358	defer stmt.Close()
359
360	rows, err := stmt.QueryContext(ctx, username)
361	if err != nil {
362		return err
363	}
364	defer rows.Close()
365
366	const initialNumRevocations = 16
367	revocationStmts := make([]string, 0, initialNumRevocations)
368	for rows.Next() {
369		var schema string
370		err = rows.Scan(&schema)
371		if err != nil {
372			// keep going; remove as many permissions as possible right now
373			continue
374		}
375		revocationStmts = append(revocationStmts, fmt.Sprintf(
376			`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`,
377			pq.QuoteIdentifier(schema),
378			pq.QuoteIdentifier(username)))
379
380		revocationStmts = append(revocationStmts, fmt.Sprintf(
381			`REVOKE USAGE ON SCHEMA %s FROM %s;`,
382			pq.QuoteIdentifier(schema),
383			pq.QuoteIdentifier(username)))
384	}
385
386	// for good measure, revoke all privileges and usage on schema public
387	revocationStmts = append(revocationStmts, fmt.Sprintf(
388		`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`,
389		pq.QuoteIdentifier(username)))
390
391	revocationStmts = append(revocationStmts, fmt.Sprintf(
392		"REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;",
393		pq.QuoteIdentifier(username)))
394
395	revocationStmts = append(revocationStmts, fmt.Sprintf(
396		"REVOKE USAGE ON SCHEMA public FROM %s;",
397		pq.QuoteIdentifier(username)))
398
399	// get the current database name so we can issue a REVOKE CONNECT for
400	// this username
401	var dbname sql.NullString
402	if err := db.QueryRowContext(ctx, "SELECT current_database();").Scan(&dbname); err != nil {
403		return err
404	}
405
406	if dbname.Valid {
407		revocationStmts = append(revocationStmts, fmt.Sprintf(
408			`REVOKE CONNECT ON DATABASE %s FROM %s;`,
409			pq.QuoteIdentifier(dbname.String),
410			pq.QuoteIdentifier(username)))
411	}
412
413	// again, here, we do not stop on error, as we want to remove as
414	// many permissions as possible right now
415	var lastStmtError error
416	for _, query := range revocationStmts {
417		if err := dbtxn.ExecuteDBQuery(ctx, db, nil, query); err != nil {
418			lastStmtError = err
419		}
420	}
421
422	// can't drop if not all privileges are revoked
423	if rows.Err() != nil {
424		return errwrap.Wrapf("could not generate revocation statements for all rows: {{err}}", rows.Err())
425	}
426	if lastStmtError != nil {
427		return errwrap.Wrapf("could not perform all revocation statements: {{err}}", lastStmtError)
428	}
429
430	// Drop this user
431	stmt, err = db.PrepareContext(ctx, fmt.Sprintf(
432		`DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username)))
433	if err != nil {
434		return err
435	}
436	defer stmt.Close()
437	if _, err := stmt.ExecContext(ctx); err != nil {
438		return err
439	}
440
441	return nil
442}
443
444func (p *PostgreSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
445	p.Lock()
446	defer p.Unlock()
447
448	if len(p.Username) == 0 || len(p.Password) == 0 {
449		return nil, errors.New("username and password are required to rotate")
450	}
451
452	rotateStatents := statements
453	if len(rotateStatents) == 0 {
454		rotateStatents = []string{defaultPostgresRotateRootCredentialsSQL}
455	}
456
457	db, err := p.getConnection(ctx)
458	if err != nil {
459		return nil, err
460	}
461
462	tx, err := db.BeginTx(ctx, nil)
463	if err != nil {
464		return nil, err
465	}
466	defer func() {
467		tx.Rollback()
468	}()
469
470	password, err := p.GeneratePassword()
471	if err != nil {
472		return nil, err
473	}
474
475	for _, stmt := range rotateStatents {
476		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
477			query = strings.TrimSpace(query)
478			if len(query) == 0 {
479				continue
480			}
481			m := map[string]string{
482				"username": p.Username,
483				"password": password,
484			}
485			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
486				return nil, err
487			}
488		}
489	}
490
491	if err := tx.Commit(); err != nil {
492		return nil, err
493	}
494
495	// Close the database connection to ensure no new connections come in
496	if err := db.Close(); err != nil {
497		return nil, err
498	}
499
500	p.RawConfig["password"] = password
501	return p.RawConfig, nil
502}
503