1package mssql
2
3import (
4	"context"
5	"database/sql"
6	"fmt"
7	"sort"
8	"strconv"
9	"strings"
10	"time"
11
12	metrics "github.com/armon/go-metrics"
13	_ "github.com/denisenkom/go-mssqldb"
14	"github.com/hashicorp/errwrap"
15	log "github.com/hashicorp/go-hclog"
16	"github.com/hashicorp/vault/sdk/helper/strutil"
17	"github.com/hashicorp/vault/sdk/physical"
18)
19
20// Verify MSSQLBackend satisfies the correct interfaces
21var _ physical.Backend = (*MSSQLBackend)(nil)
22
23type MSSQLBackend struct {
24	dbTable    string
25	client     *sql.DB
26	statements map[string]*sql.Stmt
27	logger     log.Logger
28	permitPool *physical.PermitPool
29}
30
31func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
32	username, ok := conf["username"]
33	if !ok {
34		username = ""
35	}
36
37	password, ok := conf["password"]
38	if !ok {
39		password = ""
40	}
41
42	server, ok := conf["server"]
43	if !ok || server == "" {
44		return nil, fmt.Errorf("missing server")
45	}
46
47	port, ok := conf["port"]
48	if !ok {
49		port = ""
50	}
51
52	maxParStr, ok := conf["max_parallel"]
53	var maxParInt int
54	var err error
55	if ok {
56		maxParInt, err = strconv.Atoi(maxParStr)
57		if err != nil {
58			return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
59		}
60		if logger.IsDebug() {
61			logger.Debug("max_parallel set", "max_parallel", maxParInt)
62		}
63	} else {
64		maxParInt = physical.DefaultParallelOperations
65	}
66
67	database, ok := conf["database"]
68	if !ok {
69		database = "Vault"
70	}
71
72	table, ok := conf["table"]
73	if !ok {
74		table = "Vault"
75	}
76
77	appname, ok := conf["appname"]
78	if !ok {
79		appname = "Vault"
80	}
81
82	connectionTimeout, ok := conf["connectiontimeout"]
83	if !ok {
84		connectionTimeout = "30"
85	}
86
87	logLevel, ok := conf["loglevel"]
88	if !ok {
89		logLevel = "0"
90	}
91
92	schema, ok := conf["schema"]
93	if !ok || schema == "" {
94		schema = "dbo"
95	}
96
97	connectionString := fmt.Sprintf("server=%s;app name=%s;connection timeout=%s;log=%s", server, appname, connectionTimeout, logLevel)
98	if username != "" {
99		connectionString += ";user id=" + username
100	}
101
102	if password != "" {
103		connectionString += ";password=" + password
104	}
105
106	if port != "" {
107		connectionString += ";port=" + port
108	}
109
110	db, err := sql.Open("mssql", connectionString)
111	if err != nil {
112		return nil, errwrap.Wrapf("failed to connect to mssql: {{err}}", err)
113	}
114
115	db.SetMaxOpenConns(maxParInt)
116
117	if _, err := db.Exec("IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = '" + database + "') CREATE DATABASE " + database); err != nil {
118		return nil, errwrap.Wrapf("failed to create mssql database: {{err}}", err)
119	}
120
121	dbTable := database + "." + schema + "." + table
122	createQuery := "IF NOT EXISTS(SELECT 1 FROM " + database + ".INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' AND TABLE_NAME='" + table + "' AND TABLE_SCHEMA='" + schema +
123		"') CREATE TABLE " + dbTable + " (Path VARCHAR(512) PRIMARY KEY, Value VARBINARY(MAX))"
124
125	if schema != "dbo" {
126
127		var num int
128		err = db.QueryRow("SELECT 1 FROM " + database + ".sys.schemas WHERE name = '" + schema + "'").Scan(&num)
129
130		switch {
131		case err == sql.ErrNoRows:
132			if _, err := db.Exec("USE " + database + "; EXEC ('CREATE SCHEMA " + schema + "')"); err != nil {
133				return nil, errwrap.Wrapf("failed to create mssql schema: {{err}}", err)
134			}
135
136		case err != nil:
137			return nil, errwrap.Wrapf("failed to check if mssql schema exists: {{err}}", err)
138		}
139	}
140
141	if _, err := db.Exec(createQuery); err != nil {
142		return nil, errwrap.Wrapf("failed to create mssql table: {{err}}", err)
143	}
144
145	m := &MSSQLBackend{
146		dbTable:    dbTable,
147		client:     db,
148		statements: make(map[string]*sql.Stmt),
149		logger:     logger,
150		permitPool: physical.NewPermitPool(maxParInt),
151	}
152
153	statements := map[string]string{
154		"put": "IF EXISTS(SELECT 1 FROM " + dbTable + " WHERE Path = ?) UPDATE " + dbTable + " SET Value = ? WHERE Path = ?" +
155			" ELSE INSERT INTO " + dbTable + " VALUES(?, ?)",
156		"get":    "SELECT Value FROM " + dbTable + " WHERE Path = ?",
157		"delete": "DELETE FROM " + dbTable + " WHERE Path = ?",
158		"list":   "SELECT Path FROM " + dbTable + " WHERE Path LIKE ?",
159	}
160
161	for name, query := range statements {
162		if err := m.prepare(name, query); err != nil {
163			return nil, err
164		}
165	}
166
167	return m, nil
168}
169
170func (m *MSSQLBackend) prepare(name, query string) error {
171	stmt, err := m.client.Prepare(query)
172	if err != nil {
173		return errwrap.Wrapf(fmt.Sprintf("failed to prepare %q: {{err}}", name), err)
174	}
175
176	m.statements[name] = stmt
177
178	return nil
179}
180
181func (m *MSSQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
182	defer metrics.MeasureSince([]string{"mssql", "put"}, time.Now())
183
184	m.permitPool.Acquire()
185	defer m.permitPool.Release()
186
187	_, err := m.statements["put"].Exec(entry.Key, entry.Value, entry.Key, entry.Key, entry.Value)
188	if err != nil {
189		return err
190	}
191
192	return nil
193}
194
195func (m *MSSQLBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
196	defer metrics.MeasureSince([]string{"mssql", "get"}, time.Now())
197
198	m.permitPool.Acquire()
199	defer m.permitPool.Release()
200
201	var result []byte
202	err := m.statements["get"].QueryRow(key).Scan(&result)
203	if err == sql.ErrNoRows {
204		return nil, nil
205	}
206
207	if err != nil {
208		return nil, err
209	}
210
211	ent := &physical.Entry{
212		Key:   key,
213		Value: result,
214	}
215
216	return ent, nil
217}
218
219func (m *MSSQLBackend) Delete(ctx context.Context, key string) error {
220	defer metrics.MeasureSince([]string{"mssql", "delete"}, time.Now())
221
222	m.permitPool.Acquire()
223	defer m.permitPool.Release()
224
225	_, err := m.statements["delete"].Exec(key)
226	if err != nil {
227		return err
228	}
229
230	return nil
231}
232
233func (m *MSSQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
234	defer metrics.MeasureSince([]string{"mssql", "list"}, time.Now())
235
236	m.permitPool.Acquire()
237	defer m.permitPool.Release()
238
239	likePrefix := prefix + "%"
240	rows, err := m.statements["list"].Query(likePrefix)
241	if err != nil {
242		return nil, err
243	}
244	var keys []string
245	for rows.Next() {
246		var key string
247		err = rows.Scan(&key)
248		if err != nil {
249			return nil, errwrap.Wrapf("failed to scan rows: {{err}}", err)
250		}
251
252		key = strings.TrimPrefix(key, prefix)
253		if i := strings.Index(key, "/"); i == -1 {
254			keys = append(keys, key)
255		} else if i != -1 {
256			keys = strutil.AppendIfMissing(keys, string(key[:i+1]))
257		}
258	}
259
260	sort.Strings(keys)
261
262	return keys, nil
263}
264