1package postgresql 2 3import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "os" 8 "strconv" 9 "strings" 10 "sync" 11 "time" 12 13 "github.com/hashicorp/vault/sdk/physical" 14 15 log "github.com/hashicorp/go-hclog" 16 "github.com/hashicorp/go-uuid" 17 18 "github.com/armon/go-metrics" 19 "github.com/lib/pq" 20) 21 22const ( 23 24 // The lock TTL matches the default that Consul API uses, 15 seconds. 25 // Used as part of SQL commands to set/extend lock expiry time relative to 26 // database clock. 27 PostgreSQLLockTTLSeconds = 15 28 29 // The amount of time to wait between the lock renewals 30 PostgreSQLLockRenewInterval = 5 * time.Second 31 32 // PostgreSQLLockRetryInterval is the amount of time to wait 33 // if a lock fails before trying again. 34 PostgreSQLLockRetryInterval = time.Second 35) 36 37// Verify PostgreSQLBackend satisfies the correct interfaces 38var _ physical.Backend = (*PostgreSQLBackend)(nil) 39 40// 41// HA backend was implemented based on the DynamoDB backend pattern 42// With distinction using central postgres clock, hereby avoiding 43// possible issues with multiple clocks 44// 45var ( 46 _ physical.HABackend = (*PostgreSQLBackend)(nil) 47 _ physical.Lock = (*PostgreSQLLock)(nil) 48) 49 50// PostgreSQL Backend is a physical backend that stores data 51// within a PostgreSQL database. 52type PostgreSQLBackend struct { 53 table string 54 client *sql.DB 55 put_query string 56 get_query string 57 delete_query string 58 list_query string 59 60 ha_table string 61 haGetLockValueQuery string 62 haUpsertLockIdentityExec string 63 haDeleteLockExec string 64 65 haEnabled bool 66 logger log.Logger 67 permitPool *physical.PermitPool 68} 69 70// PostgreSQLLock implements a lock using an PostgreSQL client. 71type PostgreSQLLock struct { 72 backend *PostgreSQLBackend 73 value, key string 74 identity string 75 lock sync.Mutex 76 77 renewTicker *time.Ticker 78 79 // ttlSeconds is how long a lock is valid for 80 ttlSeconds int 81 82 // renewInterval is how much time to wait between lock renewals. must be << ttl 83 renewInterval time.Duration 84 85 // retryInterval is how much time to wait between attempts to grab the lock 86 retryInterval time.Duration 87} 88 89// NewPostgreSQLBackend constructs a PostgreSQL backend using the given 90// API client, server address, credentials, and database. 91func NewPostgreSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { 92 // Get the PostgreSQL credentials to perform read/write operations. 93 connURL := connectionURL(conf) 94 if connURL == "" { 95 return nil, fmt.Errorf("missing connection_url") 96 } 97 98 unquoted_table, ok := conf["table"] 99 if !ok { 100 unquoted_table = "vault_kv_store" 101 } 102 quoted_table := pq.QuoteIdentifier(unquoted_table) 103 104 maxParStr, ok := conf["max_parallel"] 105 var maxParInt int 106 var err error 107 if ok { 108 maxParInt, err = strconv.Atoi(maxParStr) 109 if err != nil { 110 return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err) 111 } 112 if logger.IsDebug() { 113 logger.Debug("max_parallel set", "max_parallel", maxParInt) 114 } 115 } else { 116 maxParInt = physical.DefaultParallelOperations 117 } 118 119 maxIdleConnsStr, maxIdleConnsIsSet := conf["max_idle_connections"] 120 var maxIdleConns int 121 if maxIdleConnsIsSet { 122 maxIdleConns, err = strconv.Atoi(maxIdleConnsStr) 123 if err != nil { 124 return nil, fmt.Errorf("failed parsing max_idle_connections parameter: %w", err) 125 } 126 if logger.IsDebug() { 127 logger.Debug("max_idle_connections set", "max_idle_connections", maxIdleConnsStr) 128 } 129 } 130 131 // Create PostgreSQL handle for the database. 132 db, err := sql.Open("postgres", connURL) 133 if err != nil { 134 return nil, fmt.Errorf("failed to connect to postgres: %w", err) 135 } 136 db.SetMaxOpenConns(maxParInt) 137 138 if maxIdleConnsIsSet { 139 db.SetMaxIdleConns(maxIdleConns) 140 } 141 142 // Determine if we should use a function to work around lack of upsert (versions < 9.5) 143 var upsertAvailable bool 144 upsertAvailableQuery := "SELECT current_setting('server_version_num')::int >= 90500" 145 if err := db.QueryRow(upsertAvailableQuery).Scan(&upsertAvailable); err != nil { 146 return nil, fmt.Errorf("failed to check for native upsert: %w", err) 147 } 148 149 if !upsertAvailable && conf["ha_enabled"] == "true" { 150 return nil, fmt.Errorf("ha_enabled=true in config but PG version doesn't support HA, must be at least 9.5") 151 } 152 153 // Setup our put strategy based on the presence or absence of a native 154 // upsert. 155 var put_query string 156 if !upsertAvailable { 157 put_query = "SELECT vault_kv_put($1, $2, $3, $4)" 158 } else { 159 put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" + 160 " ON CONFLICT (path, key) DO " + 161 " UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)" 162 } 163 164 unquoted_ha_table, ok := conf["ha_table"] 165 if !ok { 166 unquoted_ha_table = "vault_ha_locks" 167 } 168 quoted_ha_table := pq.QuoteIdentifier(unquoted_ha_table) 169 170 // Setup the backend. 171 m := &PostgreSQLBackend{ 172 table: quoted_table, 173 client: db, 174 put_query: put_query, 175 get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2", 176 delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2", 177 list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" + 178 " UNION ALL SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " + quoted_table + 179 " WHERE parent_path LIKE $1 || '%'", 180 haGetLockValueQuery: 181 // only read non expired data 182 " SELECT ha_value FROM " + quoted_ha_table + " WHERE NOW() <= valid_until AND ha_key = $1 ", 183 haUpsertLockIdentityExec: 184 // $1=identity $2=ha_key $3=ha_value $4=TTL in seconds 185 // update either steal expired lock OR update expiry for lock owned by me 186 " INSERT INTO " + quoted_ha_table + " as t (ha_identity, ha_key, ha_value, valid_until) VALUES ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds' ) " + 187 " ON CONFLICT (ha_key) DO " + 188 " UPDATE SET (ha_identity, ha_key, ha_value, valid_until) = ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds') " + 189 " WHERE (t.valid_until < NOW() AND t.ha_key = $2) OR " + 190 " (t.ha_identity = $1 AND t.ha_key = $2) ", 191 haDeleteLockExec: 192 // $1=ha_identity $2=ha_key 193 " DELETE FROM " + quoted_ha_table + " WHERE ha_identity=$1 AND ha_key=$2 ", 194 logger: logger, 195 permitPool: physical.NewPermitPool(maxParInt), 196 haEnabled: conf["ha_enabled"] == "true", 197 } 198 199 return m, nil 200} 201 202// connectionURL first check the environment variables for a connection URL. If 203// no connection URL exists in the environment variable, the Vault config file is 204// checked. If neither the environment variables or the config file set the connection 205// URL for the Postgres backend, because it is a required field, an error is returned. 206func connectionURL(conf map[string]string) string { 207 connURL := conf["connection_url"] 208 if envURL := os.Getenv("VAULT_PG_CONNECTION_URL"); envURL != "" { 209 connURL = envURL 210 } 211 212 return connURL 213} 214 215// splitKey is a helper to split a full path key into individual 216// parts: parentPath, path, key 217func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) { 218 var parentPath string 219 var path string 220 221 pieces := strings.Split(fullPath, "/") 222 depth := len(pieces) 223 key := pieces[depth-1] 224 225 if depth == 1 { 226 parentPath = "" 227 path = "/" 228 } else if depth == 2 { 229 parentPath = "/" 230 path = "/" + pieces[0] + "/" 231 } else { 232 parentPath = "/" + strings.Join(pieces[:depth-2], "/") + "/" 233 path = "/" + strings.Join(pieces[:depth-1], "/") + "/" 234 } 235 236 return parentPath, path, key 237} 238 239// Put is used to insert or update an entry. 240func (m *PostgreSQLBackend) Put(ctx context.Context, entry *physical.Entry) error { 241 defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now()) 242 243 m.permitPool.Acquire() 244 defer m.permitPool.Release() 245 246 parentPath, path, key := m.splitKey(entry.Key) 247 248 _, err := m.client.Exec(m.put_query, parentPath, path, key, entry.Value) 249 if err != nil { 250 return err 251 } 252 return nil 253} 254 255// Get is used to fetch and entry. 256func (m *PostgreSQLBackend) Get(ctx context.Context, fullPath string) (*physical.Entry, error) { 257 defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now()) 258 259 m.permitPool.Acquire() 260 defer m.permitPool.Release() 261 262 _, path, key := m.splitKey(fullPath) 263 264 var result []byte 265 err := m.client.QueryRow(m.get_query, path, key).Scan(&result) 266 if err == sql.ErrNoRows { 267 return nil, nil 268 } 269 if err != nil { 270 return nil, err 271 } 272 273 ent := &physical.Entry{ 274 Key: fullPath, 275 Value: result, 276 } 277 return ent, nil 278} 279 280// Delete is used to permanently delete an entry 281func (m *PostgreSQLBackend) Delete(ctx context.Context, fullPath string) error { 282 defer metrics.MeasureSince([]string{"postgres", "delete"}, time.Now()) 283 284 m.permitPool.Acquire() 285 defer m.permitPool.Release() 286 287 _, path, key := m.splitKey(fullPath) 288 289 _, err := m.client.Exec(m.delete_query, path, key) 290 if err != nil { 291 return err 292 } 293 return nil 294} 295 296// List is used to list all the keys under a given 297// prefix, up to the next prefix. 298func (m *PostgreSQLBackend) List(ctx context.Context, prefix string) ([]string, error) { 299 defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now()) 300 301 m.permitPool.Acquire() 302 defer m.permitPool.Release() 303 304 rows, err := m.client.Query(m.list_query, "/"+prefix) 305 if err != nil { 306 return nil, err 307 } 308 defer rows.Close() 309 310 var keys []string 311 for rows.Next() { 312 var key string 313 err = rows.Scan(&key) 314 if err != nil { 315 return nil, fmt.Errorf("failed to scan rows: %w", err) 316 } 317 318 keys = append(keys, key) 319 } 320 321 return keys, nil 322} 323 324// LockWith is used for mutual exclusion based on the given key. 325func (p *PostgreSQLBackend) LockWith(key, value string) (physical.Lock, error) { 326 identity, err := uuid.GenerateUUID() 327 if err != nil { 328 return nil, err 329 } 330 return &PostgreSQLLock{ 331 backend: p, 332 key: key, 333 value: value, 334 identity: identity, 335 ttlSeconds: PostgreSQLLockTTLSeconds, 336 renewInterval: PostgreSQLLockRenewInterval, 337 retryInterval: PostgreSQLLockRetryInterval, 338 }, nil 339} 340 341func (p *PostgreSQLBackend) HAEnabled() bool { 342 return p.haEnabled 343} 344 345// Lock tries to acquire the lock by repeatedly trying to create a record in the 346// PostgreSQL table. It will block until either the stop channel is closed or 347// the lock could be acquired successfully. The returned channel will be closed 348// once the lock in the PostgreSQL table cannot be renewed, either due to an 349// error speaking to PostgreSQL or because someone else has taken it. 350func (l *PostgreSQLLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { 351 l.lock.Lock() 352 defer l.lock.Unlock() 353 354 var ( 355 success = make(chan struct{}) 356 errors = make(chan error) 357 leader = make(chan struct{}) 358 ) 359 // try to acquire the lock asynchronously 360 go l.tryToLock(stopCh, success, errors) 361 362 select { 363 case <-success: 364 // after acquiring it successfully, we must renew the lock periodically 365 l.renewTicker = time.NewTicker(l.renewInterval) 366 go l.periodicallyRenewLock(leader) 367 case err := <-errors: 368 return nil, err 369 case <-stopCh: 370 return nil, nil 371 } 372 373 return leader, nil 374} 375 376// Unlock releases the lock by deleting the lock record from the 377// PostgreSQL table. 378func (l *PostgreSQLLock) Unlock() error { 379 pg := l.backend 380 pg.permitPool.Acquire() 381 defer pg.permitPool.Release() 382 383 if l.renewTicker != nil { 384 l.renewTicker.Stop() 385 } 386 387 // Delete lock owned by me 388 _, err := pg.client.Exec(pg.haDeleteLockExec, l.identity, l.key) 389 return err 390} 391 392// Value checks whether or not the lock is held by any instance of PostgreSQLLock, 393// including this one, and returns the current value. 394func (l *PostgreSQLLock) Value() (bool, string, error) { 395 pg := l.backend 396 pg.permitPool.Acquire() 397 defer pg.permitPool.Release() 398 var result string 399 err := pg.client.QueryRow(pg.haGetLockValueQuery, l.key).Scan(&result) 400 401 switch err { 402 case nil: 403 return true, result, nil 404 case sql.ErrNoRows: 405 return false, "", nil 406 default: 407 return false, "", err 408 409 } 410} 411 412// tryToLock tries to create a new item in PostgreSQL every `retryInterval`. 413// As long as the item cannot be created (because it already exists), it will 414// be retried. If the operation fails due to an error, it is sent to the errors 415// channel. When the lock could be acquired successfully, the success channel 416// is closed. 417func (l *PostgreSQLLock) tryToLock(stop <-chan struct{}, success chan struct{}, errors chan error) { 418 ticker := time.NewTicker(l.retryInterval) 419 defer ticker.Stop() 420 421 for { 422 select { 423 case <-stop: 424 return 425 case <-ticker.C: 426 gotlock, err := l.writeItem() 427 switch { 428 case err != nil: 429 errors <- err 430 return 431 case gotlock: 432 close(success) 433 return 434 } 435 } 436 } 437} 438 439func (l *PostgreSQLLock) periodicallyRenewLock(done chan struct{}) { 440 for range l.renewTicker.C { 441 gotlock, err := l.writeItem() 442 if err != nil || !gotlock { 443 close(done) 444 l.renewTicker.Stop() 445 return 446 } 447 } 448} 449 450// Attempts to put/update the PostgreSQL item using condition expressions to 451// evaluate the TTL. Returns true if the lock was obtained, false if not. 452// If false error may be nil or non-nil: nil indicates simply that someone 453// else has the lock, whereas non-nil means that something unexpected happened. 454func (l *PostgreSQLLock) writeItem() (bool, error) { 455 pg := l.backend 456 pg.permitPool.Acquire() 457 defer pg.permitPool.Release() 458 459 // Try steal lock or update expiry on my lock 460 461 sqlResult, err := pg.client.Exec(pg.haUpsertLockIdentityExec, l.identity, l.key, l.value, l.ttlSeconds) 462 if err != nil { 463 return false, err 464 } 465 if sqlResult == nil { 466 return false, fmt.Errorf("empty SQL response received") 467 } 468 469 ar, err := sqlResult.RowsAffected() 470 if err != nil { 471 return false, err 472 } 473 return ar == 1, nil 474} 475