1package postgresql 2 3import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "strconv" 8 "strings" 9 "sync" 10 "time" 11 12 "github.com/hashicorp/errwrap" 13 "github.com/hashicorp/vault/sdk/physical" 14 15 log "github.com/hashicorp/go-hclog" 16 "github.com/hashicorp/go-uuid" 17 18 "github.com/armon/go-metrics" 19 "github.com/lib/pq" 20) 21 22const ( 23 24 // The lock TTL matches the default that Consul API uses, 15 seconds. 25 // Used as part of SQL commands to set/extend lock expiry time relative to 26 // database clock. 27 PostgreSQLLockTTLSeconds = 15 28 29 // The amount of time to wait between the lock renewals 30 PostgreSQLLockRenewInterval = 5 * time.Second 31 32 // PostgreSQLLockRetryInterval is the amount of time to wait 33 // if a lock fails before trying again. 34 PostgreSQLLockRetryInterval = time.Second 35) 36 37// Verify PostgreSQLBackend satisfies the correct interfaces 38var _ physical.Backend = (*PostgreSQLBackend)(nil) 39 40// 41// HA backend was implemented based on the DynamoDB backend pattern 42// With distinction using central postgres clock, hereby avoiding 43// possible issues with multiple clocks 44// 45var _ physical.HABackend = (*PostgreSQLBackend)(nil) 46var _ physical.Lock = (*PostgreSQLLock)(nil) 47 48// PostgreSQL Backend is a physical backend that stores data 49// within a PostgreSQL database. 50type PostgreSQLBackend struct { 51 table string 52 client *sql.DB 53 put_query string 54 get_query string 55 delete_query string 56 list_query string 57 58 ha_table string 59 haGetLockValueQuery string 60 haUpsertLockIdentityExec string 61 haDeleteLockExec string 62 63 haEnabled bool 64 logger log.Logger 65 permitPool *physical.PermitPool 66} 67 68// PostgreSQLLock implements a lock using an PostgreSQL client. 69type PostgreSQLLock struct { 70 backend *PostgreSQLBackend 71 value, key string 72 identity string 73 lock sync.Mutex 74 75 renewTicker *time.Ticker 76 77 // ttlSeconds is how long a lock is valid for 78 ttlSeconds int 79 80 // renewInterval is how much time to wait between lock renewals. must be << ttl 81 renewInterval time.Duration 82 83 // retryInterval is how much time to wait between attempts to grab the lock 84 retryInterval time.Duration 85} 86 87// NewPostgreSQLBackend constructs a PostgreSQL backend using the given 88// API client, server address, credentials, and database. 89func NewPostgreSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { 90 // Get the PostgreSQL credentials to perform read/write operations. 91 connURL, ok := conf["connection_url"] 92 if !ok || connURL == "" { 93 return nil, fmt.Errorf("missing connection_url") 94 } 95 96 unquoted_table, ok := conf["table"] 97 if !ok { 98 unquoted_table = "vault_kv_store" 99 } 100 quoted_table := pq.QuoteIdentifier(unquoted_table) 101 102 maxParStr, ok := conf["max_parallel"] 103 var maxParInt int 104 var err error 105 if ok { 106 maxParInt, err = strconv.Atoi(maxParStr) 107 if err != nil { 108 return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err) 109 } 110 if logger.IsDebug() { 111 logger.Debug("max_parallel set", "max_parallel", maxParInt) 112 } 113 } else { 114 maxParInt = physical.DefaultParallelOperations 115 } 116 117 maxIdleConnsStr, maxIdleConnsIsSet := conf["max_idle_connections"] 118 var maxIdleConns int 119 if maxIdleConnsIsSet { 120 maxIdleConns, err = strconv.Atoi(maxIdleConnsStr) 121 if err != nil { 122 return nil, errwrap.Wrapf("failed parsing max_idle_connections parameter: {{err}}", err) 123 } 124 if logger.IsDebug() { 125 logger.Debug("max_idle_connections set", "max_idle_connections", maxIdleConnsStr) 126 } 127 } 128 129 // Create PostgreSQL handle for the database. 130 db, err := sql.Open("postgres", connURL) 131 if err != nil { 132 return nil, errwrap.Wrapf("failed to connect to postgres: {{err}}", err) 133 } 134 db.SetMaxOpenConns(maxParInt) 135 136 if maxIdleConnsIsSet { 137 db.SetMaxIdleConns(maxIdleConns) 138 } 139 140 // Determine if we should use a function to work around lack of upsert (versions < 9.5) 141 var upsertAvailable bool 142 upsertAvailableQuery := "SELECT current_setting('server_version_num')::int >= 90500" 143 if err := db.QueryRow(upsertAvailableQuery).Scan(&upsertAvailable); err != nil { 144 return nil, errwrap.Wrapf("failed to check for native upsert: {{err}}", err) 145 } 146 147 if !upsertAvailable && conf["ha_enabled"] == "true" { 148 return nil, fmt.Errorf("ha_enabled=true in config but PG version doesn't support HA, must be at least 9.5") 149 } 150 151 // Setup our put strategy based on the presence or absence of a native 152 // upsert. 153 var put_query string 154 if !upsertAvailable { 155 put_query = "SELECT vault_kv_put($1, $2, $3, $4)" 156 } else { 157 put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" + 158 " ON CONFLICT (path, key) DO " + 159 " UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)" 160 } 161 162 unquoted_ha_table, ok := conf["ha_table"] 163 if !ok { 164 unquoted_ha_table = "vault_ha_locks" 165 } 166 quoted_ha_table := pq.QuoteIdentifier(unquoted_ha_table) 167 168 // Setup the backend. 169 m := &PostgreSQLBackend{ 170 table: quoted_table, 171 client: db, 172 put_query: put_query, 173 get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2", 174 delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2", 175 list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" + 176 " UNION ALL SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " + quoted_table + 177 " WHERE parent_path LIKE $1 || '%'", 178 haGetLockValueQuery: 179 // only read non expired data 180 " SELECT ha_value FROM " + quoted_ha_table + " WHERE NOW() <= valid_until AND ha_key = $1 ", 181 haUpsertLockIdentityExec: 182 // $1=identity $2=ha_key $3=ha_value $4=TTL in seconds 183 // update either steal expired lock OR update expiry for lock owned by me 184 " INSERT INTO " + quoted_ha_table + " as t (ha_identity, ha_key, ha_value, valid_until) VALUES ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds' ) " + 185 " ON CONFLICT (ha_key) DO " + 186 " UPDATE SET (ha_identity, ha_key, ha_value, valid_until) = ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds') " + 187 " WHERE (t.valid_until < NOW() AND t.ha_key = $2) OR " + 188 " (t.ha_identity = $1 AND t.ha_key = $2) ", 189 haDeleteLockExec: 190 // $1=ha_identity $2=ha_key 191 " DELETE FROM " + quoted_ha_table + " WHERE ha_identity=$1 AND ha_key=$2 ", 192 logger: logger, 193 permitPool: physical.NewPermitPool(maxParInt), 194 haEnabled: conf["ha_enabled"] == "true", 195 } 196 197 return m, nil 198} 199 200// splitKey is a helper to split a full path key into individual 201// parts: parentPath, path, key 202func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) { 203 var parentPath string 204 var path string 205 206 pieces := strings.Split(fullPath, "/") 207 depth := len(pieces) 208 key := pieces[depth-1] 209 210 if depth == 1 { 211 parentPath = "" 212 path = "/" 213 } else if depth == 2 { 214 parentPath = "/" 215 path = "/" + pieces[0] + "/" 216 } else { 217 parentPath = "/" + strings.Join(pieces[:depth-2], "/") + "/" 218 path = "/" + strings.Join(pieces[:depth-1], "/") + "/" 219 } 220 221 return parentPath, path, key 222} 223 224// Put is used to insert or update an entry. 225func (m *PostgreSQLBackend) Put(ctx context.Context, entry *physical.Entry) error { 226 defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now()) 227 228 m.permitPool.Acquire() 229 defer m.permitPool.Release() 230 231 parentPath, path, key := m.splitKey(entry.Key) 232 233 _, err := m.client.Exec(m.put_query, parentPath, path, key, entry.Value) 234 if err != nil { 235 return err 236 } 237 return nil 238} 239 240// Get is used to fetch and entry. 241func (m *PostgreSQLBackend) Get(ctx context.Context, fullPath string) (*physical.Entry, error) { 242 defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now()) 243 244 m.permitPool.Acquire() 245 defer m.permitPool.Release() 246 247 _, path, key := m.splitKey(fullPath) 248 249 var result []byte 250 err := m.client.QueryRow(m.get_query, path, key).Scan(&result) 251 if err == sql.ErrNoRows { 252 return nil, nil 253 } 254 if err != nil { 255 return nil, err 256 } 257 258 ent := &physical.Entry{ 259 Key: fullPath, 260 Value: result, 261 } 262 return ent, nil 263} 264 265// Delete is used to permanently delete an entry 266func (m *PostgreSQLBackend) Delete(ctx context.Context, fullPath string) error { 267 defer metrics.MeasureSince([]string{"postgres", "delete"}, time.Now()) 268 269 m.permitPool.Acquire() 270 defer m.permitPool.Release() 271 272 _, path, key := m.splitKey(fullPath) 273 274 _, err := m.client.Exec(m.delete_query, path, key) 275 if err != nil { 276 return err 277 } 278 return nil 279} 280 281// List is used to list all the keys under a given 282// prefix, up to the next prefix. 283func (m *PostgreSQLBackend) List(ctx context.Context, prefix string) ([]string, error) { 284 defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now()) 285 286 m.permitPool.Acquire() 287 defer m.permitPool.Release() 288 289 rows, err := m.client.Query(m.list_query, "/"+prefix) 290 if err != nil { 291 return nil, err 292 } 293 defer rows.Close() 294 295 var keys []string 296 for rows.Next() { 297 var key string 298 err = rows.Scan(&key) 299 if err != nil { 300 return nil, errwrap.Wrapf("failed to scan rows: {{err}}", err) 301 } 302 303 keys = append(keys, key) 304 } 305 306 return keys, nil 307} 308 309// LockWith is used for mutual exclusion based on the given key. 310func (p *PostgreSQLBackend) LockWith(key, value string) (physical.Lock, error) { 311 identity, err := uuid.GenerateUUID() 312 if err != nil { 313 return nil, err 314 } 315 return &PostgreSQLLock{ 316 backend: p, 317 key: key, 318 value: value, 319 identity: identity, 320 ttlSeconds: PostgreSQLLockTTLSeconds, 321 renewInterval: PostgreSQLLockRenewInterval, 322 retryInterval: PostgreSQLLockRetryInterval, 323 }, nil 324} 325 326func (p *PostgreSQLBackend) HAEnabled() bool { 327 return p.haEnabled 328} 329 330// Lock tries to acquire the lock by repeatedly trying to create a record in the 331// PostgreSQL table. It will block until either the stop channel is closed or 332// the lock could be acquired successfully. The returned channel will be closed 333// once the lock in the PostgreSQL table cannot be renewed, either due to an 334// error speaking to PostgreSQL or because someone else has taken it. 335func (l *PostgreSQLLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { 336 l.lock.Lock() 337 defer l.lock.Unlock() 338 339 var ( 340 success = make(chan struct{}) 341 errors = make(chan error) 342 leader = make(chan struct{}) 343 ) 344 // try to acquire the lock asynchronously 345 go l.tryToLock(stopCh, success, errors) 346 347 select { 348 case <-success: 349 // after acquiring it successfully, we must renew the lock periodically 350 l.renewTicker = time.NewTicker(l.renewInterval) 351 go l.periodicallyRenewLock(leader) 352 case err := <-errors: 353 return nil, err 354 case <-stopCh: 355 return nil, nil 356 } 357 358 return leader, nil 359} 360 361// Unlock releases the lock by deleting the lock record from the 362// PostgreSQL table. 363func (l *PostgreSQLLock) Unlock() error { 364 pg := l.backend 365 pg.permitPool.Acquire() 366 defer pg.permitPool.Release() 367 368 if l.renewTicker != nil { 369 l.renewTicker.Stop() 370 } 371 372 // Delete lock owned by me 373 _, err := pg.client.Exec(pg.haDeleteLockExec, l.identity, l.key) 374 return err 375} 376 377// Value checks whether or not the lock is held by any instance of PostgreSQLLock, 378// including this one, and returns the current value. 379func (l *PostgreSQLLock) Value() (bool, string, error) { 380 pg := l.backend 381 pg.permitPool.Acquire() 382 defer pg.permitPool.Release() 383 var result string 384 err := pg.client.QueryRow(pg.haGetLockValueQuery, l.key).Scan(&result) 385 386 switch err { 387 case nil: 388 return true, result, nil 389 case sql.ErrNoRows: 390 return false, "", nil 391 default: 392 return false, "", err 393 394 } 395} 396 397// tryToLock tries to create a new item in PostgreSQL every `retryInterval`. 398// As long as the item cannot be created (because it already exists), it will 399// be retried. If the operation fails due to an error, it is sent to the errors 400// channel. When the lock could be acquired successfully, the success channel 401// is closed. 402func (l *PostgreSQLLock) tryToLock(stop <-chan struct{}, success chan struct{}, errors chan error) { 403 ticker := time.NewTicker(l.retryInterval) 404 defer ticker.Stop() 405 406 for { 407 select { 408 case <-stop: 409 return 410 case <-ticker.C: 411 gotlock, err := l.writeItem() 412 switch { 413 case err != nil: 414 errors <- err 415 return 416 case gotlock: 417 close(success) 418 return 419 } 420 } 421 } 422} 423 424func (l *PostgreSQLLock) periodicallyRenewLock(done chan struct{}) { 425 for range l.renewTicker.C { 426 gotlock, err := l.writeItem() 427 if err != nil || !gotlock { 428 close(done) 429 l.renewTicker.Stop() 430 return 431 } 432 } 433} 434 435// Attempts to put/update the PostgreSQL item using condition expressions to 436// evaluate the TTL. Returns true if the lock was obtained, false if not. 437// If false error may be nil or non-nil: nil indicates simply that someone 438// else has the lock, whereas non-nil means that something unexpected happened. 439func (l *PostgreSQLLock) writeItem() (bool, error) { 440 pg := l.backend 441 pg.permitPool.Acquire() 442 defer pg.permitPool.Release() 443 444 // Try steal lock or update expiry on my lock 445 446 sqlResult, err := pg.client.Exec(pg.haUpsertLockIdentityExec, l.identity, l.key, l.value, l.ttlSeconds) 447 if err != nil { 448 return false, err 449 } 450 if sqlResult == nil { 451 return false, fmt.Errorf("empty SQL response received") 452 } 453 454 ar, err := sqlResult.RowsAffected() 455 if err != nil { 456 return false, err 457 } 458 return ar == 1, nil 459} 460