1package postgresql 2 3import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "strconv" 8 "strings" 9 "time" 10 11 "github.com/hashicorp/errwrap" 12 "github.com/hashicorp/vault/physical" 13 14 log "github.com/hashicorp/go-hclog" 15 16 metrics "github.com/armon/go-metrics" 17 "github.com/lib/pq" 18) 19 20// Verify PostgreSQLBackend satisfies the correct interfaces 21var _ physical.Backend = (*PostgreSQLBackend)(nil) 22 23// PostgreSQL Backend is a physical backend that stores data 24// within a PostgreSQL database. 25type PostgreSQLBackend struct { 26 table string 27 client *sql.DB 28 put_query string 29 get_query string 30 delete_query string 31 list_query string 32 logger log.Logger 33 permitPool *physical.PermitPool 34} 35 36// NewPostgreSQLBackend constructs a PostgreSQL backend using the given 37// API client, server address, credentials, and database. 38func NewPostgreSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { 39 // Get the PostgreSQL credentials to perform read/write operations. 40 connURL, ok := conf["connection_url"] 41 if !ok || connURL == "" { 42 return nil, fmt.Errorf("missing connection_url") 43 } 44 45 unquoted_table, ok := conf["table"] 46 if !ok { 47 unquoted_table = "vault_kv_store" 48 } 49 quoted_table := pq.QuoteIdentifier(unquoted_table) 50 51 maxParStr, ok := conf["max_parallel"] 52 var maxParInt int 53 var err error 54 if ok { 55 maxParInt, err = strconv.Atoi(maxParStr) 56 if err != nil { 57 return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err) 58 } 59 if logger.IsDebug() { 60 logger.Debug("max_parallel set", "max_parallel", maxParInt) 61 } 62 } else { 63 maxParInt = physical.DefaultParallelOperations 64 } 65 66 // Create PostgreSQL handle for the database. 67 db, err := sql.Open("postgres", connURL) 68 if err != nil { 69 return nil, errwrap.Wrapf("failed to connect to postgres: {{err}}", err) 70 } 71 db.SetMaxOpenConns(maxParInt) 72 73 // Determine if we should use an upsert function (versions < 9.5) 74 var upsert_required bool 75 upsert_required_query := "SELECT current_setting('server_version_num')::int < 90500" 76 if err := db.QueryRow(upsert_required_query).Scan(&upsert_required); err != nil { 77 return nil, errwrap.Wrapf("failed to check for native upsert: {{err}}", err) 78 } 79 80 // Setup our put strategy based on the presence or absence of a native 81 // upsert. 82 var put_query string 83 if upsert_required { 84 put_query = "SELECT vault_kv_put($1, $2, $3, $4)" 85 } else { 86 put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" + 87 " ON CONFLICT (path, key) DO " + 88 " UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)" 89 } 90 91 // Setup the backend. 92 m := &PostgreSQLBackend{ 93 table: quoted_table, 94 client: db, 95 put_query: put_query, 96 get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2", 97 delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2", 98 list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" + 99 "UNION SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " + 100 quoted_table + " WHERE parent_path LIKE $1 || '%'", 101 logger: logger, 102 permitPool: physical.NewPermitPool(maxParInt), 103 } 104 105 return m, nil 106} 107 108// splitKey is a helper to split a full path key into individual 109// parts: parentPath, path, key 110func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) { 111 var parentPath string 112 var path string 113 114 pieces := strings.Split(fullPath, "/") 115 depth := len(pieces) 116 key := pieces[depth-1] 117 118 if depth == 1 { 119 parentPath = "" 120 path = "/" 121 } else if depth == 2 { 122 parentPath = "/" 123 path = "/" + pieces[0] + "/" 124 } else { 125 parentPath = "/" + strings.Join(pieces[:depth-2], "/") + "/" 126 path = "/" + strings.Join(pieces[:depth-1], "/") + "/" 127 } 128 129 return parentPath, path, key 130} 131 132// Put is used to insert or update an entry. 133func (m *PostgreSQLBackend) Put(ctx context.Context, entry *physical.Entry) error { 134 defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now()) 135 136 m.permitPool.Acquire() 137 defer m.permitPool.Release() 138 139 parentPath, path, key := m.splitKey(entry.Key) 140 141 _, err := m.client.Exec(m.put_query, parentPath, path, key, entry.Value) 142 if err != nil { 143 return err 144 } 145 return nil 146} 147 148// Get is used to fetch and entry. 149func (m *PostgreSQLBackend) Get(ctx context.Context, fullPath string) (*physical.Entry, error) { 150 defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now()) 151 152 m.permitPool.Acquire() 153 defer m.permitPool.Release() 154 155 _, path, key := m.splitKey(fullPath) 156 157 var result []byte 158 err := m.client.QueryRow(m.get_query, path, key).Scan(&result) 159 if err == sql.ErrNoRows { 160 return nil, nil 161 } 162 if err != nil { 163 return nil, err 164 } 165 166 ent := &physical.Entry{ 167 Key: fullPath, 168 Value: result, 169 } 170 return ent, nil 171} 172 173// Delete is used to permanently delete an entry 174func (m *PostgreSQLBackend) Delete(ctx context.Context, fullPath string) error { 175 defer metrics.MeasureSince([]string{"postgres", "delete"}, time.Now()) 176 177 m.permitPool.Acquire() 178 defer m.permitPool.Release() 179 180 _, path, key := m.splitKey(fullPath) 181 182 _, err := m.client.Exec(m.delete_query, path, key) 183 if err != nil { 184 return err 185 } 186 return nil 187} 188 189// List is used to list all the keys under a given 190// prefix, up to the next prefix. 191func (m *PostgreSQLBackend) List(ctx context.Context, prefix string) ([]string, error) { 192 defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now()) 193 194 m.permitPool.Acquire() 195 defer m.permitPool.Release() 196 197 rows, err := m.client.Query(m.list_query, "/"+prefix) 198 if err != nil { 199 return nil, err 200 } 201 defer rows.Close() 202 203 var keys []string 204 for rows.Next() { 205 var key string 206 err = rows.Scan(&key) 207 if err != nil { 208 return nil, errwrap.Wrapf("failed to scan rows: {{err}}", err) 209 } 210 211 keys = append(keys, key) 212 } 213 214 return keys, nil 215} 216