1//go:build go1.9
2// +build go1.9
3
4package postgres
5
6import (
7	"context"
8	"database/sql"
9	"fmt"
10	"go.uber.org/atomic"
11	"io"
12	"io/ioutil"
13	nurl "net/url"
14	"regexp"
15	"strconv"
16	"strings"
17	"time"
18
19	"github.com/golang-migrate/migrate/v4"
20	"github.com/golang-migrate/migrate/v4/database"
21	"github.com/golang-migrate/migrate/v4/database/multistmt"
22	multierror "github.com/hashicorp/go-multierror"
23	"github.com/lib/pq"
24)
25
26func init() {
27	db := Postgres{}
28	database.Register("postgres", &db)
29	database.Register("postgresql", &db)
30}
31
32var (
33	multiStmtDelimiter = []byte(";")
34
35	DefaultMigrationsTable       = "schema_migrations"
36	DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
37)
38
39var (
40	ErrNilConfig      = fmt.Errorf("no config")
41	ErrNoDatabaseName = fmt.Errorf("no database name")
42	ErrNoSchema       = fmt.Errorf("no schema")
43	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
44)
45
46type Config struct {
47	MigrationsTable       string
48	MigrationsTableQuoted bool
49	MultiStatementEnabled bool
50	DatabaseName          string
51	SchemaName            string
52	migrationsSchemaName  string
53	migrationsTableName   string
54	StatementTimeout      time.Duration
55	MultiStatementMaxSize int
56}
57
58type Postgres struct {
59	// Locking and unlocking need to use the same connection
60	conn     *sql.Conn
61	db       *sql.DB
62	isLocked atomic.Bool
63
64	// Open and WithInstance need to guarantee that config is never nil
65	config *Config
66}
67
68func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
69	if config == nil {
70		return nil, ErrNilConfig
71	}
72
73	if err := instance.Ping(); err != nil {
74		return nil, err
75	}
76
77	if config.DatabaseName == "" {
78		query := `SELECT CURRENT_DATABASE()`
79		var databaseName string
80		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
81			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
82		}
83
84		if len(databaseName) == 0 {
85			return nil, ErrNoDatabaseName
86		}
87
88		config.DatabaseName = databaseName
89	}
90
91	if config.SchemaName == "" {
92		query := `SELECT CURRENT_SCHEMA()`
93		var schemaName string
94		if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
95			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
96		}
97
98		if len(schemaName) == 0 {
99			return nil, ErrNoSchema
100		}
101
102		config.SchemaName = schemaName
103	}
104
105	if len(config.MigrationsTable) == 0 {
106		config.MigrationsTable = DefaultMigrationsTable
107	}
108
109	config.migrationsSchemaName = config.SchemaName
110	config.migrationsTableName = config.MigrationsTable
111	if config.MigrationsTableQuoted {
112		re := regexp.MustCompile(`"(.*?)"`)
113		result := re.FindAllStringSubmatch(config.MigrationsTable, -1)
114		config.migrationsTableName = result[len(result)-1][1]
115		if len(result) == 2 {
116			config.migrationsSchemaName = result[0][1]
117		} else if len(result) > 2 {
118			return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable)
119		}
120	}
121
122	conn, err := instance.Conn(context.Background())
123
124	if err != nil {
125		return nil, err
126	}
127
128	px := &Postgres{
129		conn:   conn,
130		db:     instance,
131		config: config,
132	}
133
134	if err := px.ensureVersionTable(); err != nil {
135		return nil, err
136	}
137
138	return px, nil
139}
140
141func (p *Postgres) Open(url string) (database.Driver, error) {
142	purl, err := nurl.Parse(url)
143	if err != nil {
144		return nil, err
145	}
146
147	db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String())
148	if err != nil {
149		return nil, err
150	}
151
152	migrationsTable := purl.Query().Get("x-migrations-table")
153	migrationsTableQuoted := false
154	if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
155		migrationsTableQuoted, err = strconv.ParseBool(s)
156		if err != nil {
157			return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
158		}
159	}
160	if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
161		return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
162	}
163
164	statementTimeoutString := purl.Query().Get("x-statement-timeout")
165	statementTimeout := 0
166	if statementTimeoutString != "" {
167		statementTimeout, err = strconv.Atoi(statementTimeoutString)
168		if err != nil {
169			return nil, err
170		}
171	}
172
173	multiStatementMaxSize := DefaultMultiStatementMaxSize
174	if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
175		multiStatementMaxSize, err = strconv.Atoi(s)
176		if err != nil {
177			return nil, err
178		}
179		if multiStatementMaxSize <= 0 {
180			multiStatementMaxSize = DefaultMultiStatementMaxSize
181		}
182	}
183
184	multiStatementEnabled := false
185	if s := purl.Query().Get("x-multi-statement"); len(s) > 0 {
186		multiStatementEnabled, err = strconv.ParseBool(s)
187		if err != nil {
188			return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err)
189		}
190	}
191
192	px, err := WithInstance(db, &Config{
193		DatabaseName:          purl.Path,
194		MigrationsTable:       migrationsTable,
195		MigrationsTableQuoted: migrationsTableQuoted,
196		StatementTimeout:      time.Duration(statementTimeout) * time.Millisecond,
197		MultiStatementEnabled: multiStatementEnabled,
198		MultiStatementMaxSize: multiStatementMaxSize,
199	})
200
201	if err != nil {
202		return nil, err
203	}
204
205	return px, nil
206}
207
208func (p *Postgres) Close() error {
209	connErr := p.conn.Close()
210	dbErr := p.db.Close()
211	if connErr != nil || dbErr != nil {
212		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
213	}
214	return nil
215}
216
217// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
218func (p *Postgres) Lock() error {
219	return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error {
220		aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
221		if err != nil {
222			return err
223		}
224
225		// This will wait indefinitely until the lock can be acquired.
226		query := `SELECT pg_advisory_lock($1)`
227		if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
228			return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
229		}
230
231		return nil
232	})
233}
234
235func (p *Postgres) Unlock() error {
236	return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error {
237		aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
238		if err != nil {
239			return err
240		}
241
242		query := `SELECT pg_advisory_unlock($1)`
243		if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
244			return &database.Error{OrigErr: err, Query: []byte(query)}
245		}
246		return nil
247	})
248}
249
250func (p *Postgres) Run(migration io.Reader) error {
251	if p.config.MultiStatementEnabled {
252		var err error
253		if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
254			if err = p.runStatement(m); err != nil {
255				return false
256			}
257			return true
258		}); e != nil {
259			return e
260		}
261		return err
262	}
263	migr, err := ioutil.ReadAll(migration)
264	if err != nil {
265		return err
266	}
267	return p.runStatement(migr)
268}
269
270func (p *Postgres) runStatement(statement []byte) error {
271	ctx := context.Background()
272	if p.config.StatementTimeout != 0 {
273		var cancel context.CancelFunc
274		ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
275		defer cancel()
276	}
277	query := string(statement)
278	if strings.TrimSpace(query) == "" {
279		return nil
280	}
281	if _, err := p.conn.ExecContext(ctx, query); err != nil {
282		if pgErr, ok := err.(*pq.Error); ok {
283			var line uint
284			var col uint
285			var lineColOK bool
286			if pgErr.Position != "" {
287				if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
288					line, col, lineColOK = computeLineFromPos(query, int(pos))
289				}
290			}
291			message := fmt.Sprintf("migration failed: %s", pgErr.Message)
292			if lineColOK {
293				message = fmt.Sprintf("%s (column %d)", message, col)
294			}
295			if pgErr.Detail != "" {
296				message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
297			}
298			return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
299		}
300		return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
301	}
302	return nil
303}
304
305func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
306	// replace crlf with lf
307	s = strings.Replace(s, "\r\n", "\n", -1)
308	// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
309	runes := []rune(s)
310	if pos > len(runes) {
311		return 0, 0, false
312	}
313	sel := runes[:pos]
314	line = uint(runesCount(sel, newLine) + 1)
315	col = uint(pos - 1 - runesLastIndex(sel, newLine))
316	return line, col, true
317}
318
319const newLine = '\n'
320
321func runesCount(input []rune, target rune) int {
322	var count int
323	for _, r := range input {
324		if r == target {
325			count++
326		}
327	}
328	return count
329}
330
331func runesLastIndex(input []rune, target rune) int {
332	for i := len(input) - 1; i >= 0; i-- {
333		if input[i] == target {
334			return i
335		}
336	}
337	return -1
338}
339
340func (p *Postgres) SetVersion(version int, dirty bool) error {
341	tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
342	if err != nil {
343		return &database.Error{OrigErr: err, Err: "transaction start failed"}
344	}
345
346	query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName)
347	if _, err := tx.Exec(query); err != nil {
348		if errRollback := tx.Rollback(); errRollback != nil {
349			err = multierror.Append(err, errRollback)
350		}
351		return &database.Error{OrigErr: err, Query: []byte(query)}
352	}
353
354	// Also re-write the schema version for nil dirty versions to prevent
355	// empty schema version for failed down migration on the first migration
356	// See: https://github.com/golang-migrate/migrate/issues/330
357	if version >= 0 || (version == database.NilVersion && dirty) {
358		query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
359		if _, err := tx.Exec(query, version, dirty); err != nil {
360			if errRollback := tx.Rollback(); errRollback != nil {
361				err = multierror.Append(err, errRollback)
362			}
363			return &database.Error{OrigErr: err, Query: []byte(query)}
364		}
365	}
366
367	if err := tx.Commit(); err != nil {
368		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
369	}
370
371	return nil
372}
373
374func (p *Postgres) Version() (version int, dirty bool, err error) {
375	query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
376	err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
377	switch {
378	case err == sql.ErrNoRows:
379		return database.NilVersion, false, nil
380
381	case err != nil:
382		if e, ok := err.(*pq.Error); ok {
383			if e.Code.Name() == "undefined_table" {
384				return database.NilVersion, false, nil
385			}
386		}
387		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
388
389	default:
390		return version, dirty, nil
391	}
392}
393
394func (p *Postgres) Drop() (err error) {
395	// select all tables in current schema
396	query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
397	tables, err := p.conn.QueryContext(context.Background(), query)
398	if err != nil {
399		return &database.Error{OrigErr: err, Query: []byte(query)}
400	}
401	defer func() {
402		if errClose := tables.Close(); errClose != nil {
403			err = multierror.Append(err, errClose)
404		}
405	}()
406
407	// delete one table after another
408	tableNames := make([]string, 0)
409	for tables.Next() {
410		var tableName string
411		if err := tables.Scan(&tableName); err != nil {
412			return err
413		}
414		if len(tableName) > 0 {
415			tableNames = append(tableNames, tableName)
416		}
417	}
418	if err := tables.Err(); err != nil {
419		return &database.Error{OrigErr: err, Query: []byte(query)}
420	}
421
422	if len(tableNames) > 0 {
423		// delete one by one ...
424		for _, t := range tableNames {
425			query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE`
426			if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
427				return &database.Error{OrigErr: err, Query: []byte(query)}
428			}
429		}
430	}
431
432	return nil
433}
434
435// ensureVersionTable checks if versions table exists and, if not, creates it.
436// Note that this function locks the database, which deviates from the usual
437// convention of "caller locks" in the Postgres type.
438func (p *Postgres) ensureVersionTable() (err error) {
439	if err = p.Lock(); err != nil {
440		return err
441	}
442
443	defer func() {
444		if e := p.Unlock(); e != nil {
445			if err == nil {
446				err = e
447			} else {
448				err = multierror.Append(err, e)
449			}
450		}
451	}()
452
453	// This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres
454	// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
455	// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
456	// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
457	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
458	row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
459
460	var count int
461	err = row.Scan(&count)
462	if err != nil {
463		return &database.Error{OrigErr: err, Query: []byte(query)}
464	}
465
466	if count == 1 {
467		return nil
468	}
469
470	query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
471	if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
472		return &database.Error{OrigErr: err, Query: []byte(query)}
473	}
474
475	return nil
476}
477