1package mysql
2
3import (
4	"context"
5	"crypto/tls"
6	"crypto/x509"
7	"database/sql"
8	"errors"
9	"fmt"
10	"io/ioutil"
11	"math"
12	"net/url"
13	"sort"
14	"strconv"
15	"strings"
16	"sync"
17	"time"
18
19	log "github.com/hashicorp/go-hclog"
20
21	metrics "github.com/armon/go-metrics"
22	mysql "github.com/go-sql-driver/mysql"
23	"github.com/hashicorp/errwrap"
24	"github.com/hashicorp/vault/sdk/helper/strutil"
25	"github.com/hashicorp/vault/sdk/physical"
26)
27
28// Verify MySQLBackend satisfies the correct interfaces
29var _ physical.Backend = (*MySQLBackend)(nil)
30var _ physical.HABackend = (*MySQLBackend)(nil)
31var _ physical.Lock = (*MySQLHALock)(nil)
32
33// Unreserved tls key
34// Reserved values are "true", "false", "skip-verify"
35const mysqlTLSKey = "default"
36
37// MySQLBackend is a physical backend that stores data
38// within MySQL database.
39type MySQLBackend struct {
40	dbTable      string
41	dbLockTable  string
42	client       *sql.DB
43	statements   map[string]*sql.Stmt
44	logger       log.Logger
45	permitPool   *physical.PermitPool
46	conf         map[string]string
47	redirectHost string
48	redirectPort int64
49	haEnabled    bool
50}
51
52// NewMySQLBackend constructs a MySQL backend using the given API client and
53// server address and credential for accessing mysql database.
54func NewMySQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
55	var err error
56
57	db, err := NewMySQLClient(conf, logger)
58	if err != nil {
59		return nil, err
60	}
61
62	database, ok := conf["database"]
63	if !ok {
64		database = "vault"
65	}
66	table, ok := conf["table"]
67	if !ok {
68		table = "vault"
69	}
70	dbTable := "`" + database + "`.`" + table + "`"
71
72	maxParStr, ok := conf["max_parallel"]
73	var maxParInt int
74	if ok {
75		maxParInt, err = strconv.Atoi(maxParStr)
76		if err != nil {
77			return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
78		}
79		if logger.IsDebug() {
80			logger.Debug("max_parallel set", "max_parallel", maxParInt)
81		}
82	} else {
83		maxParInt = physical.DefaultParallelOperations
84	}
85
86	// Check schema exists
87	var schemaExist bool
88	schemaRows, err := db.Query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?", database)
89	if err != nil {
90		return nil, errwrap.Wrapf("failed to check mysql schema exist: {{err}}", err)
91	}
92	defer schemaRows.Close()
93	schemaExist = schemaRows.Next()
94
95	// Check table exists
96	var tableExist bool
97	tableRows, err := db.Query("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?", table, database)
98
99	if err != nil {
100		return nil, errwrap.Wrapf("failed to check mysql table exist: {{err}}", err)
101	}
102	defer tableRows.Close()
103	tableExist = tableRows.Next()
104
105	// Create the required database if it doesn't exists.
106	if !schemaExist {
107		if _, err := db.Exec("CREATE DATABASE IF NOT EXISTS `" + database + "`"); err != nil {
108			return nil, errwrap.Wrapf("failed to create mysql database: {{err}}", err)
109		}
110	}
111
112	// Create the required table if it doesn't exists.
113	if !tableExist {
114		create_query := "CREATE TABLE IF NOT EXISTS " + dbTable +
115			" (vault_key varbinary(512), vault_value mediumblob, PRIMARY KEY (vault_key))"
116		if _, err := db.Exec(create_query); err != nil {
117			return nil, errwrap.Wrapf("failed to create mysql table: {{err}}", err)
118		}
119	}
120
121	// Default value for ha_enabled
122	haEnabledStr, ok := conf["ha_enabled"]
123	if !ok {
124		haEnabledStr = "false"
125	}
126	haEnabled, err := strconv.ParseBool(haEnabledStr)
127	if err != nil {
128		return nil, fmt.Errorf("value [%v] of 'ha_enabled' could not be understood", haEnabledStr)
129	}
130
131	locktable, ok := conf["lock_table"]
132	if !ok {
133		locktable = table + "_lock"
134	}
135
136	dbLockTable := "`" + database + "`.`" + locktable + "`"
137
138	// Only create lock table if ha_enabled is true
139	if haEnabled {
140		// Check table exists
141		var lockTableExist bool
142		lockTableRows, err := db.Query("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?", locktable, database)
143
144		if err != nil {
145			return nil, errwrap.Wrapf("failed to check mysql table exist: {{err}}", err)
146		}
147		defer lockTableRows.Close()
148		lockTableExist = lockTableRows.Next()
149
150		// Create the required table if it doesn't exists.
151		if !lockTableExist {
152			create_query := "CREATE TABLE IF NOT EXISTS " + dbLockTable +
153				" (node_job varbinary(512), current_leader varbinary(512), PRIMARY KEY (node_job))"
154			if _, err := db.Exec(create_query); err != nil {
155				return nil, errwrap.Wrapf("failed to create mysql table: {{err}}", err)
156			}
157		}
158	}
159
160	// Setup the backend.
161	m := &MySQLBackend{
162		dbTable:     dbTable,
163		dbLockTable: dbLockTable,
164		client:      db,
165		statements:  make(map[string]*sql.Stmt),
166		logger:      logger,
167		permitPool:  physical.NewPermitPool(maxParInt),
168		conf:        conf,
169		haEnabled:   haEnabled,
170	}
171
172	// Prepare all the statements required
173	statements := map[string]string{
174		"put": "INSERT INTO " + dbTable +
175			" VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)",
176		"get":    "SELECT vault_value FROM " + dbTable + " WHERE vault_key = ?",
177		"delete": "DELETE FROM " + dbTable + " WHERE vault_key = ?",
178		"list":   "SELECT vault_key FROM " + dbTable + " WHERE vault_key LIKE ?",
179	}
180
181	// Only prepare ha-related statements if we need them
182	if haEnabled {
183		statements["get_lock"] = "SELECT current_leader FROM " + dbLockTable + " WHERE node_job = ?"
184		statements["used_lock"] = "SELECT IS_USED_LOCK(?)"
185	}
186
187	for name, query := range statements {
188		if err := m.prepare(name, query); err != nil {
189			return nil, err
190		}
191	}
192
193	return m, nil
194}
195
196func NewMySQLClient(conf map[string]string, logger log.Logger) (*sql.DB, error) {
197	var err error
198
199	// Get the MySQL credentials to perform read/write operations.
200	username, ok := conf["username"]
201	if !ok || username == "" {
202		return nil, fmt.Errorf("missing username")
203	}
204	password, ok := conf["password"]
205	if !ok || password == "" {
206		return nil, fmt.Errorf("missing password")
207	}
208
209	// Get or set MySQL server address. Defaults to localhost and default port(3306)
210	address, ok := conf["address"]
211	if !ok {
212		address = "127.0.0.1:3306"
213	}
214
215	maxIdleConnStr, ok := conf["max_idle_connections"]
216	var maxIdleConnInt int
217	if ok {
218		maxIdleConnInt, err = strconv.Atoi(maxIdleConnStr)
219		if err != nil {
220			return nil, errwrap.Wrapf("failed parsing max_idle_connections parameter: {{err}}", err)
221		}
222		if logger.IsDebug() {
223			logger.Debug("max_idle_connections set", "max_idle_connections", maxIdleConnInt)
224		}
225	}
226
227	maxConnLifeStr, ok := conf["max_connection_lifetime"]
228	var maxConnLifeInt int
229	if ok {
230		maxConnLifeInt, err = strconv.Atoi(maxConnLifeStr)
231		if err != nil {
232			return nil, errwrap.Wrapf("failed parsing max_connection_lifetime parameter: {{err}}", err)
233		}
234		if logger.IsDebug() {
235			logger.Debug("max_connection_lifetime set", "max_connection_lifetime", maxConnLifeInt)
236		}
237	}
238
239	maxParStr, ok := conf["max_parallel"]
240	var maxParInt int
241	if ok {
242		maxParInt, err = strconv.Atoi(maxParStr)
243		if err != nil {
244			return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
245		}
246		if logger.IsDebug() {
247			logger.Debug("max_parallel set", "max_parallel", maxParInt)
248		}
249	} else {
250		maxParInt = physical.DefaultParallelOperations
251	}
252
253	dsnParams := url.Values{}
254	tlsCaFile, ok := conf["tls_ca_file"]
255	if ok {
256		if err := setupMySQLTLSConfig(tlsCaFile); err != nil {
257			return nil, errwrap.Wrapf("failed register TLS config: {{err}}", err)
258		}
259
260		dsnParams.Add("tls", mysqlTLSKey)
261	}
262
263	// Create MySQL handle for the database.
264	dsn := username + ":" + password + "@tcp(" + address + ")/?" + dsnParams.Encode()
265	db, err := sql.Open("mysql", dsn)
266	if err != nil {
267		return nil, errwrap.Wrapf("failed to connect to mysql: {{err}}", err)
268	}
269	db.SetMaxOpenConns(maxParInt)
270	if maxIdleConnInt != 0 {
271		db.SetMaxIdleConns(maxIdleConnInt)
272	}
273	if maxConnLifeInt != 0 {
274		db.SetConnMaxLifetime(time.Duration(maxConnLifeInt) * time.Second)
275	}
276
277	return db, err
278}
279
280// prepare is a helper to prepare a query for future execution
281func (m *MySQLBackend) prepare(name, query string) error {
282	stmt, err := m.client.Prepare(query)
283	if err != nil {
284		return errwrap.Wrapf(fmt.Sprintf("failed to prepare %q: {{err}}", name), err)
285	}
286	m.statements[name] = stmt
287	return nil
288}
289
290// Put is used to insert or update an entry.
291func (m *MySQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
292	defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now())
293
294	m.permitPool.Acquire()
295	defer m.permitPool.Release()
296
297	_, err := m.statements["put"].Exec(entry.Key, entry.Value)
298	if err != nil {
299		return err
300	}
301	return nil
302}
303
304// Get is used to fetch an entry.
305func (m *MySQLBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
306	defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now())
307
308	m.permitPool.Acquire()
309	defer m.permitPool.Release()
310
311	var result []byte
312	err := m.statements["get"].QueryRow(key).Scan(&result)
313	if err == sql.ErrNoRows {
314		return nil, nil
315	}
316	if err != nil {
317		return nil, err
318	}
319
320	ent := &physical.Entry{
321		Key:   key,
322		Value: result,
323	}
324	return ent, nil
325}
326
327// Delete is used to permanently delete an entry
328func (m *MySQLBackend) Delete(ctx context.Context, key string) error {
329	defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now())
330
331	m.permitPool.Acquire()
332	defer m.permitPool.Release()
333
334	_, err := m.statements["delete"].Exec(key)
335	if err != nil {
336		return err
337	}
338	return nil
339}
340
341// List is used to list all the keys under a given
342// prefix, up to the next prefix.
343func (m *MySQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
344	defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now())
345
346	m.permitPool.Acquire()
347	defer m.permitPool.Release()
348
349	// Add the % wildcard to the prefix to do the prefix search
350	likePrefix := prefix + "%"
351	rows, err := m.statements["list"].Query(likePrefix)
352	if err != nil {
353		return nil, errwrap.Wrapf("failed to execute statement: {{err}}", err)
354	}
355
356	var keys []string
357	for rows.Next() {
358		var key string
359		err = rows.Scan(&key)
360		if err != nil {
361			return nil, errwrap.Wrapf("failed to scan rows: {{err}}", err)
362		}
363
364		key = strings.TrimPrefix(key, prefix)
365		if i := strings.Index(key, "/"); i == -1 {
366			// Add objects only from the current 'folder'
367			keys = append(keys, key)
368		} else if i != -1 {
369			// Add truncated 'folder' paths
370			keys = strutil.AppendIfMissing(keys, string(key[:i+1]))
371		}
372	}
373
374	sort.Strings(keys)
375	return keys, nil
376}
377
378// LockWith is used for mutual exclusion based on the given key.
379func (m *MySQLBackend) LockWith(key, value string) (physical.Lock, error) {
380	l := &MySQLHALock{
381		in:     m,
382		key:    key,
383		value:  value,
384		logger: m.logger,
385	}
386	return l, nil
387}
388
389func (m *MySQLBackend) HAEnabled() bool {
390	return m.haEnabled
391}
392
393// MySQLHALock is a MySQL Lock implementation for the HABackend
394type MySQLHALock struct {
395	in     *MySQLBackend
396	key    string
397	value  string
398	logger log.Logger
399
400	held      bool
401	localLock sync.Mutex
402	leaderCh  chan struct{}
403	stopCh    <-chan struct{}
404	lock      *MySQLLock
405}
406
407func (i *MySQLHALock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
408	i.localLock.Lock()
409	defer i.localLock.Unlock()
410	if i.held {
411		return nil, fmt.Errorf("lock already held")
412	}
413
414	// Attempt an async acquisition
415	didLock := make(chan struct{})
416	failLock := make(chan error, 1)
417	releaseCh := make(chan bool, 1)
418	go i.attemptLock(i.key, i.value, didLock, failLock, releaseCh)
419
420	// Wait for lock acquisition, failure, or shutdown
421	select {
422	case <-didLock:
423		releaseCh <- false
424	case err := <-failLock:
425		return nil, err
426	case <-stopCh:
427		releaseCh <- true
428		return nil, nil
429	}
430
431	// Create the leader channel
432	i.held = true
433	i.leaderCh = make(chan struct{})
434
435	go i.monitorLock(i.leaderCh)
436
437	i.stopCh = stopCh
438
439	return i.leaderCh, nil
440}
441
442func (i *MySQLHALock) attemptLock(key, value string, didLock chan struct{}, failLock chan error, releaseCh chan bool) {
443	lock, err := NewMySQLLock(i.in, i.logger, key, value)
444
445	// Set node value
446	i.lock = lock
447
448	if err != nil {
449		failLock <- err
450	}
451
452	err = lock.Lock()
453	if err != nil {
454		failLock <- err
455		return
456	}
457
458	// Signal that lock is held
459	close(didLock)
460
461	// Handle an early abort
462	release := <-releaseCh
463	if release {
464		lock.Unlock()
465	}
466}
467
468func (i *MySQLHALock) monitorLock(leaderCh chan struct{}) {
469	for {
470		// The only way to lose this lock is if someone is
471		// logging into the DB and altering system tables or you lose a connection in
472		// which case you will lose the lock anyway.
473		err := i.hasLock(i.key)
474		if err != nil {
475			// Somehow we lost the lock.... likely because the connection holding
476			// the lock was closed or someone was playing around with the locks in the DB.
477			close(leaderCh)
478			return
479		}
480
481		time.Sleep(5 * time.Second)
482	}
483}
484
485func (i *MySQLHALock) Unlock() error {
486	i.localLock.Lock()
487	defer i.localLock.Unlock()
488	if !i.held {
489		return nil
490	}
491
492	err := i.lock.Unlock()
493
494	if err == nil {
495		i.held = false
496		return nil
497	}
498
499	return err
500}
501
502// hasLock will check if a lock is held by checking the current lock id against our known ID.
503func (i *MySQLHALock) hasLock(key string) error {
504	var result sql.NullInt64
505	err := i.in.statements["used_lock"].QueryRow(key).Scan(&result)
506	if err == sql.ErrNoRows || !result.Valid {
507		// This is not an error to us since it just means the lock isn't held
508		return nil
509	}
510
511	if err != nil {
512		return err
513	}
514
515	// IS_USED_LOCK will return the ID of the connection that created the lock.
516	if result.Int64 != GlobalLockID {
517		return ErrLockHeld
518	}
519
520	return nil
521}
522
523func (i *MySQLHALock) GetLeader() (string, error) {
524	defer metrics.MeasureSince([]string{"mysql", "lock_get"}, time.Now())
525	var result string
526	err := i.in.statements["get_lock"].QueryRow("leader").Scan(&result)
527	if err == sql.ErrNoRows {
528		return "", err
529	}
530
531	return result, nil
532}
533
534func (i *MySQLHALock) Value() (bool, string, error) {
535	leaderkey, err := i.GetLeader()
536	if err != nil {
537		return false, "", err
538	}
539
540	return true, leaderkey, err
541}
542
543// MySQLLock provides an easy way to grab and release mysql
544// locks using the built in GET_LOCK function. Note that these
545// locks are released when you lose connection to the server.
546type MySQLLock struct {
547	parentConn *MySQLBackend
548	in         *sql.DB
549	logger     log.Logger
550	statements map[string]*sql.Stmt
551	key        string
552	value      string
553}
554
555// Errors specific to trying to grab a lock in MySQL
556var (
557	// This is the GlobalLockID for checking if the lock we got is still the current lock
558	GlobalLockID int64
559	// ErrLockHeld is returned when another vault instance already has a lock held for the given key.
560	ErrLockHeld = errors.New("mysql: lock already held")
561	// ErrUnlockFailed
562	ErrUnlockFailed = errors.New("mysql: unable to release lock, already released or not held by this session")
563	// You were unable to update that you are the new leader in the DB
564	ErrClaimFailed = errors.New("mysql: unable to update DB with new leader information")
565	// Error to throw if between getting the lock and checking the ID of it we lost it.
566	ErrSettingGlobalID = errors.New("mysql: getting global lock id failed")
567)
568
569// NewMySQLLock helper function
570func NewMySQLLock(in *MySQLBackend, l log.Logger, key, value string) (*MySQLLock, error) {
571	// Create a new MySQL connection so we can close this and have no effect on
572	// the rest of the MySQL backend and any cleanup that might need to be done.
573	conn, _ := NewMySQLClient(in.conf, in.logger)
574
575	m := &MySQLLock{
576		parentConn: in,
577		in:         conn,
578		logger:     l,
579		statements: make(map[string]*sql.Stmt),
580		key:        key,
581		value:      value,
582	}
583
584	statements := map[string]string{
585		"put": "INSERT INTO " + in.dbLockTable +
586			" VALUES( ?, ? ) ON DUPLICATE KEY UPDATE current_leader=VALUES(current_leader)",
587	}
588
589	for name, query := range statements {
590		if err := m.prepare(name, query); err != nil {
591			return nil, err
592		}
593	}
594
595	return m, nil
596}
597
598// prepare is a helper to prepare a query for future execution
599func (m *MySQLLock) prepare(name, query string) error {
600	stmt, err := m.in.Prepare(query)
601	if err != nil {
602		return errwrap.Wrapf(fmt.Sprintf("failed to prepare %q: {{err}}", name), err)
603	}
604	m.statements[name] = stmt
605	return nil
606}
607
608// update the current cluster leader in the DB. This is used so
609// we can tell the servers in standby who the active leader is.
610func (i *MySQLLock) becomeLeader() error {
611	_, err := i.statements["put"].Exec("leader", i.value)
612	if err != nil {
613		return err
614	}
615
616	return nil
617}
618
619// Lock will try to get a lock for an indefinite amount of time
620// based on the given key that has been requested.
621func (i *MySQLLock) Lock() error {
622	defer metrics.MeasureSince([]string{"mysql", "get_lock"}, time.Now())
623
624	// Lock timeout math.MaxInt32 instead of -1 solves compatibility issues with
625	// different MySQL flavours i.e. MariaDB
626	rows, err := i.in.Query("SELECT GET_LOCK(?, ?), IS_USED_LOCK(?)", i.key, math.MaxInt32, i.key)
627	if err != nil {
628		return err
629	}
630
631	defer rows.Close()
632	rows.Next()
633	var lock sql.NullInt64
634	var connectionID sql.NullInt64
635	rows.Scan(&lock, &connectionID)
636
637	if rows.Err() != nil {
638		return rows.Err()
639	}
640
641	// 1 is returned from GET_LOCK if it was able to get the lock
642	// 0 if it failed and NULL if some strange error happened.
643	// https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_get-lock
644	if !lock.Valid || lock.Int64 != 1 {
645		return ErrLockHeld
646	}
647
648	// Since we have the lock alert the rest of the cluster
649	// that we are now the active leader.
650	err = i.becomeLeader()
651	if err != nil {
652		return ErrLockHeld
653	}
654
655	// This will return the connection ID of NULL if an error happens
656	// https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_is-used-lock
657	if !connectionID.Valid {
658		return ErrSettingGlobalID
659	}
660
661	GlobalLockID = connectionID.Int64
662
663	return nil
664}
665
666// Unlock just closes the connection. This is because closing the MySQL connection
667// is a 100% reliable way to close the lock. If you just release the lock you must
668// do it from the same mysql connection_id that you originally created it from. This
669// is a huge hastle and I actually couldn't find a clean way to do this although one
670// likely does exist. Closing the connection however ensures we don't ever get into a
671// state where we try to release the lock and it hangs it is also much less code.
672func (i *MySQLLock) Unlock() error {
673	err := i.in.Close()
674	if err != nil {
675		return ErrUnlockFailed
676	}
677
678	return nil
679}
680
681// Establish a TLS connection with a given CA certificate
682// Register a tsl.Config associated with the same key as the dns param from sql.Open
683// foo:bar@tcp(127.0.0.1:3306)/dbname?tls=default
684func setupMySQLTLSConfig(tlsCaFile string) error {
685	rootCertPool := x509.NewCertPool()
686
687	pem, err := ioutil.ReadFile(tlsCaFile)
688	if err != nil {
689		return err
690	}
691
692	if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
693		return err
694	}
695
696	err = mysql.RegisterTLSConfig(mysqlTLSKey, &tls.Config{
697		RootCAs: rootCertPool,
698	})
699	if err != nil {
700		return err
701	}
702
703	return nil
704}
705