1package spanner
2
3import (
4	"errors"
5	"fmt"
6	"io"
7	"io/ioutil"
8	"log"
9	nurl "net/url"
10	"regexp"
11	"strconv"
12	"strings"
13
14	"context"
15
16	"cloud.google.com/go/spanner"
17	sdb "cloud.google.com/go/spanner/admin/database/apiv1"
18	"cloud.google.com/go/spanner/spansql"
19
20	"github.com/golang-migrate/migrate/v4"
21	"github.com/golang-migrate/migrate/v4/database"
22
23	"github.com/hashicorp/go-multierror"
24	uatomic "go.uber.org/atomic"
25	"google.golang.org/api/iterator"
26	adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
27)
28
29func init() {
30	db := Spanner{}
31	database.Register("spanner", &db)
32}
33
34// DefaultMigrationsTable is used if no custom table is specified
35const DefaultMigrationsTable = "SchemaMigrations"
36
37const (
38	unlockedVal = 0
39	lockedVal   = 1
40)
41
42// Driver errors
43var (
44	ErrNilConfig      = errors.New("no config")
45	ErrNoDatabaseName = errors.New("no database name")
46	ErrNoSchema       = errors.New("no schema")
47	ErrDatabaseDirty  = errors.New("database is dirty")
48	ErrLockHeld       = errors.New("unable to obtain lock")
49	ErrLockNotHeld    = errors.New("unable to release already released lock")
50)
51
52// Config used for a Spanner instance
53type Config struct {
54	MigrationsTable string
55	DatabaseName    string
56	// Whether to parse the migration DDL with spansql before
57	// running them towards Spanner.
58	// Parsing outputs clean DDL statements such as reformatted
59	// and void of comments.
60	CleanStatements bool
61}
62
63// Spanner implements database.Driver for Google Cloud Spanner
64type Spanner struct {
65	db *DB
66
67	config *Config
68
69	lock *uatomic.Uint32
70}
71
72type DB struct {
73	admin *sdb.DatabaseAdminClient
74	data  *spanner.Client
75}
76
77func NewDB(admin sdb.DatabaseAdminClient, data spanner.Client) *DB {
78	return &DB{
79		admin: &admin,
80		data:  &data,
81	}
82}
83
84// WithInstance implements database.Driver
85func WithInstance(instance *DB, config *Config) (database.Driver, error) {
86	if config == nil {
87		return nil, ErrNilConfig
88	}
89
90	if len(config.DatabaseName) == 0 {
91		return nil, ErrNoDatabaseName
92	}
93
94	if len(config.MigrationsTable) == 0 {
95		config.MigrationsTable = DefaultMigrationsTable
96	}
97
98	sx := &Spanner{
99		db:     instance,
100		config: config,
101		lock:   uatomic.NewUint32(unlockedVal),
102	}
103
104	if err := sx.ensureVersionTable(); err != nil {
105		return nil, err
106	}
107
108	return sx, nil
109}
110
111// Open implements database.Driver
112func (s *Spanner) Open(url string) (database.Driver, error) {
113	purl, err := nurl.Parse(url)
114	if err != nil {
115		return nil, err
116	}
117
118	ctx := context.Background()
119
120	adminClient, err := sdb.NewDatabaseAdminClient(ctx)
121	if err != nil {
122		return nil, err
123	}
124	dbname := strings.Replace(migrate.FilterCustomQuery(purl).String(), "spanner://", "", 1)
125	dataClient, err := spanner.NewClient(ctx, dbname)
126	if err != nil {
127		log.Fatal(err)
128	}
129
130	migrationsTable := purl.Query().Get("x-migrations-table")
131
132	cleanQuery := purl.Query().Get("x-clean-statements")
133	clean := false
134	if cleanQuery != "" {
135		clean, err = strconv.ParseBool(cleanQuery)
136		if err != nil {
137			return nil, err
138		}
139	}
140
141	db := &DB{admin: adminClient, data: dataClient}
142	return WithInstance(db, &Config{
143		DatabaseName:    dbname,
144		MigrationsTable: migrationsTable,
145		CleanStatements: clean,
146	})
147}
148
149// Close implements database.Driver
150func (s *Spanner) Close() error {
151	s.db.data.Close()
152	return s.db.admin.Close()
153}
154
155// Lock implements database.Driver but doesn't do anything because Spanner only
156// enqueues the UpdateDatabaseDdlRequest.
157func (s *Spanner) Lock() error {
158	if swapped := s.lock.CAS(unlockedVal, lockedVal); swapped {
159		return nil
160	}
161	return ErrLockHeld
162}
163
164// Unlock implements database.Driver but no action required, see Lock.
165func (s *Spanner) Unlock() error {
166	if swapped := s.lock.CAS(lockedVal, unlockedVal); swapped {
167		return nil
168	}
169	return ErrLockNotHeld
170}
171
172// Run implements database.Driver
173func (s *Spanner) Run(migration io.Reader) error {
174	migr, err := ioutil.ReadAll(migration)
175	if err != nil {
176		return err
177	}
178
179	stmts := []string{string(migr)}
180	if s.config.CleanStatements {
181		stmts, err = cleanStatements(migr)
182		if err != nil {
183			return err
184		}
185	}
186
187	ctx := context.Background()
188	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
189		Database:   s.config.DatabaseName,
190		Statements: stmts,
191	})
192
193	if err != nil {
194		return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
195	}
196
197	if err := op.Wait(ctx); err != nil {
198		return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
199	}
200
201	return nil
202}
203
204// SetVersion implements database.Driver
205func (s *Spanner) SetVersion(version int, dirty bool) error {
206	ctx := context.Background()
207
208	_, err := s.db.data.ReadWriteTransaction(ctx,
209		func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
210			m := []*spanner.Mutation{
211				spanner.Delete(s.config.MigrationsTable, spanner.AllKeys()),
212				spanner.Insert(s.config.MigrationsTable,
213					[]string{"Version", "Dirty"},
214					[]interface{}{version, dirty},
215				)}
216			return txn.BufferWrite(m)
217		})
218	if err != nil {
219		return &database.Error{OrigErr: err}
220	}
221
222	return nil
223}
224
225// Version implements database.Driver
226func (s *Spanner) Version() (version int, dirty bool, err error) {
227	ctx := context.Background()
228
229	stmt := spanner.Statement{
230		SQL: `SELECT Version, Dirty FROM ` + s.config.MigrationsTable + ` LIMIT 1`,
231	}
232	iter := s.db.data.Single().Query(ctx, stmt)
233	defer iter.Stop()
234
235	row, err := iter.Next()
236	switch err {
237	case iterator.Done:
238		return database.NilVersion, false, nil
239	case nil:
240		var v int64
241		if err = row.Columns(&v, &dirty); err != nil {
242			return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
243		}
244		version = int(v)
245	default:
246		return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
247	}
248
249	return version, dirty, nil
250}
251
252var nameMatcher = regexp.MustCompile(`(CREATE TABLE\s(\S+)\s)|(CREATE.+INDEX\s(\S+)\s)`)
253
254// Drop implements database.Driver. Retrieves the database schema first and
255// creates statements to drop the indexes and tables accordingly.
256// Note: The drop statements are created in reverse order to how they're
257// provided in the schema. Assuming the schema describes how the database can
258// be "build up", it seems logical to "unbuild" the database simply by going the
259// opposite direction. More testing
260func (s *Spanner) Drop() error {
261	ctx := context.Background()
262	res, err := s.db.admin.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{
263		Database: s.config.DatabaseName,
264	})
265	if err != nil {
266		return &database.Error{OrigErr: err, Err: "drop failed"}
267	}
268	if len(res.Statements) == 0 {
269		return nil
270	}
271
272	stmts := make([]string, 0)
273	for i := len(res.Statements) - 1; i >= 0; i-- {
274		s := res.Statements[i]
275		m := nameMatcher.FindSubmatch([]byte(s))
276
277		if len(m) == 0 {
278			continue
279		} else if tbl := m[2]; len(tbl) > 0 {
280			stmts = append(stmts, fmt.Sprintf(`DROP TABLE %s`, tbl))
281		} else if idx := m[4]; len(idx) > 0 {
282			stmts = append(stmts, fmt.Sprintf(`DROP INDEX %s`, idx))
283		}
284	}
285
286	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
287		Database:   s.config.DatabaseName,
288		Statements: stmts,
289	})
290	if err != nil {
291		return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
292	}
293	if err := op.Wait(ctx); err != nil {
294		return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
295	}
296
297	return nil
298}
299
300// ensureVersionTable checks if versions table exists and, if not, creates it.
301// Note that this function locks the database, which deviates from the usual
302// convention of "caller locks" in the Spanner type.
303func (s *Spanner) ensureVersionTable() (err error) {
304	if err = s.Lock(); err != nil {
305		return err
306	}
307
308	defer func() {
309		if e := s.Unlock(); e != nil {
310			if err == nil {
311				err = e
312			} else {
313				err = multierror.Append(err, e)
314			}
315		}
316	}()
317
318	ctx := context.Background()
319	tbl := s.config.MigrationsTable
320	iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), []string{"Version"})
321	if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil {
322		return nil
323	}
324
325	stmt := fmt.Sprintf(`CREATE TABLE %s (
326    Version INT64 NOT NULL,
327    Dirty    BOOL NOT NULL
328	) PRIMARY KEY(Version)`, tbl)
329
330	op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
331		Database:   s.config.DatabaseName,
332		Statements: []string{stmt},
333	})
334
335	if err != nil {
336		return &database.Error{OrigErr: err, Query: []byte(stmt)}
337	}
338	if err := op.Wait(ctx); err != nil {
339		return &database.Error{OrigErr: err, Query: []byte(stmt)}
340	}
341
342	return nil
343}
344
345func cleanStatements(migration []byte) ([]string, error) {
346	// The Spanner GCP backend does not yet support comments for the UpdateDatabaseDdl RPC
347	// (see https://issuetracker.google.com/issues/159730604) we use
348	// spansql to parse the DDL and output valid stamements without comments
349	ddl, err := spansql.ParseDDL("", string(migration))
350	if err != nil {
351		return nil, err
352	}
353	stmts := make([]string, 0, len(ddl.List))
354	for _, stmt := range ddl.List {
355		stmts = append(stmts, stmt.SQL())
356	}
357	return stmts, nil
358}
359