1package postgresql
2
3import (
4	"context"
5	"database/sql"
6	"fmt"
7	"strconv"
8	"strings"
9	"sync"
10	"time"
11
12	"github.com/hashicorp/errwrap"
13	"github.com/hashicorp/vault/sdk/physical"
14
15	log "github.com/hashicorp/go-hclog"
16	"github.com/hashicorp/go-uuid"
17
18	"github.com/armon/go-metrics"
19	"github.com/lib/pq"
20)
21
22const (
23
24	// The lock TTL matches the default that Consul API uses, 15 seconds.
25	// Used as part of SQL commands to set/extend lock expiry time relative to
26	// database clock.
27	PostgreSQLLockTTLSeconds = 15
28
29	// The amount of time to wait between the lock renewals
30	PostgreSQLLockRenewInterval = 5 * time.Second
31
32	// PostgreSQLLockRetryInterval is the amount of time to wait
33	// if a lock fails before trying again.
34	PostgreSQLLockRetryInterval = time.Second
35)
36
37// Verify PostgreSQLBackend satisfies the correct interfaces
38var _ physical.Backend = (*PostgreSQLBackend)(nil)
39
40//
41// HA backend was implemented based on the DynamoDB backend pattern
42// With distinction using central postgres clock, hereby avoiding
43// possible issues with multiple clocks
44//
45var _ physical.HABackend = (*PostgreSQLBackend)(nil)
46var _ physical.Lock = (*PostgreSQLLock)(nil)
47
48// PostgreSQL Backend is a physical backend that stores data
49// within a PostgreSQL database.
50type PostgreSQLBackend struct {
51	table        string
52	client       *sql.DB
53	put_query    string
54	get_query    string
55	delete_query string
56	list_query   string
57
58	ha_table                 string
59	haGetLockValueQuery      string
60	haUpsertLockIdentityExec string
61	haDeleteLockExec         string
62
63	haEnabled  bool
64	logger     log.Logger
65	permitPool *physical.PermitPool
66}
67
68// PostgreSQLLock implements a lock using an PostgreSQL client.
69type PostgreSQLLock struct {
70	backend    *PostgreSQLBackend
71	value, key string
72	identity   string
73	lock       sync.Mutex
74
75	renewTicker *time.Ticker
76
77	// ttlSeconds is how long a lock is valid for
78	ttlSeconds int
79
80	// renewInterval is how much time to wait between lock renewals.  must be << ttl
81	renewInterval time.Duration
82
83	// retryInterval is how much time to wait between attempts to grab the lock
84	retryInterval time.Duration
85}
86
87// NewPostgreSQLBackend constructs a PostgreSQL backend using the given
88// API client, server address, credentials, and database.
89func NewPostgreSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
90	// Get the PostgreSQL credentials to perform read/write operations.
91	connURL, ok := conf["connection_url"]
92	if !ok || connURL == "" {
93		return nil, fmt.Errorf("missing connection_url")
94	}
95
96	unquoted_table, ok := conf["table"]
97	if !ok {
98		unquoted_table = "vault_kv_store"
99	}
100	quoted_table := pq.QuoteIdentifier(unquoted_table)
101
102	maxParStr, ok := conf["max_parallel"]
103	var maxParInt int
104	var err error
105	if ok {
106		maxParInt, err = strconv.Atoi(maxParStr)
107		if err != nil {
108			return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
109		}
110		if logger.IsDebug() {
111			logger.Debug("max_parallel set", "max_parallel", maxParInt)
112		}
113	} else {
114		maxParInt = physical.DefaultParallelOperations
115	}
116
117	maxIdleConnsStr, maxIdleConnsIsSet := conf["max_idle_connections"]
118	var maxIdleConns int
119	if maxIdleConnsIsSet {
120		maxIdleConns, err = strconv.Atoi(maxIdleConnsStr)
121		if err != nil {
122			return nil, errwrap.Wrapf("failed parsing max_idle_connections parameter: {{err}}", err)
123		}
124		if logger.IsDebug() {
125			logger.Debug("max_idle_connections set", "max_idle_connections", maxIdleConnsStr)
126		}
127	}
128
129	// Create PostgreSQL handle for the database.
130	db, err := sql.Open("postgres", connURL)
131	if err != nil {
132		return nil, errwrap.Wrapf("failed to connect to postgres: {{err}}", err)
133	}
134	db.SetMaxOpenConns(maxParInt)
135
136	if maxIdleConnsIsSet {
137		db.SetMaxIdleConns(maxIdleConns)
138	}
139
140	// Determine if we should use a function to work around lack of upsert (versions < 9.5)
141	var upsertAvailable bool
142	upsertAvailableQuery := "SELECT current_setting('server_version_num')::int >= 90500"
143	if err := db.QueryRow(upsertAvailableQuery).Scan(&upsertAvailable); err != nil {
144		return nil, errwrap.Wrapf("failed to check for native upsert: {{err}}", err)
145	}
146
147	if !upsertAvailable && conf["ha_enabled"] == "true" {
148		return nil, fmt.Errorf("ha_enabled=true in config but PG version doesn't support HA, must be at least 9.5")
149	}
150
151	// Setup our put strategy based on the presence or absence of a native
152	// upsert.
153	var put_query string
154	if !upsertAvailable {
155		put_query = "SELECT vault_kv_put($1, $2, $3, $4)"
156	} else {
157		put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" +
158			" ON CONFLICT (path, key) DO " +
159			" UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)"
160	}
161
162	unquoted_ha_table, ok := conf["ha_table"]
163	if !ok {
164		unquoted_ha_table = "vault_ha_locks"
165	}
166	quoted_ha_table := pq.QuoteIdentifier(unquoted_ha_table)
167
168	// Setup the backend.
169	m := &PostgreSQLBackend{
170		table:        quoted_table,
171		client:       db,
172		put_query:    put_query,
173		get_query:    "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
174		delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
175		list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" +
176			" UNION ALL SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " + quoted_table +
177			" WHERE parent_path LIKE $1 || '%'",
178		haGetLockValueQuery:
179		// only read non expired data
180		" SELECT ha_value FROM " + quoted_ha_table + " WHERE NOW() <= valid_until AND ha_key = $1 ",
181		haUpsertLockIdentityExec:
182		// $1=identity $2=ha_key $3=ha_value $4=TTL in seconds
183		// update either steal expired lock OR update expiry for lock owned by me
184		" INSERT INTO " + quoted_ha_table + " as t (ha_identity, ha_key, ha_value, valid_until) VALUES ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds'  ) " +
185			" ON CONFLICT (ha_key) DO " +
186			" UPDATE SET (ha_identity, ha_key, ha_value, valid_until) = ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds') " +
187			" WHERE (t.valid_until < NOW() AND t.ha_key = $2) OR " +
188			" (t.ha_identity = $1 AND t.ha_key = $2)  ",
189		haDeleteLockExec:
190		// $1=ha_identity $2=ha_key
191		" DELETE FROM " + quoted_ha_table + " WHERE ha_identity=$1 AND ha_key=$2 ",
192		logger:     logger,
193		permitPool: physical.NewPermitPool(maxParInt),
194		haEnabled:  conf["ha_enabled"] == "true",
195	}
196
197	return m, nil
198}
199
200// splitKey is a helper to split a full path key into individual
201// parts: parentPath, path, key
202func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) {
203	var parentPath string
204	var path string
205
206	pieces := strings.Split(fullPath, "/")
207	depth := len(pieces)
208	key := pieces[depth-1]
209
210	if depth == 1 {
211		parentPath = ""
212		path = "/"
213	} else if depth == 2 {
214		parentPath = "/"
215		path = "/" + pieces[0] + "/"
216	} else {
217		parentPath = "/" + strings.Join(pieces[:depth-2], "/") + "/"
218		path = "/" + strings.Join(pieces[:depth-1], "/") + "/"
219	}
220
221	return parentPath, path, key
222}
223
224// Put is used to insert or update an entry.
225func (m *PostgreSQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
226	defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now())
227
228	m.permitPool.Acquire()
229	defer m.permitPool.Release()
230
231	parentPath, path, key := m.splitKey(entry.Key)
232
233	_, err := m.client.Exec(m.put_query, parentPath, path, key, entry.Value)
234	if err != nil {
235		return err
236	}
237	return nil
238}
239
240// Get is used to fetch and entry.
241func (m *PostgreSQLBackend) Get(ctx context.Context, fullPath string) (*physical.Entry, error) {
242	defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now())
243
244	m.permitPool.Acquire()
245	defer m.permitPool.Release()
246
247	_, path, key := m.splitKey(fullPath)
248
249	var result []byte
250	err := m.client.QueryRow(m.get_query, path, key).Scan(&result)
251	if err == sql.ErrNoRows {
252		return nil, nil
253	}
254	if err != nil {
255		return nil, err
256	}
257
258	ent := &physical.Entry{
259		Key:   fullPath,
260		Value: result,
261	}
262	return ent, nil
263}
264
265// Delete is used to permanently delete an entry
266func (m *PostgreSQLBackend) Delete(ctx context.Context, fullPath string) error {
267	defer metrics.MeasureSince([]string{"postgres", "delete"}, time.Now())
268
269	m.permitPool.Acquire()
270	defer m.permitPool.Release()
271
272	_, path, key := m.splitKey(fullPath)
273
274	_, err := m.client.Exec(m.delete_query, path, key)
275	if err != nil {
276		return err
277	}
278	return nil
279}
280
281// List is used to list all the keys under a given
282// prefix, up to the next prefix.
283func (m *PostgreSQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
284	defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now())
285
286	m.permitPool.Acquire()
287	defer m.permitPool.Release()
288
289	rows, err := m.client.Query(m.list_query, "/"+prefix)
290	if err != nil {
291		return nil, err
292	}
293	defer rows.Close()
294
295	var keys []string
296	for rows.Next() {
297		var key string
298		err = rows.Scan(&key)
299		if err != nil {
300			return nil, errwrap.Wrapf("failed to scan rows: {{err}}", err)
301		}
302
303		keys = append(keys, key)
304	}
305
306	return keys, nil
307}
308
309// LockWith is used for mutual exclusion based on the given key.
310func (p *PostgreSQLBackend) LockWith(key, value string) (physical.Lock, error) {
311	identity, err := uuid.GenerateUUID()
312	if err != nil {
313		return nil, err
314	}
315	return &PostgreSQLLock{
316		backend:       p,
317		key:           key,
318		value:         value,
319		identity:      identity,
320		ttlSeconds:    PostgreSQLLockTTLSeconds,
321		renewInterval: PostgreSQLLockRenewInterval,
322		retryInterval: PostgreSQLLockRetryInterval,
323	}, nil
324}
325
326func (p *PostgreSQLBackend) HAEnabled() bool {
327	return p.haEnabled
328}
329
330// Lock tries to acquire the lock by repeatedly trying to create a record in the
331// PostgreSQL table. It will block until either the stop channel is closed or
332// the lock could be acquired successfully. The returned channel will be closed
333// once the lock in the PostgreSQL table cannot be renewed, either due to an
334// error speaking to PostgreSQL or because someone else has taken it.
335func (l *PostgreSQLLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
336	l.lock.Lock()
337	defer l.lock.Unlock()
338
339	var (
340		success = make(chan struct{})
341		errors  = make(chan error)
342		leader  = make(chan struct{})
343	)
344	// try to acquire the lock asynchronously
345	go l.tryToLock(stopCh, success, errors)
346
347	select {
348	case <-success:
349		// after acquiring it successfully, we must renew the lock periodically
350		l.renewTicker = time.NewTicker(l.renewInterval)
351		go l.periodicallyRenewLock(leader)
352	case err := <-errors:
353		return nil, err
354	case <-stopCh:
355		return nil, nil
356	}
357
358	return leader, nil
359}
360
361// Unlock releases the lock by deleting the lock record from the
362// PostgreSQL table.
363func (l *PostgreSQLLock) Unlock() error {
364	pg := l.backend
365	pg.permitPool.Acquire()
366	defer pg.permitPool.Release()
367
368	if l.renewTicker != nil {
369		l.renewTicker.Stop()
370	}
371
372	// Delete lock owned by me
373	_, err := pg.client.Exec(pg.haDeleteLockExec, l.identity, l.key)
374	return err
375}
376
377// Value checks whether or not the lock is held by any instance of PostgreSQLLock,
378// including this one, and returns the current value.
379func (l *PostgreSQLLock) Value() (bool, string, error) {
380	pg := l.backend
381	pg.permitPool.Acquire()
382	defer pg.permitPool.Release()
383	var result string
384	err := pg.client.QueryRow(pg.haGetLockValueQuery, l.key).Scan(&result)
385
386	switch err {
387	case nil:
388		return true, result, nil
389	case sql.ErrNoRows:
390		return false, "", nil
391	default:
392		return false, "", err
393
394	}
395}
396
397// tryToLock tries to create a new item in PostgreSQL every `retryInterval`.
398// As long as the item cannot be created (because it already exists), it will
399// be retried. If the operation fails due to an error, it is sent to the errors
400// channel. When the lock could be acquired successfully, the success channel
401// is closed.
402func (l *PostgreSQLLock) tryToLock(stop <-chan struct{}, success chan struct{}, errors chan error) {
403	ticker := time.NewTicker(l.retryInterval)
404	defer ticker.Stop()
405
406	for {
407		select {
408		case <-stop:
409			return
410		case <-ticker.C:
411			gotlock, err := l.writeItem()
412			switch {
413			case err != nil:
414				errors <- err
415				return
416			case gotlock:
417				close(success)
418				return
419			}
420		}
421	}
422}
423
424func (l *PostgreSQLLock) periodicallyRenewLock(done chan struct{}) {
425	for range l.renewTicker.C {
426		gotlock, err := l.writeItem()
427		if err != nil || !gotlock {
428			close(done)
429			l.renewTicker.Stop()
430			return
431		}
432	}
433}
434
435// Attempts to put/update the PostgreSQL item using condition expressions to
436// evaluate the TTL.  Returns true if the lock was obtained, false if not.
437// If false error may be nil or non-nil: nil indicates simply that someone
438// else has the lock, whereas non-nil means that something unexpected happened.
439func (l *PostgreSQLLock) writeItem() (bool, error) {
440	pg := l.backend
441	pg.permitPool.Acquire()
442	defer pg.permitPool.Release()
443
444	// Try steal lock or update expiry on my lock
445
446	sqlResult, err := pg.client.Exec(pg.haUpsertLockIdentityExec, l.identity, l.key, l.value, l.ttlSeconds)
447	if err != nil {
448		return false, err
449	}
450	if sqlResult == nil {
451		return false, fmt.Errorf("empty SQL response received")
452	}
453
454	ar, err := sqlResult.RowsAffected()
455	if err != nil {
456		return false, err
457	}
458	return ar == 1, nil
459}
460