1// +build go1.9
2
3package mysql
4
5import (
6	"context"
7	"crypto/tls"
8	"crypto/x509"
9	"database/sql"
10	"fmt"
11	"io"
12	"io/ioutil"
13	nurl "net/url"
14	"strconv"
15	"strings"
16)
17
18import (
19	"github.com/go-sql-driver/mysql"
20	"github.com/hashicorp/go-multierror"
21)
22
23import (
24	"github.com/golang-migrate/migrate/v4/database"
25)
26
27func init() {
28	database.Register("mysql", &Mysql{})
29}
30
31var DefaultMigrationsTable = "schema_migrations"
32
33var (
34	ErrDatabaseDirty    = fmt.Errorf("database is dirty")
35	ErrNilConfig        = fmt.Errorf("no config")
36	ErrNoDatabaseName   = fmt.Errorf("no database name")
37	ErrAppendPEM        = fmt.Errorf("failed to append PEM")
38	ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty")
39)
40
41type Config struct {
42	MigrationsTable string
43	DatabaseName    string
44	NoLock          bool
45}
46
47type Mysql struct {
48	// mysql RELEASE_LOCK must be called from the same conn, so
49	// just do everything over a single conn anyway.
50	conn     *sql.Conn
51	db       *sql.DB
52	isLocked bool
53
54	config *Config
55}
56
57// instance must have `multiStatements` set to true
58func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
59	if config == nil {
60		return nil, ErrNilConfig
61	}
62
63	if err := instance.Ping(); err != nil {
64		return nil, err
65	}
66
67	if config.DatabaseName == "" {
68		query := `SELECT DATABASE()`
69		var databaseName sql.NullString
70		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
71			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
72		}
73
74		if len(databaseName.String) == 0 {
75			return nil, ErrNoDatabaseName
76		}
77
78		config.DatabaseName = databaseName.String
79	}
80
81	if len(config.MigrationsTable) == 0 {
82		config.MigrationsTable = DefaultMigrationsTable
83	}
84
85	conn, err := instance.Conn(context.Background())
86	if err != nil {
87		return nil, err
88	}
89
90	mx := &Mysql{
91		conn:   conn,
92		db:     instance,
93		config: config,
94	}
95
96	if err := mx.ensureVersionTable(); err != nil {
97		return nil, err
98	}
99
100	return mx, nil
101}
102
103// extractCustomQueryParams extracts the custom query params (ones that start with "x-") from
104// mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL
105func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) {
106	if c == nil {
107		return nil, ErrNilConfig
108	}
109	customQueryParams := map[string]string{}
110
111	for k, v := range c.Params {
112		if strings.HasPrefix(k, "x-") {
113			customQueryParams[k] = v
114			delete(c.Params, k)
115		}
116	}
117	return customQueryParams, nil
118}
119
120func urlToMySQLConfig(url string) (*mysql.Config, error) {
121	// Need to parse out custom TLS parameters and call
122	// mysql.RegisterTLSConfig() before mysql.ParseDSN() is called
123	// which consumes the registered tls.Config
124	// Fixes: https://github.com/golang-migrate/migrate/issues/411
125	//
126	// Can't use url.Parse() since it fails to parse MySQL DSNs
127	// mysql.ParseDSN() also searches for "?" to find query parameters:
128	// https://github.com/go-sql-driver/mysql/blob/46351a8/dsn.go#L344
129	if idx := strings.LastIndex(url, "?"); idx > 0 {
130		rawParams := url[idx+1:]
131		parsedParams, err := nurl.ParseQuery(rawParams)
132		if err != nil {
133			return nil, err
134		}
135
136		ctls := parsedParams.Get("tls")
137		if len(ctls) > 0 {
138			if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
139				rootCertPool := x509.NewCertPool()
140				pem, err := ioutil.ReadFile(parsedParams.Get("x-tls-ca"))
141				if err != nil {
142					return nil, err
143				}
144
145				if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
146					return nil, ErrAppendPEM
147				}
148
149				clientCert := make([]tls.Certificate, 0, 1)
150				if ccert, ckey := parsedParams.Get("x-tls-cert"), parsedParams.Get("x-tls-key"); ccert != "" || ckey != "" {
151					if ccert == "" || ckey == "" {
152						return nil, ErrTLSCertKeyConfig
153					}
154					certs, err := tls.LoadX509KeyPair(ccert, ckey)
155					if err != nil {
156						return nil, err
157					}
158					clientCert = append(clientCert, certs)
159				}
160
161				insecureSkipVerify := false
162				insecureSkipVerifyStr := parsedParams.Get("x-tls-insecure-skip-verify")
163				if len(insecureSkipVerifyStr) > 0 {
164					x, err := strconv.ParseBool(insecureSkipVerifyStr)
165					if err != nil {
166						return nil, err
167					}
168					insecureSkipVerify = x
169				}
170
171				err = mysql.RegisterTLSConfig(ctls, &tls.Config{
172					RootCAs:            rootCertPool,
173					Certificates:       clientCert,
174					InsecureSkipVerify: insecureSkipVerify,
175				})
176				if err != nil {
177					return nil, err
178				}
179			}
180		}
181	}
182
183	config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
184	if err != nil {
185		return nil, err
186	}
187
188	config.MultiStatements = true
189
190	// Keep backwards compatibility from when we used net/url.Parse() to parse the DSN.
191	// net/url.Parse() would automatically unescape it for us.
192	// See: https://play.golang.org/p/q9j1io-YICQ
193	user, err := nurl.QueryUnescape(config.User)
194	if err != nil {
195		return nil, err
196	}
197	config.User = user
198
199	password, err := nurl.QueryUnescape(config.Passwd)
200	if err != nil {
201		return nil, err
202	}
203	config.Passwd = password
204
205	return config, nil
206}
207
208func (m *Mysql) Open(url string) (database.Driver, error) {
209	config, err := urlToMySQLConfig(url)
210	if err != nil {
211		return nil, err
212	}
213
214	customParams, err := extractCustomQueryParams(config)
215	if err != nil {
216		return nil, err
217	}
218
219	noLockParam, noLock := customParams["x-no-lock"], false
220	if noLockParam != "" {
221		noLock, err = strconv.ParseBool(noLockParam)
222		if err != nil {
223			return nil, fmt.Errorf("could not parse x-no-lock as bool: %w", err)
224		}
225	}
226
227	db, err := sql.Open("mysql", config.FormatDSN())
228	if err != nil {
229		return nil, err
230	}
231
232	mx, err := WithInstance(db, &Config{
233		DatabaseName:    config.DBName,
234		MigrationsTable: customParams["x-migrations-table"],
235		NoLock:          noLock,
236	})
237	if err != nil {
238		return nil, err
239	}
240
241	return mx, nil
242}
243
244func (m *Mysql) Close() error {
245	connErr := m.conn.Close()
246	dbErr := m.db.Close()
247	if connErr != nil || dbErr != nil {
248		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
249	}
250	return nil
251}
252
253func (m *Mysql) Lock() error {
254	if m.isLocked {
255		return database.ErrLocked
256	}
257
258	if m.config.NoLock {
259		m.isLocked = true
260		return nil
261	}
262
263	aid, err := database.GenerateAdvisoryLockId(
264		fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
265	if err != nil {
266		return err
267	}
268
269	query := "SELECT GET_LOCK(?, 10)"
270	var success bool
271	if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil {
272		return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
273	}
274
275	if success {
276		m.isLocked = true
277		return nil
278	}
279
280	return database.ErrLocked
281}
282
283func (m *Mysql) Unlock() error {
284	if !m.isLocked {
285		return nil
286	}
287
288	if m.config.NoLock {
289		m.isLocked = false
290		return nil
291	}
292
293	aid, err := database.GenerateAdvisoryLockId(
294		fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
295	if err != nil {
296		return err
297	}
298
299	query := `SELECT RELEASE_LOCK(?)`
300	if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil {
301		return &database.Error{OrigErr: err, Query: []byte(query)}
302	}
303
304	// NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed),
305	// in which case isLocked should be true until the timeout expires -- synchronizing
306	// these states is likely not worth trying to do; reconsider the necessity of isLocked.
307
308	m.isLocked = false
309	return nil
310}
311
312func (m *Mysql) Run(migration io.Reader) error {
313	migr, err := ioutil.ReadAll(migration)
314	if err != nil {
315		return err
316	}
317
318	query := string(migr[:])
319	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
320		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
321	}
322
323	return nil
324}
325
326func (m *Mysql) SetVersion(version int, dirty bool) error {
327	tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{})
328	if err != nil {
329		return &database.Error{OrigErr: err, Err: "transaction start failed"}
330	}
331
332	query := "TRUNCATE `" + m.config.MigrationsTable + "`"
333	if _, err := tx.ExecContext(context.Background(), query); err != nil {
334		if errRollback := tx.Rollback(); errRollback != nil {
335			err = multierror.Append(err, errRollback)
336		}
337		return &database.Error{OrigErr: err, Query: []byte(query)}
338	}
339
340	// Also re-write the schema version for nil dirty versions to prevent
341	// empty schema version for failed down migration on the first migration
342	// See: https://github.com/golang-migrate/migrate/issues/330
343	if version >= 0 || (version == database.NilVersion && dirty) {
344		query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
345		if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil {
346			if errRollback := tx.Rollback(); errRollback != nil {
347				err = multierror.Append(err, errRollback)
348			}
349			return &database.Error{OrigErr: err, Query: []byte(query)}
350		}
351	}
352
353	if err := tx.Commit(); err != nil {
354		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
355	}
356
357	return nil
358}
359
360func (m *Mysql) Version() (version int, dirty bool, err error) {
361	query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
362	err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
363	switch {
364	case err == sql.ErrNoRows:
365		return database.NilVersion, false, nil
366
367	case err != nil:
368		if e, ok := err.(*mysql.MySQLError); ok {
369			if e.Number == 0 {
370				return database.NilVersion, false, nil
371			}
372		}
373		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
374
375	default:
376		return version, dirty, nil
377	}
378}
379
380func (m *Mysql) Drop() (err error) {
381	// select all tables
382	query := `SHOW TABLES LIKE '%'`
383	tables, err := m.conn.QueryContext(context.Background(), query)
384	if err != nil {
385		return &database.Error{OrigErr: err, Query: []byte(query)}
386	}
387	defer func() {
388		if errClose := tables.Close(); errClose != nil {
389			err = multierror.Append(err, errClose)
390		}
391	}()
392
393	// delete one table after another
394	tableNames := make([]string, 0)
395	for tables.Next() {
396		var tableName string
397		if err := tables.Scan(&tableName); err != nil {
398			return err
399		}
400		if len(tableName) > 0 {
401			tableNames = append(tableNames, tableName)
402		}
403	}
404	if err := tables.Err(); err != nil {
405		return &database.Error{OrigErr: err, Query: []byte(query)}
406	}
407
408	if len(tableNames) > 0 {
409		// disable checking foreign key constraints until finished
410		query = `SET foreign_key_checks = 0`
411		if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
412			return &database.Error{OrigErr: err, Query: []byte(query)}
413		}
414
415		defer func() {
416			// enable foreign key checks
417			_, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`)
418		}()
419
420		// delete one by one ...
421		for _, t := range tableNames {
422			query = "DROP TABLE IF EXISTS `" + t + "`"
423			if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
424				return &database.Error{OrigErr: err, Query: []byte(query)}
425			}
426		}
427	}
428
429	return nil
430}
431
432// ensureVersionTable checks if versions table exists and, if not, creates it.
433// Note that this function locks the database, which deviates from the usual
434// convention of "caller locks" in the Mysql type.
435func (m *Mysql) ensureVersionTable() (err error) {
436	if err = m.Lock(); err != nil {
437		return err
438	}
439
440	defer func() {
441		if e := m.Unlock(); e != nil {
442			if err == nil {
443				err = e
444			} else {
445				err = multierror.Append(err, e)
446			}
447		}
448	}()
449
450	// check if migration table exists
451	var result string
452	query := `SHOW TABLES LIKE '` + m.config.MigrationsTable + `'`
453	if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil {
454		if err != sql.ErrNoRows {
455			return &database.Error{OrigErr: err, Query: []byte(query)}
456		}
457	} else {
458		return nil
459	}
460
461	// if not, create the empty migration table
462	query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
463	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
464		return &database.Error{OrigErr: err, Query: []byte(query)}
465	}
466	return nil
467}
468
469// Returns the bool value of the input.
470// The 2nd return value indicates if the input was a valid bool value
471// See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71
472func readBool(input string) (value bool, valid bool) {
473	switch input {
474	case "1", "true", "TRUE", "True":
475		return true, true
476	case "0", "false", "FALSE", "False":
477		return false, true
478	}
479
480	// Not a valid bool value
481	return
482}
483