1package mysql 2 3import ( 4 "context" 5 "crypto/tls" 6 "crypto/x509" 7 "database/sql" 8 "errors" 9 "fmt" 10 "io/ioutil" 11 "math" 12 "net/url" 13 "sort" 14 "strconv" 15 "strings" 16 "sync" 17 "time" 18 19 log "github.com/hashicorp/go-hclog" 20 21 metrics "github.com/armon/go-metrics" 22 mysql "github.com/go-sql-driver/mysql" 23 "github.com/hashicorp/errwrap" 24 "github.com/hashicorp/vault/sdk/helper/strutil" 25 "github.com/hashicorp/vault/sdk/physical" 26) 27 28// Verify MySQLBackend satisfies the correct interfaces 29var _ physical.Backend = (*MySQLBackend)(nil) 30var _ physical.HABackend = (*MySQLBackend)(nil) 31var _ physical.Lock = (*MySQLHALock)(nil) 32 33// Unreserved tls key 34// Reserved values are "true", "false", "skip-verify" 35const mysqlTLSKey = "default" 36 37// MySQLBackend is a physical backend that stores data 38// within MySQL database. 39type MySQLBackend struct { 40 dbTable string 41 dbLockTable string 42 client *sql.DB 43 statements map[string]*sql.Stmt 44 logger log.Logger 45 permitPool *physical.PermitPool 46 conf map[string]string 47 redirectHost string 48 redirectPort int64 49 haEnabled bool 50} 51 52// NewMySQLBackend constructs a MySQL backend using the given API client and 53// server address and credential for accessing mysql database. 54func NewMySQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { 55 var err error 56 57 db, err := NewMySQLClient(conf, logger) 58 if err != nil { 59 return nil, err 60 } 61 62 database, ok := conf["database"] 63 if !ok { 64 database = "vault" 65 } 66 table, ok := conf["table"] 67 if !ok { 68 table = "vault" 69 } 70 dbTable := "`" + database + "`.`" + table + "`" 71 72 maxParStr, ok := conf["max_parallel"] 73 var maxParInt int 74 if ok { 75 maxParInt, err = strconv.Atoi(maxParStr) 76 if err != nil { 77 return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err) 78 } 79 if logger.IsDebug() { 80 logger.Debug("max_parallel set", "max_parallel", maxParInt) 81 } 82 } else { 83 maxParInt = physical.DefaultParallelOperations 84 } 85 86 // Check schema exists 87 var schemaExist bool 88 schemaRows, err := db.Query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?", database) 89 if err != nil { 90 return nil, errwrap.Wrapf("failed to check mysql schema exist: {{err}}", err) 91 } 92 defer schemaRows.Close() 93 schemaExist = schemaRows.Next() 94 95 // Check table exists 96 var tableExist bool 97 tableRows, err := db.Query("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?", table, database) 98 99 if err != nil { 100 return nil, errwrap.Wrapf("failed to check mysql table exist: {{err}}", err) 101 } 102 defer tableRows.Close() 103 tableExist = tableRows.Next() 104 105 // Create the required database if it doesn't exists. 106 if !schemaExist { 107 if _, err := db.Exec("CREATE DATABASE IF NOT EXISTS `" + database + "`"); err != nil { 108 return nil, errwrap.Wrapf("failed to create mysql database: {{err}}", err) 109 } 110 } 111 112 // Create the required table if it doesn't exists. 113 if !tableExist { 114 create_query := "CREATE TABLE IF NOT EXISTS " + dbTable + 115 " (vault_key varbinary(512), vault_value mediumblob, PRIMARY KEY (vault_key))" 116 if _, err := db.Exec(create_query); err != nil { 117 return nil, errwrap.Wrapf("failed to create mysql table: {{err}}", err) 118 } 119 } 120 121 // Default value for ha_enabled 122 haEnabledStr, ok := conf["ha_enabled"] 123 if !ok { 124 haEnabledStr = "false" 125 } 126 haEnabled, err := strconv.ParseBool(haEnabledStr) 127 if err != nil { 128 return nil, fmt.Errorf("value [%v] of 'ha_enabled' could not be understood", haEnabledStr) 129 } 130 131 locktable, ok := conf["lock_table"] 132 if !ok { 133 locktable = table + "_lock" 134 } 135 136 dbLockTable := "`" + database + "`.`" + locktable + "`" 137 138 // Only create lock table if ha_enabled is true 139 if haEnabled { 140 // Check table exists 141 var lockTableExist bool 142 lockTableRows, err := db.Query("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?", locktable, database) 143 144 if err != nil { 145 return nil, errwrap.Wrapf("failed to check mysql table exist: {{err}}", err) 146 } 147 defer lockTableRows.Close() 148 lockTableExist = lockTableRows.Next() 149 150 // Create the required table if it doesn't exists. 151 if !lockTableExist { 152 create_query := "CREATE TABLE IF NOT EXISTS " + dbLockTable + 153 " (node_job varbinary(512), current_leader varbinary(512), PRIMARY KEY (node_job))" 154 if _, err := db.Exec(create_query); err != nil { 155 return nil, errwrap.Wrapf("failed to create mysql table: {{err}}", err) 156 } 157 } 158 } 159 160 // Setup the backend. 161 m := &MySQLBackend{ 162 dbTable: dbTable, 163 dbLockTable: dbLockTable, 164 client: db, 165 statements: make(map[string]*sql.Stmt), 166 logger: logger, 167 permitPool: physical.NewPermitPool(maxParInt), 168 conf: conf, 169 haEnabled: haEnabled, 170 } 171 172 // Prepare all the statements required 173 statements := map[string]string{ 174 "put": "INSERT INTO " + dbTable + 175 " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)", 176 "get": "SELECT vault_value FROM " + dbTable + " WHERE vault_key = ?", 177 "delete": "DELETE FROM " + dbTable + " WHERE vault_key = ?", 178 "list": "SELECT vault_key FROM " + dbTable + " WHERE vault_key LIKE ?", 179 } 180 181 // Only prepare ha-related statements if we need them 182 if haEnabled { 183 statements["get_lock"] = "SELECT current_leader FROM " + dbLockTable + " WHERE node_job = ?" 184 statements["used_lock"] = "SELECT IS_USED_LOCK(?)" 185 } 186 187 for name, query := range statements { 188 if err := m.prepare(name, query); err != nil { 189 return nil, err 190 } 191 } 192 193 return m, nil 194} 195 196func NewMySQLClient(conf map[string]string, logger log.Logger) (*sql.DB, error) { 197 var err error 198 199 // Get the MySQL credentials to perform read/write operations. 200 username, ok := conf["username"] 201 if !ok || username == "" { 202 return nil, fmt.Errorf("missing username") 203 } 204 password, ok := conf["password"] 205 if !ok || password == "" { 206 return nil, fmt.Errorf("missing password") 207 } 208 209 // Get or set MySQL server address. Defaults to localhost and default port(3306) 210 address, ok := conf["address"] 211 if !ok { 212 address = "127.0.0.1:3306" 213 } 214 215 maxIdleConnStr, ok := conf["max_idle_connections"] 216 var maxIdleConnInt int 217 if ok { 218 maxIdleConnInt, err = strconv.Atoi(maxIdleConnStr) 219 if err != nil { 220 return nil, errwrap.Wrapf("failed parsing max_idle_connections parameter: {{err}}", err) 221 } 222 if logger.IsDebug() { 223 logger.Debug("max_idle_connections set", "max_idle_connections", maxIdleConnInt) 224 } 225 } 226 227 maxConnLifeStr, ok := conf["max_connection_lifetime"] 228 var maxConnLifeInt int 229 if ok { 230 maxConnLifeInt, err = strconv.Atoi(maxConnLifeStr) 231 if err != nil { 232 return nil, errwrap.Wrapf("failed parsing max_connection_lifetime parameter: {{err}}", err) 233 } 234 if logger.IsDebug() { 235 logger.Debug("max_connection_lifetime set", "max_connection_lifetime", maxConnLifeInt) 236 } 237 } 238 239 maxParStr, ok := conf["max_parallel"] 240 var maxParInt int 241 if ok { 242 maxParInt, err = strconv.Atoi(maxParStr) 243 if err != nil { 244 return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err) 245 } 246 if logger.IsDebug() { 247 logger.Debug("max_parallel set", "max_parallel", maxParInt) 248 } 249 } else { 250 maxParInt = physical.DefaultParallelOperations 251 } 252 253 dsnParams := url.Values{} 254 tlsCaFile, ok := conf["tls_ca_file"] 255 if ok { 256 if err := setupMySQLTLSConfig(tlsCaFile); err != nil { 257 return nil, errwrap.Wrapf("failed register TLS config: {{err}}", err) 258 } 259 260 dsnParams.Add("tls", mysqlTLSKey) 261 } 262 263 // Create MySQL handle for the database. 264 dsn := username + ":" + password + "@tcp(" + address + ")/?" + dsnParams.Encode() 265 db, err := sql.Open("mysql", dsn) 266 if err != nil { 267 return nil, errwrap.Wrapf("failed to connect to mysql: {{err}}", err) 268 } 269 db.SetMaxOpenConns(maxParInt) 270 if maxIdleConnInt != 0 { 271 db.SetMaxIdleConns(maxIdleConnInt) 272 } 273 if maxConnLifeInt != 0 { 274 db.SetConnMaxLifetime(time.Duration(maxConnLifeInt) * time.Second) 275 } 276 277 return db, err 278} 279 280// prepare is a helper to prepare a query for future execution 281func (m *MySQLBackend) prepare(name, query string) error { 282 stmt, err := m.client.Prepare(query) 283 if err != nil { 284 return errwrap.Wrapf(fmt.Sprintf("failed to prepare %q: {{err}}", name), err) 285 } 286 m.statements[name] = stmt 287 return nil 288} 289 290// Put is used to insert or update an entry. 291func (m *MySQLBackend) Put(ctx context.Context, entry *physical.Entry) error { 292 defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now()) 293 294 m.permitPool.Acquire() 295 defer m.permitPool.Release() 296 297 _, err := m.statements["put"].Exec(entry.Key, entry.Value) 298 if err != nil { 299 return err 300 } 301 return nil 302} 303 304// Get is used to fetch an entry. 305func (m *MySQLBackend) Get(ctx context.Context, key string) (*physical.Entry, error) { 306 defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now()) 307 308 m.permitPool.Acquire() 309 defer m.permitPool.Release() 310 311 var result []byte 312 err := m.statements["get"].QueryRow(key).Scan(&result) 313 if err == sql.ErrNoRows { 314 return nil, nil 315 } 316 if err != nil { 317 return nil, err 318 } 319 320 ent := &physical.Entry{ 321 Key: key, 322 Value: result, 323 } 324 return ent, nil 325} 326 327// Delete is used to permanently delete an entry 328func (m *MySQLBackend) Delete(ctx context.Context, key string) error { 329 defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now()) 330 331 m.permitPool.Acquire() 332 defer m.permitPool.Release() 333 334 _, err := m.statements["delete"].Exec(key) 335 if err != nil { 336 return err 337 } 338 return nil 339} 340 341// List is used to list all the keys under a given 342// prefix, up to the next prefix. 343func (m *MySQLBackend) List(ctx context.Context, prefix string) ([]string, error) { 344 defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now()) 345 346 m.permitPool.Acquire() 347 defer m.permitPool.Release() 348 349 // Add the % wildcard to the prefix to do the prefix search 350 likePrefix := prefix + "%" 351 rows, err := m.statements["list"].Query(likePrefix) 352 if err != nil { 353 return nil, errwrap.Wrapf("failed to execute statement: {{err}}", err) 354 } 355 356 var keys []string 357 for rows.Next() { 358 var key string 359 err = rows.Scan(&key) 360 if err != nil { 361 return nil, errwrap.Wrapf("failed to scan rows: {{err}}", err) 362 } 363 364 key = strings.TrimPrefix(key, prefix) 365 if i := strings.Index(key, "/"); i == -1 { 366 // Add objects only from the current 'folder' 367 keys = append(keys, key) 368 } else if i != -1 { 369 // Add truncated 'folder' paths 370 keys = strutil.AppendIfMissing(keys, string(key[:i+1])) 371 } 372 } 373 374 sort.Strings(keys) 375 return keys, nil 376} 377 378// LockWith is used for mutual exclusion based on the given key. 379func (m *MySQLBackend) LockWith(key, value string) (physical.Lock, error) { 380 l := &MySQLHALock{ 381 in: m, 382 key: key, 383 value: value, 384 logger: m.logger, 385 } 386 return l, nil 387} 388 389func (m *MySQLBackend) HAEnabled() bool { 390 return m.haEnabled 391} 392 393// MySQLHALock is a MySQL Lock implementation for the HABackend 394type MySQLHALock struct { 395 in *MySQLBackend 396 key string 397 value string 398 logger log.Logger 399 400 held bool 401 localLock sync.Mutex 402 leaderCh chan struct{} 403 stopCh <-chan struct{} 404 lock *MySQLLock 405} 406 407func (i *MySQLHALock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { 408 i.localLock.Lock() 409 defer i.localLock.Unlock() 410 if i.held { 411 return nil, fmt.Errorf("lock already held") 412 } 413 414 // Attempt an async acquisition 415 didLock := make(chan struct{}) 416 failLock := make(chan error, 1) 417 releaseCh := make(chan bool, 1) 418 go i.attemptLock(i.key, i.value, didLock, failLock, releaseCh) 419 420 // Wait for lock acquisition, failure, or shutdown 421 select { 422 case <-didLock: 423 releaseCh <- false 424 case err := <-failLock: 425 return nil, err 426 case <-stopCh: 427 releaseCh <- true 428 return nil, nil 429 } 430 431 // Create the leader channel 432 i.held = true 433 i.leaderCh = make(chan struct{}) 434 435 go i.monitorLock(i.leaderCh) 436 437 i.stopCh = stopCh 438 439 return i.leaderCh, nil 440} 441 442func (i *MySQLHALock) attemptLock(key, value string, didLock chan struct{}, failLock chan error, releaseCh chan bool) { 443 lock, err := NewMySQLLock(i.in, i.logger, key, value) 444 445 // Set node value 446 i.lock = lock 447 448 if err != nil { 449 failLock <- err 450 } 451 452 err = lock.Lock() 453 if err != nil { 454 failLock <- err 455 return 456 } 457 458 // Signal that lock is held 459 close(didLock) 460 461 // Handle an early abort 462 release := <-releaseCh 463 if release { 464 lock.Unlock() 465 } 466} 467 468func (i *MySQLHALock) monitorLock(leaderCh chan struct{}) { 469 for { 470 // The only way to lose this lock is if someone is 471 // logging into the DB and altering system tables or you lose a connection in 472 // which case you will lose the lock anyway. 473 err := i.hasLock(i.key) 474 if err != nil { 475 // Somehow we lost the lock.... likely because the connection holding 476 // the lock was closed or someone was playing around with the locks in the DB. 477 close(leaderCh) 478 return 479 } 480 481 time.Sleep(5 * time.Second) 482 } 483} 484 485func (i *MySQLHALock) Unlock() error { 486 i.localLock.Lock() 487 defer i.localLock.Unlock() 488 if !i.held { 489 return nil 490 } 491 492 err := i.lock.Unlock() 493 494 if err == nil { 495 i.held = false 496 return nil 497 } 498 499 return err 500} 501 502// hasLock will check if a lock is held by checking the current lock id against our known ID. 503func (i *MySQLHALock) hasLock(key string) error { 504 var result sql.NullInt64 505 err := i.in.statements["used_lock"].QueryRow(key).Scan(&result) 506 if err == sql.ErrNoRows || !result.Valid { 507 // This is not an error to us since it just means the lock isn't held 508 return nil 509 } 510 511 if err != nil { 512 return err 513 } 514 515 // IS_USED_LOCK will return the ID of the connection that created the lock. 516 if result.Int64 != GlobalLockID { 517 return ErrLockHeld 518 } 519 520 return nil 521} 522 523func (i *MySQLHALock) GetLeader() (string, error) { 524 defer metrics.MeasureSince([]string{"mysql", "lock_get"}, time.Now()) 525 var result string 526 err := i.in.statements["get_lock"].QueryRow("leader").Scan(&result) 527 if err == sql.ErrNoRows { 528 return "", err 529 } 530 531 return result, nil 532} 533 534func (i *MySQLHALock) Value() (bool, string, error) { 535 leaderkey, err := i.GetLeader() 536 if err != nil { 537 return false, "", err 538 } 539 540 return true, leaderkey, err 541} 542 543// MySQLLock provides an easy way to grab and release mysql 544// locks using the built in GET_LOCK function. Note that these 545// locks are released when you lose connection to the server. 546type MySQLLock struct { 547 parentConn *MySQLBackend 548 in *sql.DB 549 logger log.Logger 550 statements map[string]*sql.Stmt 551 key string 552 value string 553} 554 555// Errors specific to trying to grab a lock in MySQL 556var ( 557 // This is the GlobalLockID for checking if the lock we got is still the current lock 558 GlobalLockID int64 559 // ErrLockHeld is returned when another vault instance already has a lock held for the given key. 560 ErrLockHeld = errors.New("mysql: lock already held") 561 // ErrUnlockFailed 562 ErrUnlockFailed = errors.New("mysql: unable to release lock, already released or not held by this session") 563 // You were unable to update that you are the new leader in the DB 564 ErrClaimFailed = errors.New("mysql: unable to update DB with new leader information") 565 // Error to throw if between getting the lock and checking the ID of it we lost it. 566 ErrSettingGlobalID = errors.New("mysql: getting global lock id failed") 567) 568 569// NewMySQLLock helper function 570func NewMySQLLock(in *MySQLBackend, l log.Logger, key, value string) (*MySQLLock, error) { 571 // Create a new MySQL connection so we can close this and have no effect on 572 // the rest of the MySQL backend and any cleanup that might need to be done. 573 conn, _ := NewMySQLClient(in.conf, in.logger) 574 575 m := &MySQLLock{ 576 parentConn: in, 577 in: conn, 578 logger: l, 579 statements: make(map[string]*sql.Stmt), 580 key: key, 581 value: value, 582 } 583 584 statements := map[string]string{ 585 "put": "INSERT INTO " + in.dbLockTable + 586 " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE current_leader=VALUES(current_leader)", 587 } 588 589 for name, query := range statements { 590 if err := m.prepare(name, query); err != nil { 591 return nil, err 592 } 593 } 594 595 return m, nil 596} 597 598// prepare is a helper to prepare a query for future execution 599func (m *MySQLLock) prepare(name, query string) error { 600 stmt, err := m.in.Prepare(query) 601 if err != nil { 602 return errwrap.Wrapf(fmt.Sprintf("failed to prepare %q: {{err}}", name), err) 603 } 604 m.statements[name] = stmt 605 return nil 606} 607 608// update the current cluster leader in the DB. This is used so 609// we can tell the servers in standby who the active leader is. 610func (i *MySQLLock) becomeLeader() error { 611 _, err := i.statements["put"].Exec("leader", i.value) 612 if err != nil { 613 return err 614 } 615 616 return nil 617} 618 619// Lock will try to get a lock for an indefinite amount of time 620// based on the given key that has been requested. 621func (i *MySQLLock) Lock() error { 622 defer metrics.MeasureSince([]string{"mysql", "get_lock"}, time.Now()) 623 624 // Lock timeout math.MaxInt32 instead of -1 solves compatibility issues with 625 // different MySQL flavours i.e. MariaDB 626 rows, err := i.in.Query("SELECT GET_LOCK(?, ?), IS_USED_LOCK(?)", i.key, math.MaxInt32, i.key) 627 if err != nil { 628 return err 629 } 630 631 defer rows.Close() 632 rows.Next() 633 var lock sql.NullInt64 634 var connectionID sql.NullInt64 635 rows.Scan(&lock, &connectionID) 636 637 if rows.Err() != nil { 638 return rows.Err() 639 } 640 641 // 1 is returned from GET_LOCK if it was able to get the lock 642 // 0 if it failed and NULL if some strange error happened. 643 // https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_get-lock 644 if !lock.Valid || lock.Int64 != 1 { 645 return ErrLockHeld 646 } 647 648 // Since we have the lock alert the rest of the cluster 649 // that we are now the active leader. 650 err = i.becomeLeader() 651 if err != nil { 652 return ErrLockHeld 653 } 654 655 // This will return the connection ID of NULL if an error happens 656 // https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_is-used-lock 657 if !connectionID.Valid { 658 return ErrSettingGlobalID 659 } 660 661 GlobalLockID = connectionID.Int64 662 663 return nil 664} 665 666// Unlock just closes the connection. This is because closing the MySQL connection 667// is a 100% reliable way to close the lock. If you just release the lock you must 668// do it from the same mysql connection_id that you originally created it from. This 669// is a huge hastle and I actually couldn't find a clean way to do this although one 670// likely does exist. Closing the connection however ensures we don't ever get into a 671// state where we try to release the lock and it hangs it is also much less code. 672func (i *MySQLLock) Unlock() error { 673 err := i.in.Close() 674 if err != nil { 675 return ErrUnlockFailed 676 } 677 678 return nil 679} 680 681// Establish a TLS connection with a given CA certificate 682// Register a tsl.Config associated with the same key as the dns param from sql.Open 683// foo:bar@tcp(127.0.0.1:3306)/dbname?tls=default 684func setupMySQLTLSConfig(tlsCaFile string) error { 685 rootCertPool := x509.NewCertPool() 686 687 pem, err := ioutil.ReadFile(tlsCaFile) 688 if err != nil { 689 return err 690 } 691 692 if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 693 return err 694 } 695 696 err = mysql.RegisterTLSConfig(mysqlTLSKey, &tls.Config{ 697 RootCAs: rootCertPool, 698 }) 699 if err != nil { 700 return err 701 } 702 703 return nil 704} 705