1package postgresql
2
3import (
4	"context"
5	"database/sql"
6	"fmt"
7	"strconv"
8	"strings"
9	"time"
10
11	"github.com/hashicorp/errwrap"
12	"github.com/hashicorp/vault/physical"
13
14	log "github.com/hashicorp/go-hclog"
15
16	metrics "github.com/armon/go-metrics"
17	"github.com/lib/pq"
18)
19
20// Verify PostgreSQLBackend satisfies the correct interfaces
21var _ physical.Backend = (*PostgreSQLBackend)(nil)
22
23// PostgreSQL Backend is a physical backend that stores data
24// within a PostgreSQL database.
25type PostgreSQLBackend struct {
26	table        string
27	client       *sql.DB
28	put_query    string
29	get_query    string
30	delete_query string
31	list_query   string
32	logger       log.Logger
33	permitPool   *physical.PermitPool
34}
35
36// NewPostgreSQLBackend constructs a PostgreSQL backend using the given
37// API client, server address, credentials, and database.
38func NewPostgreSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
39	// Get the PostgreSQL credentials to perform read/write operations.
40	connURL, ok := conf["connection_url"]
41	if !ok || connURL == "" {
42		return nil, fmt.Errorf("missing connection_url")
43	}
44
45	unquoted_table, ok := conf["table"]
46	if !ok {
47		unquoted_table = "vault_kv_store"
48	}
49	quoted_table := pq.QuoteIdentifier(unquoted_table)
50
51	maxParStr, ok := conf["max_parallel"]
52	var maxParInt int
53	var err error
54	if ok {
55		maxParInt, err = strconv.Atoi(maxParStr)
56		if err != nil {
57			return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
58		}
59		if logger.IsDebug() {
60			logger.Debug("max_parallel set", "max_parallel", maxParInt)
61		}
62	} else {
63		maxParInt = physical.DefaultParallelOperations
64	}
65
66	// Create PostgreSQL handle for the database.
67	db, err := sql.Open("postgres", connURL)
68	if err != nil {
69		return nil, errwrap.Wrapf("failed to connect to postgres: {{err}}", err)
70	}
71	db.SetMaxOpenConns(maxParInt)
72
73	// Determine if we should use an upsert function (versions < 9.5)
74	var upsert_required bool
75	upsert_required_query := "SELECT current_setting('server_version_num')::int < 90500"
76	if err := db.QueryRow(upsert_required_query).Scan(&upsert_required); err != nil {
77		return nil, errwrap.Wrapf("failed to check for native upsert: {{err}}", err)
78	}
79
80	// Setup our put strategy based on the presence or absence of a native
81	// upsert.
82	var put_query string
83	if upsert_required {
84		put_query = "SELECT vault_kv_put($1, $2, $3, $4)"
85	} else {
86		put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" +
87			" ON CONFLICT (path, key) DO " +
88			" UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)"
89	}
90
91	// Setup the backend.
92	m := &PostgreSQLBackend{
93		table:        quoted_table,
94		client:       db,
95		put_query:    put_query,
96		get_query:    "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
97		delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
98		list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" +
99			"UNION SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " +
100			quoted_table + " WHERE parent_path LIKE $1 || '%'",
101		logger:     logger,
102		permitPool: physical.NewPermitPool(maxParInt),
103	}
104
105	return m, nil
106}
107
108// splitKey is a helper to split a full path key into individual
109// parts: parentPath, path, key
110func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) {
111	var parentPath string
112	var path string
113
114	pieces := strings.Split(fullPath, "/")
115	depth := len(pieces)
116	key := pieces[depth-1]
117
118	if depth == 1 {
119		parentPath = ""
120		path = "/"
121	} else if depth == 2 {
122		parentPath = "/"
123		path = "/" + pieces[0] + "/"
124	} else {
125		parentPath = "/" + strings.Join(pieces[:depth-2], "/") + "/"
126		path = "/" + strings.Join(pieces[:depth-1], "/") + "/"
127	}
128
129	return parentPath, path, key
130}
131
132// Put is used to insert or update an entry.
133func (m *PostgreSQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
134	defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now())
135
136	m.permitPool.Acquire()
137	defer m.permitPool.Release()
138
139	parentPath, path, key := m.splitKey(entry.Key)
140
141	_, err := m.client.Exec(m.put_query, parentPath, path, key, entry.Value)
142	if err != nil {
143		return err
144	}
145	return nil
146}
147
148// Get is used to fetch and entry.
149func (m *PostgreSQLBackend) Get(ctx context.Context, fullPath string) (*physical.Entry, error) {
150	defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now())
151
152	m.permitPool.Acquire()
153	defer m.permitPool.Release()
154
155	_, path, key := m.splitKey(fullPath)
156
157	var result []byte
158	err := m.client.QueryRow(m.get_query, path, key).Scan(&result)
159	if err == sql.ErrNoRows {
160		return nil, nil
161	}
162	if err != nil {
163		return nil, err
164	}
165
166	ent := &physical.Entry{
167		Key:   fullPath,
168		Value: result,
169	}
170	return ent, nil
171}
172
173// Delete is used to permanently delete an entry
174func (m *PostgreSQLBackend) Delete(ctx context.Context, fullPath string) error {
175	defer metrics.MeasureSince([]string{"postgres", "delete"}, time.Now())
176
177	m.permitPool.Acquire()
178	defer m.permitPool.Release()
179
180	_, path, key := m.splitKey(fullPath)
181
182	_, err := m.client.Exec(m.delete_query, path, key)
183	if err != nil {
184		return err
185	}
186	return nil
187}
188
189// List is used to list all the keys under a given
190// prefix, up to the next prefix.
191func (m *PostgreSQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
192	defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now())
193
194	m.permitPool.Acquire()
195	defer m.permitPool.Release()
196
197	rows, err := m.client.Query(m.list_query, "/"+prefix)
198	if err != nil {
199		return nil, err
200	}
201	defer rows.Close()
202
203	var keys []string
204	for rows.Next() {
205		var key string
206		err = rows.Scan(&key)
207		if err != nil {
208			return nil, errwrap.Wrapf("failed to scan rows: {{err}}", err)
209		}
210
211		keys = append(keys, key)
212	}
213
214	return keys, nil
215}
216