1package postgresql
2
3import (
4	"context"
5	"database/sql"
6	"fmt"
7	"os"
8	"strconv"
9	"strings"
10	"sync"
11	"time"
12
13	"github.com/hashicorp/vault/sdk/physical"
14
15	log "github.com/hashicorp/go-hclog"
16	"github.com/hashicorp/go-uuid"
17
18	"github.com/armon/go-metrics"
19	"github.com/lib/pq"
20)
21
22const (
23
24	// The lock TTL matches the default that Consul API uses, 15 seconds.
25	// Used as part of SQL commands to set/extend lock expiry time relative to
26	// database clock.
27	PostgreSQLLockTTLSeconds = 15
28
29	// The amount of time to wait between the lock renewals
30	PostgreSQLLockRenewInterval = 5 * time.Second
31
32	// PostgreSQLLockRetryInterval is the amount of time to wait
33	// if a lock fails before trying again.
34	PostgreSQLLockRetryInterval = time.Second
35)
36
37// Verify PostgreSQLBackend satisfies the correct interfaces
38var _ physical.Backend = (*PostgreSQLBackend)(nil)
39
40//
41// HA backend was implemented based on the DynamoDB backend pattern
42// With distinction using central postgres clock, hereby avoiding
43// possible issues with multiple clocks
44//
45var (
46	_ physical.HABackend = (*PostgreSQLBackend)(nil)
47	_ physical.Lock      = (*PostgreSQLLock)(nil)
48)
49
50// PostgreSQL Backend is a physical backend that stores data
51// within a PostgreSQL database.
52type PostgreSQLBackend struct {
53	table        string
54	client       *sql.DB
55	put_query    string
56	get_query    string
57	delete_query string
58	list_query   string
59
60	ha_table                 string
61	haGetLockValueQuery      string
62	haUpsertLockIdentityExec string
63	haDeleteLockExec         string
64
65	haEnabled  bool
66	logger     log.Logger
67	permitPool *physical.PermitPool
68}
69
70// PostgreSQLLock implements a lock using an PostgreSQL client.
71type PostgreSQLLock struct {
72	backend    *PostgreSQLBackend
73	value, key string
74	identity   string
75	lock       sync.Mutex
76
77	renewTicker *time.Ticker
78
79	// ttlSeconds is how long a lock is valid for
80	ttlSeconds int
81
82	// renewInterval is how much time to wait between lock renewals.  must be << ttl
83	renewInterval time.Duration
84
85	// retryInterval is how much time to wait between attempts to grab the lock
86	retryInterval time.Duration
87}
88
89// NewPostgreSQLBackend constructs a PostgreSQL backend using the given
90// API client, server address, credentials, and database.
91func NewPostgreSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
92	// Get the PostgreSQL credentials to perform read/write operations.
93	connURL := connectionURL(conf)
94	if connURL == "" {
95		return nil, fmt.Errorf("missing connection_url")
96	}
97
98	unquoted_table, ok := conf["table"]
99	if !ok {
100		unquoted_table = "vault_kv_store"
101	}
102	quoted_table := pq.QuoteIdentifier(unquoted_table)
103
104	maxParStr, ok := conf["max_parallel"]
105	var maxParInt int
106	var err error
107	if ok {
108		maxParInt, err = strconv.Atoi(maxParStr)
109		if err != nil {
110			return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
111		}
112		if logger.IsDebug() {
113			logger.Debug("max_parallel set", "max_parallel", maxParInt)
114		}
115	} else {
116		maxParInt = physical.DefaultParallelOperations
117	}
118
119	maxIdleConnsStr, maxIdleConnsIsSet := conf["max_idle_connections"]
120	var maxIdleConns int
121	if maxIdleConnsIsSet {
122		maxIdleConns, err = strconv.Atoi(maxIdleConnsStr)
123		if err != nil {
124			return nil, fmt.Errorf("failed parsing max_idle_connections parameter: %w", err)
125		}
126		if logger.IsDebug() {
127			logger.Debug("max_idle_connections set", "max_idle_connections", maxIdleConnsStr)
128		}
129	}
130
131	// Create PostgreSQL handle for the database.
132	db, err := sql.Open("postgres", connURL)
133	if err != nil {
134		return nil, fmt.Errorf("failed to connect to postgres: %w", err)
135	}
136	db.SetMaxOpenConns(maxParInt)
137
138	if maxIdleConnsIsSet {
139		db.SetMaxIdleConns(maxIdleConns)
140	}
141
142	// Determine if we should use a function to work around lack of upsert (versions < 9.5)
143	var upsertAvailable bool
144	upsertAvailableQuery := "SELECT current_setting('server_version_num')::int >= 90500"
145	if err := db.QueryRow(upsertAvailableQuery).Scan(&upsertAvailable); err != nil {
146		return nil, fmt.Errorf("failed to check for native upsert: %w", err)
147	}
148
149	if !upsertAvailable && conf["ha_enabled"] == "true" {
150		return nil, fmt.Errorf("ha_enabled=true in config but PG version doesn't support HA, must be at least 9.5")
151	}
152
153	// Setup our put strategy based on the presence or absence of a native
154	// upsert.
155	var put_query string
156	if !upsertAvailable {
157		put_query = "SELECT vault_kv_put($1, $2, $3, $4)"
158	} else {
159		put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" +
160			" ON CONFLICT (path, key) DO " +
161			" UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)"
162	}
163
164	unquoted_ha_table, ok := conf["ha_table"]
165	if !ok {
166		unquoted_ha_table = "vault_ha_locks"
167	}
168	quoted_ha_table := pq.QuoteIdentifier(unquoted_ha_table)
169
170	// Setup the backend.
171	m := &PostgreSQLBackend{
172		table:        quoted_table,
173		client:       db,
174		put_query:    put_query,
175		get_query:    "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
176		delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
177		list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" +
178			" UNION ALL SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " + quoted_table +
179			" WHERE parent_path LIKE $1 || '%'",
180		haGetLockValueQuery:
181		// only read non expired data
182		" SELECT ha_value FROM " + quoted_ha_table + " WHERE NOW() <= valid_until AND ha_key = $1 ",
183		haUpsertLockIdentityExec:
184		// $1=identity $2=ha_key $3=ha_value $4=TTL in seconds
185		// update either steal expired lock OR update expiry for lock owned by me
186		" INSERT INTO " + quoted_ha_table + " as t (ha_identity, ha_key, ha_value, valid_until) VALUES ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds'  ) " +
187			" ON CONFLICT (ha_key) DO " +
188			" UPDATE SET (ha_identity, ha_key, ha_value, valid_until) = ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds') " +
189			" WHERE (t.valid_until < NOW() AND t.ha_key = $2) OR " +
190			" (t.ha_identity = $1 AND t.ha_key = $2)  ",
191		haDeleteLockExec:
192		// $1=ha_identity $2=ha_key
193		" DELETE FROM " + quoted_ha_table + " WHERE ha_identity=$1 AND ha_key=$2 ",
194		logger:     logger,
195		permitPool: physical.NewPermitPool(maxParInt),
196		haEnabled:  conf["ha_enabled"] == "true",
197	}
198
199	return m, nil
200}
201
202// connectionURL first check the environment variables for a connection URL. If
203// no connection URL exists in the environment variable, the Vault config file is
204// checked. If neither the environment variables or the config file set the connection
205// URL for the Postgres backend, because it is a required field, an error is returned.
206func connectionURL(conf map[string]string) string {
207	connURL := conf["connection_url"]
208	if envURL := os.Getenv("VAULT_PG_CONNECTION_URL"); envURL != "" {
209		connURL = envURL
210	}
211
212	return connURL
213}
214
215// splitKey is a helper to split a full path key into individual
216// parts: parentPath, path, key
217func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) {
218	var parentPath string
219	var path string
220
221	pieces := strings.Split(fullPath, "/")
222	depth := len(pieces)
223	key := pieces[depth-1]
224
225	if depth == 1 {
226		parentPath = ""
227		path = "/"
228	} else if depth == 2 {
229		parentPath = "/"
230		path = "/" + pieces[0] + "/"
231	} else {
232		parentPath = "/" + strings.Join(pieces[:depth-2], "/") + "/"
233		path = "/" + strings.Join(pieces[:depth-1], "/") + "/"
234	}
235
236	return parentPath, path, key
237}
238
239// Put is used to insert or update an entry.
240func (m *PostgreSQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
241	defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now())
242
243	m.permitPool.Acquire()
244	defer m.permitPool.Release()
245
246	parentPath, path, key := m.splitKey(entry.Key)
247
248	_, err := m.client.Exec(m.put_query, parentPath, path, key, entry.Value)
249	if err != nil {
250		return err
251	}
252	return nil
253}
254
255// Get is used to fetch and entry.
256func (m *PostgreSQLBackend) Get(ctx context.Context, fullPath string) (*physical.Entry, error) {
257	defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now())
258
259	m.permitPool.Acquire()
260	defer m.permitPool.Release()
261
262	_, path, key := m.splitKey(fullPath)
263
264	var result []byte
265	err := m.client.QueryRow(m.get_query, path, key).Scan(&result)
266	if err == sql.ErrNoRows {
267		return nil, nil
268	}
269	if err != nil {
270		return nil, err
271	}
272
273	ent := &physical.Entry{
274		Key:   fullPath,
275		Value: result,
276	}
277	return ent, nil
278}
279
280// Delete is used to permanently delete an entry
281func (m *PostgreSQLBackend) Delete(ctx context.Context, fullPath string) error {
282	defer metrics.MeasureSince([]string{"postgres", "delete"}, time.Now())
283
284	m.permitPool.Acquire()
285	defer m.permitPool.Release()
286
287	_, path, key := m.splitKey(fullPath)
288
289	_, err := m.client.Exec(m.delete_query, path, key)
290	if err != nil {
291		return err
292	}
293	return nil
294}
295
296// List is used to list all the keys under a given
297// prefix, up to the next prefix.
298func (m *PostgreSQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
299	defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now())
300
301	m.permitPool.Acquire()
302	defer m.permitPool.Release()
303
304	rows, err := m.client.Query(m.list_query, "/"+prefix)
305	if err != nil {
306		return nil, err
307	}
308	defer rows.Close()
309
310	var keys []string
311	for rows.Next() {
312		var key string
313		err = rows.Scan(&key)
314		if err != nil {
315			return nil, fmt.Errorf("failed to scan rows: %w", err)
316		}
317
318		keys = append(keys, key)
319	}
320
321	return keys, nil
322}
323
324// LockWith is used for mutual exclusion based on the given key.
325func (p *PostgreSQLBackend) LockWith(key, value string) (physical.Lock, error) {
326	identity, err := uuid.GenerateUUID()
327	if err != nil {
328		return nil, err
329	}
330	return &PostgreSQLLock{
331		backend:       p,
332		key:           key,
333		value:         value,
334		identity:      identity,
335		ttlSeconds:    PostgreSQLLockTTLSeconds,
336		renewInterval: PostgreSQLLockRenewInterval,
337		retryInterval: PostgreSQLLockRetryInterval,
338	}, nil
339}
340
341func (p *PostgreSQLBackend) HAEnabled() bool {
342	return p.haEnabled
343}
344
345// Lock tries to acquire the lock by repeatedly trying to create a record in the
346// PostgreSQL table. It will block until either the stop channel is closed or
347// the lock could be acquired successfully. The returned channel will be closed
348// once the lock in the PostgreSQL table cannot be renewed, either due to an
349// error speaking to PostgreSQL or because someone else has taken it.
350func (l *PostgreSQLLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
351	l.lock.Lock()
352	defer l.lock.Unlock()
353
354	var (
355		success = make(chan struct{})
356		errors  = make(chan error)
357		leader  = make(chan struct{})
358	)
359	// try to acquire the lock asynchronously
360	go l.tryToLock(stopCh, success, errors)
361
362	select {
363	case <-success:
364		// after acquiring it successfully, we must renew the lock periodically
365		l.renewTicker = time.NewTicker(l.renewInterval)
366		go l.periodicallyRenewLock(leader)
367	case err := <-errors:
368		return nil, err
369	case <-stopCh:
370		return nil, nil
371	}
372
373	return leader, nil
374}
375
376// Unlock releases the lock by deleting the lock record from the
377// PostgreSQL table.
378func (l *PostgreSQLLock) Unlock() error {
379	pg := l.backend
380	pg.permitPool.Acquire()
381	defer pg.permitPool.Release()
382
383	if l.renewTicker != nil {
384		l.renewTicker.Stop()
385	}
386
387	// Delete lock owned by me
388	_, err := pg.client.Exec(pg.haDeleteLockExec, l.identity, l.key)
389	return err
390}
391
392// Value checks whether or not the lock is held by any instance of PostgreSQLLock,
393// including this one, and returns the current value.
394func (l *PostgreSQLLock) Value() (bool, string, error) {
395	pg := l.backend
396	pg.permitPool.Acquire()
397	defer pg.permitPool.Release()
398	var result string
399	err := pg.client.QueryRow(pg.haGetLockValueQuery, l.key).Scan(&result)
400
401	switch err {
402	case nil:
403		return true, result, nil
404	case sql.ErrNoRows:
405		return false, "", nil
406	default:
407		return false, "", err
408
409	}
410}
411
412// tryToLock tries to create a new item in PostgreSQL every `retryInterval`.
413// As long as the item cannot be created (because it already exists), it will
414// be retried. If the operation fails due to an error, it is sent to the errors
415// channel. When the lock could be acquired successfully, the success channel
416// is closed.
417func (l *PostgreSQLLock) tryToLock(stop <-chan struct{}, success chan struct{}, errors chan error) {
418	ticker := time.NewTicker(l.retryInterval)
419	defer ticker.Stop()
420
421	for {
422		select {
423		case <-stop:
424			return
425		case <-ticker.C:
426			gotlock, err := l.writeItem()
427			switch {
428			case err != nil:
429				errors <- err
430				return
431			case gotlock:
432				close(success)
433				return
434			}
435		}
436	}
437}
438
439func (l *PostgreSQLLock) periodicallyRenewLock(done chan struct{}) {
440	for range l.renewTicker.C {
441		gotlock, err := l.writeItem()
442		if err != nil || !gotlock {
443			close(done)
444			l.renewTicker.Stop()
445			return
446		}
447	}
448}
449
450// Attempts to put/update the PostgreSQL item using condition expressions to
451// evaluate the TTL.  Returns true if the lock was obtained, false if not.
452// If false error may be nil or non-nil: nil indicates simply that someone
453// else has the lock, whereas non-nil means that something unexpected happened.
454func (l *PostgreSQLLock) writeItem() (bool, error) {
455	pg := l.backend
456	pg.permitPool.Acquire()
457	defer pg.permitPool.Release()
458
459	// Try steal lock or update expiry on my lock
460
461	sqlResult, err := pg.client.Exec(pg.haUpsertLockIdentityExec, l.identity, l.key, l.value, l.ttlSeconds)
462	if err != nil {
463		return false, err
464	}
465	if sqlResult == nil {
466		return false, fmt.Errorf("empty SQL response received")
467	}
468
469	ar, err := sqlResult.RowsAffected()
470	if err != nil {
471		return false, err
472	}
473	return ar == 1, nil
474}
475