1//go:build go1.9 2// +build go1.9 3 4package postgres 5 6import ( 7 "context" 8 "database/sql" 9 "fmt" 10 "go.uber.org/atomic" 11 "io" 12 "io/ioutil" 13 nurl "net/url" 14 "regexp" 15 "strconv" 16 "strings" 17 "time" 18 19 "github.com/golang-migrate/migrate/v4" 20 "github.com/golang-migrate/migrate/v4/database" 21 "github.com/golang-migrate/migrate/v4/database/multistmt" 22 multierror "github.com/hashicorp/go-multierror" 23 "github.com/lib/pq" 24) 25 26func init() { 27 db := Postgres{} 28 database.Register("postgres", &db) 29 database.Register("postgresql", &db) 30} 31 32var ( 33 multiStmtDelimiter = []byte(";") 34 35 DefaultMigrationsTable = "schema_migrations" 36 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 37) 38 39var ( 40 ErrNilConfig = fmt.Errorf("no config") 41 ErrNoDatabaseName = fmt.Errorf("no database name") 42 ErrNoSchema = fmt.Errorf("no schema") 43 ErrDatabaseDirty = fmt.Errorf("database is dirty") 44) 45 46type Config struct { 47 MigrationsTable string 48 MigrationsTableQuoted bool 49 MultiStatementEnabled bool 50 DatabaseName string 51 SchemaName string 52 migrationsSchemaName string 53 migrationsTableName string 54 StatementTimeout time.Duration 55 MultiStatementMaxSize int 56} 57 58type Postgres struct { 59 // Locking and unlocking need to use the same connection 60 conn *sql.Conn 61 db *sql.DB 62 isLocked atomic.Bool 63 64 // Open and WithInstance need to guarantee that config is never nil 65 config *Config 66} 67 68func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 69 if config == nil { 70 return nil, ErrNilConfig 71 } 72 73 if err := instance.Ping(); err != nil { 74 return nil, err 75 } 76 77 if config.DatabaseName == "" { 78 query := `SELECT CURRENT_DATABASE()` 79 var databaseName string 80 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 81 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 82 } 83 84 if len(databaseName) == 0 { 85 return nil, ErrNoDatabaseName 86 } 87 88 config.DatabaseName = databaseName 89 } 90 91 if config.SchemaName == "" { 92 query := `SELECT CURRENT_SCHEMA()` 93 var schemaName string 94 if err := instance.QueryRow(query).Scan(&schemaName); err != nil { 95 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 96 } 97 98 if len(schemaName) == 0 { 99 return nil, ErrNoSchema 100 } 101 102 config.SchemaName = schemaName 103 } 104 105 if len(config.MigrationsTable) == 0 { 106 config.MigrationsTable = DefaultMigrationsTable 107 } 108 109 config.migrationsSchemaName = config.SchemaName 110 config.migrationsTableName = config.MigrationsTable 111 if config.MigrationsTableQuoted { 112 re := regexp.MustCompile(`"(.*?)"`) 113 result := re.FindAllStringSubmatch(config.MigrationsTable, -1) 114 config.migrationsTableName = result[len(result)-1][1] 115 if len(result) == 2 { 116 config.migrationsSchemaName = result[0][1] 117 } else if len(result) > 2 { 118 return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable) 119 } 120 } 121 122 conn, err := instance.Conn(context.Background()) 123 124 if err != nil { 125 return nil, err 126 } 127 128 px := &Postgres{ 129 conn: conn, 130 db: instance, 131 config: config, 132 } 133 134 if err := px.ensureVersionTable(); err != nil { 135 return nil, err 136 } 137 138 return px, nil 139} 140 141func (p *Postgres) Open(url string) (database.Driver, error) { 142 purl, err := nurl.Parse(url) 143 if err != nil { 144 return nil, err 145 } 146 147 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String()) 148 if err != nil { 149 return nil, err 150 } 151 152 migrationsTable := purl.Query().Get("x-migrations-table") 153 migrationsTableQuoted := false 154 if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 { 155 migrationsTableQuoted, err = strconv.ParseBool(s) 156 if err != nil { 157 return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err) 158 } 159 } 160 if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) { 161 return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable) 162 } 163 164 statementTimeoutString := purl.Query().Get("x-statement-timeout") 165 statementTimeout := 0 166 if statementTimeoutString != "" { 167 statementTimeout, err = strconv.Atoi(statementTimeoutString) 168 if err != nil { 169 return nil, err 170 } 171 } 172 173 multiStatementMaxSize := DefaultMultiStatementMaxSize 174 if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 { 175 multiStatementMaxSize, err = strconv.Atoi(s) 176 if err != nil { 177 return nil, err 178 } 179 if multiStatementMaxSize <= 0 { 180 multiStatementMaxSize = DefaultMultiStatementMaxSize 181 } 182 } 183 184 multiStatementEnabled := false 185 if s := purl.Query().Get("x-multi-statement"); len(s) > 0 { 186 multiStatementEnabled, err = strconv.ParseBool(s) 187 if err != nil { 188 return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err) 189 } 190 } 191 192 px, err := WithInstance(db, &Config{ 193 DatabaseName: purl.Path, 194 MigrationsTable: migrationsTable, 195 MigrationsTableQuoted: migrationsTableQuoted, 196 StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, 197 MultiStatementEnabled: multiStatementEnabled, 198 MultiStatementMaxSize: multiStatementMaxSize, 199 }) 200 201 if err != nil { 202 return nil, err 203 } 204 205 return px, nil 206} 207 208func (p *Postgres) Close() error { 209 connErr := p.conn.Close() 210 dbErr := p.db.Close() 211 if connErr != nil || dbErr != nil { 212 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 213 } 214 return nil 215} 216 217// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS 218func (p *Postgres) Lock() error { 219 return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { 220 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) 221 if err != nil { 222 return err 223 } 224 225 // This will wait indefinitely until the lock can be acquired. 226 query := `SELECT pg_advisory_lock($1)` 227 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 228 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 229 } 230 231 return nil 232 }) 233} 234 235func (p *Postgres) Unlock() error { 236 return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error { 237 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) 238 if err != nil { 239 return err 240 } 241 242 query := `SELECT pg_advisory_unlock($1)` 243 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { 244 return &database.Error{OrigErr: err, Query: []byte(query)} 245 } 246 return nil 247 }) 248} 249 250func (p *Postgres) Run(migration io.Reader) error { 251 if p.config.MultiStatementEnabled { 252 var err error 253 if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { 254 if err = p.runStatement(m); err != nil { 255 return false 256 } 257 return true 258 }); e != nil { 259 return e 260 } 261 return err 262 } 263 migr, err := ioutil.ReadAll(migration) 264 if err != nil { 265 return err 266 } 267 return p.runStatement(migr) 268} 269 270func (p *Postgres) runStatement(statement []byte) error { 271 ctx := context.Background() 272 if p.config.StatementTimeout != 0 { 273 var cancel context.CancelFunc 274 ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout) 275 defer cancel() 276 } 277 query := string(statement) 278 if strings.TrimSpace(query) == "" { 279 return nil 280 } 281 if _, err := p.conn.ExecContext(ctx, query); err != nil { 282 if pgErr, ok := err.(*pq.Error); ok { 283 var line uint 284 var col uint 285 var lineColOK bool 286 if pgErr.Position != "" { 287 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { 288 line, col, lineColOK = computeLineFromPos(query, int(pos)) 289 } 290 } 291 message := fmt.Sprintf("migration failed: %s", pgErr.Message) 292 if lineColOK { 293 message = fmt.Sprintf("%s (column %d)", message, col) 294 } 295 if pgErr.Detail != "" { 296 message = fmt.Sprintf("%s, %s", message, pgErr.Detail) 297 } 298 return database.Error{OrigErr: err, Err: message, Query: statement, Line: line} 299 } 300 return database.Error{OrigErr: err, Err: "migration failed", Query: statement} 301 } 302 return nil 303} 304 305func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { 306 // replace crlf with lf 307 s = strings.Replace(s, "\r\n", "\n", -1) 308 // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes 309 runes := []rune(s) 310 if pos > len(runes) { 311 return 0, 0, false 312 } 313 sel := runes[:pos] 314 line = uint(runesCount(sel, newLine) + 1) 315 col = uint(pos - 1 - runesLastIndex(sel, newLine)) 316 return line, col, true 317} 318 319const newLine = '\n' 320 321func runesCount(input []rune, target rune) int { 322 var count int 323 for _, r := range input { 324 if r == target { 325 count++ 326 } 327 } 328 return count 329} 330 331func runesLastIndex(input []rune, target rune) int { 332 for i := len(input) - 1; i >= 0; i-- { 333 if input[i] == target { 334 return i 335 } 336 } 337 return -1 338} 339 340func (p *Postgres) SetVersion(version int, dirty bool) error { 341 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) 342 if err != nil { 343 return &database.Error{OrigErr: err, Err: "transaction start failed"} 344 } 345 346 query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) 347 if _, err := tx.Exec(query); err != nil { 348 if errRollback := tx.Rollback(); errRollback != nil { 349 err = multierror.Append(err, errRollback) 350 } 351 return &database.Error{OrigErr: err, Query: []byte(query)} 352 } 353 354 // Also re-write the schema version for nil dirty versions to prevent 355 // empty schema version for failed down migration on the first migration 356 // See: https://github.com/golang-migrate/migrate/issues/330 357 if version >= 0 || (version == database.NilVersion && dirty) { 358 query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)` 359 if _, err := tx.Exec(query, version, dirty); err != nil { 360 if errRollback := tx.Rollback(); errRollback != nil { 361 err = multierror.Append(err, errRollback) 362 } 363 return &database.Error{OrigErr: err, Query: []byte(query)} 364 } 365 } 366 367 if err := tx.Commit(); err != nil { 368 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 369 } 370 371 return nil 372} 373 374func (p *Postgres) Version() (version int, dirty bool, err error) { 375 query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1` 376 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 377 switch { 378 case err == sql.ErrNoRows: 379 return database.NilVersion, false, nil 380 381 case err != nil: 382 if e, ok := err.(*pq.Error); ok { 383 if e.Code.Name() == "undefined_table" { 384 return database.NilVersion, false, nil 385 } 386 } 387 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 388 389 default: 390 return version, dirty, nil 391 } 392} 393 394func (p *Postgres) Drop() (err error) { 395 // select all tables in current schema 396 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` 397 tables, err := p.conn.QueryContext(context.Background(), query) 398 if err != nil { 399 return &database.Error{OrigErr: err, Query: []byte(query)} 400 } 401 defer func() { 402 if errClose := tables.Close(); errClose != nil { 403 err = multierror.Append(err, errClose) 404 } 405 }() 406 407 // delete one table after another 408 tableNames := make([]string, 0) 409 for tables.Next() { 410 var tableName string 411 if err := tables.Scan(&tableName); err != nil { 412 return err 413 } 414 if len(tableName) > 0 { 415 tableNames = append(tableNames, tableName) 416 } 417 } 418 if err := tables.Err(); err != nil { 419 return &database.Error{OrigErr: err, Query: []byte(query)} 420 } 421 422 if len(tableNames) > 0 { 423 // delete one by one ... 424 for _, t := range tableNames { 425 query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE` 426 if _, err := p.conn.ExecContext(context.Background(), query); err != nil { 427 return &database.Error{OrigErr: err, Query: []byte(query)} 428 } 429 } 430 } 431 432 return nil 433} 434 435// ensureVersionTable checks if versions table exists and, if not, creates it. 436// Note that this function locks the database, which deviates from the usual 437// convention of "caller locks" in the Postgres type. 438func (p *Postgres) ensureVersionTable() (err error) { 439 if err = p.Lock(); err != nil { 440 return err 441 } 442 443 defer func() { 444 if e := p.Unlock(); e != nil { 445 if err == nil { 446 err = e 447 } else { 448 err = multierror.Append(err, e) 449 } 450 } 451 }() 452 453 // This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres 454 // users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the 455 // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission. 456 // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258 457 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1` 458 row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName) 459 460 var count int 461 err = row.Scan(&count) 462 if err != nil { 463 return &database.Error{OrigErr: err, Query: []byte(query)} 464 } 465 466 if count == 1 { 467 return nil 468 } 469 470 query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)` 471 if _, err = p.conn.ExecContext(context.Background(), query); err != nil { 472 return &database.Error{OrigErr: err, Query: []byte(query)} 473 } 474 475 return nil 476} 477