1/*
2Copyright 2019 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spannertest
18
19// This file contains the implementation of the Spanner fake itself,
20// namely the part behind the RPC interface.
21
22// TODO: missing transactionality in a serious way!
23
24import (
25	"bytes"
26	"encoding/base64"
27	"fmt"
28	"sort"
29	"strconv"
30	"strings"
31	"sync"
32	"time"
33
34	"google.golang.org/grpc/codes"
35	"google.golang.org/grpc/status"
36
37	structpb "github.com/golang/protobuf/ptypes/struct"
38
39	"cloud.google.com/go/civil"
40	"cloud.google.com/go/spanner/spansql"
41)
42
43type database struct {
44	mu      sync.Mutex
45	lastTS  time.Time // last commit timestamp
46	tables  map[spansql.ID]*table
47	indexes map[spansql.ID]struct{} // only record their existence
48	views   map[spansql.ID]struct{} // only record their existence
49
50	rwMu sync.Mutex // held by read-write transactions
51}
52
53type table struct {
54	mu sync.Mutex
55
56	// Information about the table columns.
57	// They are reordered on table creation so the primary key columns come first.
58	cols      []colInfo
59	colIndex  map[spansql.ID]int         // col name to index
60	origIndex map[spansql.ID]int         // original index of each column upon construction
61	pkCols    int                        // number of primary key columns (may be 0)
62	pkDesc    []bool                     // whether each primary key column is in descending order
63	rdw       *spansql.RowDeletionPolicy // RowDeletionPolicy of this table (may be nil)
64
65	// Rows are stored in primary key order.
66	rows []row
67}
68
69// colInfo represents information about a column in a table or result set.
70type colInfo struct {
71	Name      spansql.ID
72	Type      spansql.Type
73	Generated spansql.Expr
74	NotNull   bool            // only set for table columns
75	AggIndex  int             // Index+1 of SELECT list for which this is an aggregate value.
76	Alias     spansql.PathExp // an alternate name for this column (result sets only)
77}
78
79// commitTimestampSentinel is a sentinel value for TIMESTAMP fields with allow_commit_timestamp=true.
80// It is accepted, but never stored.
81var commitTimestampSentinel = &struct{}{}
82
83// transaction records information about a running transaction.
84// This is not safe for concurrent use.
85type transaction struct {
86	// readOnly is whether this transaction was constructed
87	// for read-only use, and should yield errors if used
88	// to perform a mutation.
89	readOnly bool
90
91	d               *database
92	commitTimestamp time.Time // not set if readOnly
93	unlock          func()    // may be nil
94}
95
96func (d *database) NewReadOnlyTransaction() *transaction {
97	return &transaction{
98		readOnly: true,
99	}
100}
101
102func (d *database) NewTransaction() *transaction {
103	return &transaction{
104		d: d,
105	}
106}
107
108// Start starts the transaction and commits to a specific commit timestamp.
109// This also locks out any other read-write transaction on this database
110// until Commit/Rollback are called.
111func (tx *transaction) Start() {
112	// Commit timestamps are only guaranteed to be unique
113	// when transactions write to overlapping sets of fields.
114	// This simulated database exceeds that guarantee.
115
116	// Grab rwMu for the duration of this transaction.
117	// Take it before d.mu so we don't hold that lock
118	// while waiting for d.rwMu, which is held for longer.
119	tx.d.rwMu.Lock()
120
121	tx.d.mu.Lock()
122	const tsRes = 1 * time.Microsecond
123	now := time.Now().UTC().Truncate(tsRes)
124	if !now.After(tx.d.lastTS) {
125		now = tx.d.lastTS.Add(tsRes)
126	}
127	tx.d.lastTS = now
128	tx.d.mu.Unlock()
129
130	tx.commitTimestamp = now
131	tx.unlock = tx.d.rwMu.Unlock
132}
133
134func (tx *transaction) checkMutable() error {
135	if tx.readOnly {
136		// TODO: is this the right status?
137		return status.Errorf(codes.InvalidArgument, "transaction is read-only")
138	}
139	return nil
140}
141
142func (tx *transaction) Commit() (time.Time, error) {
143	if tx.unlock != nil {
144		tx.unlock()
145	}
146	return tx.commitTimestamp, nil
147}
148
149func (tx *transaction) Rollback() {
150	if tx.unlock != nil {
151		tx.unlock()
152	}
153	// TODO: actually rollback
154}
155
156/*
157row represents a list of data elements.
158
159The mapping between Spanner types and Go types internal to this package are:
160	BOOL		bool
161	INT64		int64
162	FLOAT64		float64
163	STRING		string
164	BYTES		[]byte
165	DATE		civil.Date
166	TIMESTAMP	time.Time (location set to UTC)
167	ARRAY<T>	[]interface{}
168	STRUCT		TODO
169*/
170type row []interface{}
171
172func (r row) copyDataElem(index int) interface{} {
173	v := r[index]
174	if is, ok := v.([]interface{}); ok {
175		// Deep-copy array values.
176		v = append([]interface{}(nil), is...)
177	}
178	return v
179}
180
181// copyData returns a copy of the row.
182func (r row) copyAllData() row {
183	dst := make(row, 0, len(r))
184	for i := range r {
185		dst = append(dst, r.copyDataElem(i))
186	}
187	return dst
188}
189
190// copyData returns a copy of a subset of a row.
191func (r row) copyData(indexes []int) row {
192	if len(indexes) == 0 {
193		return nil
194	}
195	dst := make(row, 0, len(indexes))
196	for _, i := range indexes {
197		dst = append(dst, r.copyDataElem(i))
198	}
199	return dst
200}
201
202func (d *database) LastCommitTimestamp() time.Time {
203	d.mu.Lock()
204	defer d.mu.Unlock()
205	return d.lastTS
206}
207
208func (d *database) GetDDL() []spansql.DDLStmt {
209	// This lacks fidelity, but captures the details we support.
210	d.mu.Lock()
211	defer d.mu.Unlock()
212
213	var stmts []spansql.DDLStmt
214
215	for name, t := range d.tables {
216		ct := &spansql.CreateTable{
217			Name: name,
218		}
219
220		t.mu.Lock()
221		for i, col := range t.cols {
222			ct.Columns = append(ct.Columns, spansql.ColumnDef{
223				Name:    col.Name,
224				Type:    col.Type,
225				NotNull: col.NotNull,
226				// TODO: AllowCommitTimestamp
227			})
228			if i < t.pkCols {
229				ct.PrimaryKey = append(ct.PrimaryKey, spansql.KeyPart{
230					Column: col.Name,
231					Desc:   t.pkDesc[i],
232				})
233			}
234		}
235		t.mu.Unlock()
236
237		stmts = append(stmts, ct)
238	}
239
240	return stmts
241}
242
243func (d *database) ApplyDDL(stmt spansql.DDLStmt) *status.Status {
244	d.mu.Lock()
245	defer d.mu.Unlock()
246
247	// Lazy init.
248	if d.tables == nil {
249		d.tables = make(map[spansql.ID]*table)
250	}
251	if d.indexes == nil {
252		d.indexes = make(map[spansql.ID]struct{})
253	}
254	if d.views == nil {
255		d.views = make(map[spansql.ID]struct{})
256	}
257
258	switch stmt := stmt.(type) {
259	default:
260		return status.Newf(codes.Unimplemented, "unhandled DDL statement type %T", stmt)
261	case *spansql.CreateTable:
262		if _, ok := d.tables[stmt.Name]; ok {
263			return status.Newf(codes.AlreadyExists, "table %s already exists", stmt.Name)
264		}
265		if len(stmt.PrimaryKey) == 0 {
266			return status.Newf(codes.InvalidArgument, "table %s has no primary key", stmt.Name)
267		}
268
269		// TODO: check stmt.Interleave details.
270
271		// Record original column ordering.
272		orig := make(map[spansql.ID]int)
273		for i, col := range stmt.Columns {
274			orig[col.Name] = i
275		}
276
277		// Move primary keys first, preserving their order.
278		pk := make(map[spansql.ID]int)
279		var pkDesc []bool
280		for i, kp := range stmt.PrimaryKey {
281			pk[kp.Column] = -1000 + i
282			pkDesc = append(pkDesc, kp.Desc)
283		}
284		sort.SliceStable(stmt.Columns, func(i, j int) bool {
285			a, b := pk[stmt.Columns[i].Name], pk[stmt.Columns[j].Name]
286			return a < b
287		})
288
289		t := &table{
290			colIndex:  make(map[spansql.ID]int),
291			origIndex: orig,
292			pkCols:    len(pk),
293			pkDesc:    pkDesc,
294		}
295		for _, cd := range stmt.Columns {
296			if st := t.addColumn(cd, true); st.Code() != codes.OK {
297				return st
298			}
299		}
300		for col := range pk {
301			if _, ok := t.colIndex[col]; !ok {
302				return status.Newf(codes.InvalidArgument, "primary key column %q not in table", col)
303			}
304		}
305		t.rdw = stmt.RowDeletionPolicy
306		d.tables[stmt.Name] = t
307		return nil
308	case *spansql.CreateIndex:
309		if _, ok := d.indexes[stmt.Name]; ok {
310			return status.Newf(codes.AlreadyExists, "index %s already exists", stmt.Name)
311		}
312		d.indexes[stmt.Name] = struct{}{}
313		return nil
314	case *spansql.CreateView:
315		if !stmt.OrReplace {
316			if _, ok := d.views[stmt.Name]; ok {
317				return status.Newf(codes.AlreadyExists, "view %s already exists", stmt.Name)
318			}
319		}
320		d.views[stmt.Name] = struct{}{}
321		return nil
322	case *spansql.DropTable:
323		if _, ok := d.tables[stmt.Name]; !ok {
324			return status.Newf(codes.NotFound, "no table named %s", stmt.Name)
325		}
326		// TODO: check for indexes on this table.
327		delete(d.tables, stmt.Name)
328		return nil
329	case *spansql.DropIndex:
330		if _, ok := d.indexes[stmt.Name]; !ok {
331			return status.Newf(codes.NotFound, "no index named %s", stmt.Name)
332		}
333		delete(d.indexes, stmt.Name)
334		return nil
335	case *spansql.DropView:
336		if _, ok := d.views[stmt.Name]; !ok {
337			return status.Newf(codes.NotFound, "no view named %s", stmt.Name)
338		}
339		delete(d.views, stmt.Name)
340		return nil
341	case *spansql.AlterTable:
342		t, ok := d.tables[stmt.Name]
343		if !ok {
344			return status.Newf(codes.NotFound, "no table named %s", stmt.Name)
345		}
346		switch alt := stmt.Alteration.(type) {
347		default:
348			return status.Newf(codes.Unimplemented, "unhandled DDL table alteration type %T", alt)
349		case spansql.AddColumn:
350			if st := t.addColumn(alt.Def, false); st.Code() != codes.OK {
351				return st
352			}
353			return nil
354		case spansql.DropColumn:
355			if st := t.dropColumn(alt.Name); st.Code() != codes.OK {
356				return st
357			}
358			return nil
359		case spansql.AlterColumn:
360			if st := t.alterColumn(alt); st.Code() != codes.OK {
361				return st
362			}
363			return nil
364		case spansql.AddRowDeletionPolicy:
365			if st := t.addRowDeletionPolicy(alt); st.Code() != codes.OK {
366				return st
367			}
368			return nil
369		case spansql.ReplaceRowDeletionPolicy:
370			if st := t.replaceRowDeletionPolicy(alt); st.Code() != codes.OK {
371				return st
372			}
373			return nil
374		case spansql.DropRowDeletionPolicy:
375			if st := t.dropRowDeletionPolicy(alt); st.Code() != codes.OK {
376				return st
377			}
378			return nil
379		}
380	}
381
382}
383
384func (d *database) table(tbl spansql.ID) (*table, error) {
385	d.mu.Lock()
386	defer d.mu.Unlock()
387
388	t, ok := d.tables[tbl]
389	if !ok {
390		return nil, status.Errorf(codes.NotFound, "no table named %s", tbl)
391	}
392	return t, nil
393}
394
395// writeValues executes a write option (Insert, Update, etc.).
396func (d *database) writeValues(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue, f func(t *table, colIndexes []int, r row) error) error {
397	if err := tx.checkMutable(); err != nil {
398		return err
399	}
400
401	t, err := d.table(tbl)
402	if err != nil {
403		return err
404	}
405
406	t.mu.Lock()
407	defer t.mu.Unlock()
408
409	colIndexes, err := t.colIndexes(cols)
410	if err != nil {
411		return err
412	}
413	revIndex := make(map[int]int) // table index to col index
414	for j, i := range colIndexes {
415		revIndex[i] = j
416	}
417
418	for pki := 0; pki < t.pkCols; pki++ {
419		_, ok := revIndex[pki]
420		if !ok {
421			return status.Errorf(codes.InvalidArgument, "primary key column %s not included in write", t.cols[pki].Name)
422		}
423	}
424
425	for _, vs := range values {
426		if len(vs.Values) != len(colIndexes) {
427			return status.Errorf(codes.InvalidArgument, "row of %d values can't be written to %d columns", len(vs.Values), len(colIndexes))
428		}
429
430		r := make(row, len(t.cols))
431		for j, v := range vs.Values {
432			i := colIndexes[j]
433
434			if t.cols[i].Generated != nil {
435				return status.Error(codes.InvalidArgument, "values can't be written to a generated column")
436			}
437			x, err := valForType(v, t.cols[i].Type)
438			if err != nil {
439				return err
440			}
441			if x == commitTimestampSentinel {
442				x = tx.commitTimestamp
443			}
444			if x == nil && t.cols[i].NotNull {
445				return status.Errorf(codes.FailedPrecondition, "%s must not be NULL in table %s", t.cols[i].Name, tbl)
446			}
447
448			r[i] = x
449		}
450		// TODO: enforce that provided timestamp for commit_timestamp=true columns
451		// are not ahead of the transaction's commit timestamp.
452
453		if err := f(t, colIndexes, r); err != nil {
454			return err
455		}
456
457		// Get row again after potential update merge to ensure we compute
458		// generated columns with fresh data.
459		pk := r[:t.pkCols]
460		rowNum, found := t.rowForPK(pk)
461		// This should never fail as the row was just inserted.
462		if !found {
463			return status.Error(codes.Internal, "row failed to be inserted")
464		}
465		row := t.rows[rowNum]
466		ec := evalContext{
467			cols: t.cols,
468			row:  row,
469		}
470
471		// TODO: We would need to do a topological sort on dependencies
472		// (i.e. what other columns the expression references) to ensure we
473		// can handle generated columns which reference other generated columns
474		for i, col := range t.cols {
475			if col.Generated != nil {
476				res, err := ec.evalExpr(col.Generated)
477				if err != nil {
478					return err
479				}
480				row[i] = res
481			}
482		}
483	}
484
485	return nil
486}
487
488func (d *database) Insert(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue) error {
489	return d.writeValues(tx, tbl, cols, values, func(t *table, colIndexes []int, r row) error {
490		pk := r[:t.pkCols]
491		rowNum, found := t.rowForPK(pk)
492		if found {
493			return status.Errorf(codes.AlreadyExists, "row already in table")
494		}
495		t.insertRow(rowNum, r)
496		return nil
497	})
498}
499
500func (d *database) Update(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue) error {
501	return d.writeValues(tx, tbl, cols, values, func(t *table, colIndexes []int, r row) error {
502		if t.pkCols == 0 {
503			return status.Errorf(codes.InvalidArgument, "cannot update table %s with no columns in primary key", tbl)
504		}
505		pk := r[:t.pkCols]
506		rowNum, found := t.rowForPK(pk)
507		if !found {
508			// TODO: is this the right way to return `NOT_FOUND`?
509			return status.Errorf(codes.NotFound, "row not in table")
510		}
511
512		for _, i := range colIndexes {
513			t.rows[rowNum][i] = r[i]
514		}
515		return nil
516	})
517}
518
519func (d *database) InsertOrUpdate(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue) error {
520	return d.writeValues(tx, tbl, cols, values, func(t *table, colIndexes []int, r row) error {
521		pk := r[:t.pkCols]
522		rowNum, found := t.rowForPK(pk)
523		if !found {
524			// New row; do an insert.
525			t.insertRow(rowNum, r)
526		} else {
527			// Existing row; do an update.
528			for _, i := range colIndexes {
529				t.rows[rowNum][i] = r[i]
530			}
531		}
532		return nil
533	})
534}
535
536// TODO: Replace
537
538func (d *database) Delete(tx *transaction, table spansql.ID, keys []*structpb.ListValue, keyRanges keyRangeList, all bool) error {
539	if err := tx.checkMutable(); err != nil {
540		return err
541	}
542
543	t, err := d.table(table)
544	if err != nil {
545		return err
546	}
547
548	t.mu.Lock()
549	defer t.mu.Unlock()
550
551	if all {
552		t.rows = nil
553		return nil
554	}
555
556	for _, key := range keys {
557		pk, err := t.primaryKey(key.Values)
558		if err != nil {
559			return err
560		}
561		// Not an error if the key does not exist.
562		rowNum, found := t.rowForPK(pk)
563		if found {
564			copy(t.rows[rowNum:], t.rows[rowNum+1:])
565			t.rows = t.rows[:len(t.rows)-1]
566		}
567	}
568
569	for _, r := range keyRanges {
570		r.startKey, err = t.primaryKeyPrefix(r.start.Values)
571		if err != nil {
572			return err
573		}
574		r.endKey, err = t.primaryKeyPrefix(r.end.Values)
575		if err != nil {
576			return err
577		}
578		startRow, endRow := t.findRange(r)
579		if n := endRow - startRow; n > 0 {
580			copy(t.rows[startRow:], t.rows[endRow:])
581			t.rows = t.rows[:len(t.rows)-n]
582		}
583	}
584
585	return nil
586}
587
588// readTable executes a read option (Read, ReadAll).
589func (d *database) readTable(table spansql.ID, cols []spansql.ID, f func(*table, *rawIter, []int) error) (*rawIter, error) {
590	t, err := d.table(table)
591	if err != nil {
592		return nil, err
593	}
594
595	t.mu.Lock()
596	defer t.mu.Unlock()
597
598	colIndexes, err := t.colIndexes(cols)
599	if err != nil {
600		return nil, err
601	}
602
603	ri := &rawIter{}
604	for _, i := range colIndexes {
605		ri.cols = append(ri.cols, t.cols[i])
606	}
607	return ri, f(t, ri, colIndexes)
608}
609
610func (d *database) Read(tbl spansql.ID, cols []spansql.ID, keys []*structpb.ListValue, keyRanges keyRangeList, limit int64) (rowIter, error) {
611	// The real Cloud Spanner returns an error if the key set is empty by definition.
612	// That doesn't seem to be well-defined, but it is a common error to attempt a read with no keys,
613	// so catch that here and return a representative error.
614	if len(keys) == 0 && len(keyRanges) == 0 {
615		return nil, status.Error(codes.Unimplemented, "Cloud Spanner does not support reading no keys")
616	}
617
618	return d.readTable(tbl, cols, func(t *table, ri *rawIter, colIndexes []int) error {
619		// "If the same key is specified multiple times in the set (for
620		// example if two ranges, two keys, or a key and a range
621		// overlap), Cloud Spanner behaves as if the key were only
622		// specified once."
623		done := make(map[int]bool) // row numbers we've included in ri.
624
625		// Specific keys.
626		for _, key := range keys {
627			pk, err := t.primaryKey(key.Values)
628			if err != nil {
629				return err
630			}
631			// Not an error if the key does not exist.
632			rowNum, found := t.rowForPK(pk)
633			if !found {
634				continue
635			}
636			if done[rowNum] {
637				continue
638			}
639			done[rowNum] = true
640			ri.add(t.rows[rowNum], colIndexes)
641			if limit > 0 && len(ri.rows) >= int(limit) {
642				return nil
643			}
644		}
645
646		// Key ranges.
647		for _, r := range keyRanges {
648			var err error
649			r.startKey, err = t.primaryKeyPrefix(r.start.Values)
650			if err != nil {
651				return err
652			}
653			r.endKey, err = t.primaryKeyPrefix(r.end.Values)
654			if err != nil {
655				return err
656			}
657			startRow, endRow := t.findRange(r)
658			for rowNum := startRow; rowNum < endRow; rowNum++ {
659				if done[rowNum] {
660					continue
661				}
662				done[rowNum] = true
663				ri.add(t.rows[rowNum], colIndexes)
664				if limit > 0 && len(ri.rows) >= int(limit) {
665					return nil
666				}
667			}
668		}
669
670		return nil
671	})
672}
673
674func (d *database) ReadAll(tbl spansql.ID, cols []spansql.ID, limit int64) (*rawIter, error) {
675	return d.readTable(tbl, cols, func(t *table, ri *rawIter, colIndexes []int) error {
676		for _, r := range t.rows {
677			ri.add(r, colIndexes)
678			if limit > 0 && len(ri.rows) >= int(limit) {
679				break
680			}
681		}
682		return nil
683	})
684}
685
686func (t *table) addColumn(cd spansql.ColumnDef, newTable bool) *status.Status {
687	if !newTable && cd.NotNull {
688		return status.Newf(codes.InvalidArgument, "new non-key columns cannot be NOT NULL")
689	}
690
691	if _, ok := t.colIndex[cd.Name]; ok {
692		return status.Newf(codes.AlreadyExists, "column %s already exists", cd.Name)
693	}
694
695	t.mu.Lock()
696	defer t.mu.Unlock()
697
698	if len(t.rows) > 0 {
699		if cd.NotNull {
700			// TODO: what happens in this case?
701			return status.Newf(codes.Unimplemented, "can't add NOT NULL columns to non-empty tables yet")
702		}
703		if cd.Generated != nil {
704			// TODO: should backfill the data to maintain behaviour with real spanner
705			return status.Newf(codes.Unimplemented, "can't add generated columns to non-empty tables yet")
706		}
707		for i := range t.rows {
708			t.rows[i] = append(t.rows[i], nil)
709		}
710	}
711
712	t.cols = append(t.cols, colInfo{
713		Name:    cd.Name,
714		Type:    cd.Type,
715		NotNull: cd.NotNull,
716		// TODO: We should figure out what columns the Generator expression
717		// relies on and check it is valid at this time currently it will
718		// fail when writing data instead as it is the first time we
719		// evaluate the expression.
720		Generated: cd.Generated,
721	})
722	t.colIndex[cd.Name] = len(t.cols) - 1
723	if !newTable {
724		t.origIndex[cd.Name] = len(t.cols) - 1
725	}
726
727	return nil
728}
729
730func (t *table) dropColumn(name spansql.ID) *status.Status {
731	// Only permit dropping non-key columns that aren't part of a secondary index.
732	// We don't support indexes, so only check that it isn't part of the primary key.
733
734	t.mu.Lock()
735	defer t.mu.Unlock()
736
737	ci, ok := t.colIndex[name]
738	if !ok {
739		// TODO: What's the right response code?
740		return status.Newf(codes.InvalidArgument, "unknown column %q", name)
741	}
742	if ci < t.pkCols {
743		// TODO: What's the right response code?
744		return status.Newf(codes.InvalidArgument, "can't drop primary key column %q", name)
745	}
746
747	// Remove from cols and colIndex, and renumber colIndex and origIndex.
748	t.cols = append(t.cols[:ci], t.cols[ci+1:]...)
749	delete(t.colIndex, name)
750	for i, col := range t.cols {
751		t.colIndex[col.Name] = i
752	}
753	pre := t.origIndex[name]
754	delete(t.origIndex, name)
755	for n, i := range t.origIndex {
756		if i > pre {
757			t.origIndex[n]--
758		}
759	}
760
761	// Drop data.
762	for i := range t.rows {
763		t.rows[i] = append(t.rows[i][:ci], t.rows[i][ci+1:]...)
764	}
765
766	return nil
767}
768
769func (t *table) alterColumn(alt spansql.AlterColumn) *status.Status {
770	// Supported changes here are:
771	//	Add NOT NULL to a non-key column, excluding ARRAY columns.
772	//	Remove NOT NULL from a non-key column.
773	//	Change a STRING column to a BYTES column or a BYTES column to a STRING column.
774	//	Increase or decrease the length limit for a STRING or BYTES type (including to MAX).
775	//	Enable or disable commit timestamps in value and primary key columns.
776	// https://cloud.google.com/spanner/docs/schema-updates#supported-updates
777
778	// TODO: codes.InvalidArgument is used throughout here for reporting errors,
779	// but that has not been validated against the real Spanner.
780
781	sct, ok := alt.Alteration.(spansql.SetColumnType)
782	if !ok {
783		return status.Newf(codes.InvalidArgument, "unsupported ALTER COLUMN %s", alt.SQL())
784	}
785
786	t.mu.Lock()
787	defer t.mu.Unlock()
788
789	ci, ok := t.colIndex[alt.Name]
790	if !ok {
791		return status.Newf(codes.InvalidArgument, "unknown column %q", alt.Name)
792	}
793
794	oldT, newT := t.cols[ci].Type, sct.Type
795	stringOrBytes := func(bt spansql.TypeBase) bool { return bt == spansql.String || bt == spansql.Bytes }
796
797	// First phase: Check the validity of the change.
798	// TODO: Don't permit changes to allow commit timestamps.
799	if !t.cols[ci].NotNull && sct.NotNull {
800		// Adding NOT NULL is not permitted for primary key columns or array typed columns.
801		if ci < t.pkCols {
802			return status.Newf(codes.InvalidArgument, "cannot set NOT NULL on primary key column %q", alt.Name)
803		}
804		if oldT.Array {
805			return status.Newf(codes.InvalidArgument, "cannot set NOT NULL on array-typed column %q", alt.Name)
806		}
807		// Validate that there are no NULL values.
808		for _, row := range t.rows {
809			if row[ci] == nil {
810				return status.Newf(codes.InvalidArgument, "cannot set NOT NULL on column %q that contains NULL values", alt.Name)
811			}
812		}
813	}
814	var conv func(x interface{}) interface{}
815	if stringOrBytes(oldT.Base) && stringOrBytes(newT.Base) && !oldT.Array && !newT.Array {
816		// Change between STRING and BYTES is fine, as is increasing/decreasing the length limit.
817		// TODO: This should permit array conversions too.
818		// TODO: Validate data; length limit changes should be rejected if they'd lead to data loss, for instance.
819		if oldT.Base == spansql.Bytes && newT.Base == spansql.String {
820			conv = func(x interface{}) interface{} { return string(x.([]byte)) }
821		} else if oldT.Base == spansql.String && newT.Base == spansql.Bytes {
822			conv = func(x interface{}) interface{} { return []byte(x.(string)) }
823		}
824	} else if oldT == newT {
825		// Same type; only NOT NULL changes.
826	} else { // TODO: Support other alterations.
827		return status.Newf(codes.InvalidArgument, "unsupported ALTER COLUMN %s", alt.SQL())
828	}
829
830	// Second phase: Make type transformations.
831	t.cols[ci].NotNull = sct.NotNull
832	t.cols[ci].Type = newT
833	if conv != nil {
834		for _, row := range t.rows {
835			if row[ci] != nil { // NULL stays as NULL.
836				row[ci] = conv(row[ci])
837			}
838		}
839	}
840	return nil
841}
842
843func (t *table) addRowDeletionPolicy(ard spansql.AddRowDeletionPolicy) *status.Status {
844	_, ok := t.colIndex[ard.RowDeletionPolicy.Column]
845	if !ok {
846		return status.Newf(codes.InvalidArgument, "unknown column %q", ard.RowDeletionPolicy.Column)
847	}
848	if t.rdw != nil {
849		return status.New(codes.InvalidArgument, "table already has a row deletion policy")
850	}
851	t.rdw = &ard.RowDeletionPolicy
852	return nil
853}
854
855func (t *table) replaceRowDeletionPolicy(ard spansql.ReplaceRowDeletionPolicy) *status.Status {
856	_, ok := t.colIndex[ard.RowDeletionPolicy.Column]
857	if !ok {
858		return status.Newf(codes.InvalidArgument, "unknown column %q", ard.RowDeletionPolicy.Column)
859	}
860	if t.rdw == nil {
861		return status.New(codes.InvalidArgument, "table does not have a row deletion policy")
862	}
863	t.rdw = &ard.RowDeletionPolicy
864	return nil
865}
866
867func (t *table) dropRowDeletionPolicy(ard spansql.DropRowDeletionPolicy) *status.Status {
868	if t.rdw == nil {
869		return status.New(codes.InvalidArgument, "table does not have a row deletion policy")
870	}
871	t.rdw = nil
872	return nil
873}
874
875func (t *table) insertRow(rowNum int, r row) {
876	t.rows = append(t.rows, nil)
877	copy(t.rows[rowNum+1:], t.rows[rowNum:])
878	t.rows[rowNum] = r
879}
880
881// findRange finds the rows included in the key range,
882// reporting it as a half-open interval.
883// r.startKey and r.endKey should be populated.
884func (t *table) findRange(r *keyRange) (int, int) {
885	// startRow is the first row matching the range.
886	startRow := sort.Search(len(t.rows), func(i int) bool {
887		return rowCmp(r.startKey, t.rows[i][:t.pkCols], t.pkDesc) <= 0
888	})
889	if startRow == len(t.rows) {
890		return startRow, startRow
891	}
892	if !r.startClosed && rowCmp(r.startKey, t.rows[startRow][:t.pkCols], t.pkDesc) == 0 {
893		startRow++
894	}
895
896	// endRow is one more than the last row matching the range.
897	endRow := sort.Search(len(t.rows), func(i int) bool {
898		return rowCmp(r.endKey, t.rows[i][:t.pkCols], t.pkDesc) < 0
899	})
900	if !r.endClosed && rowCmp(r.endKey, t.rows[endRow-1][:t.pkCols], t.pkDesc) == 0 {
901		endRow--
902	}
903
904	return startRow, endRow
905}
906
907// colIndexes returns the indexes for the named columns.
908func (t *table) colIndexes(cols []spansql.ID) ([]int, error) {
909	var is []int
910	for _, col := range cols {
911		i, ok := t.colIndex[col]
912		if !ok {
913			return nil, status.Errorf(codes.InvalidArgument, "column %s not in table", col)
914		}
915		is = append(is, i)
916	}
917	return is, nil
918}
919
920// primaryKey constructs the internal representation of a primary key.
921// The list of given values must be in 1:1 correspondence with the primary key of the table.
922func (t *table) primaryKey(values []*structpb.Value) ([]interface{}, error) {
923	if len(values) != t.pkCols {
924		return nil, status.Errorf(codes.InvalidArgument, "primary key length mismatch: got %d values, table has %d", len(values), t.pkCols)
925	}
926	return t.primaryKeyPrefix(values)
927}
928
929// primaryKeyPrefix constructs the internal representation of a primary key prefix.
930func (t *table) primaryKeyPrefix(values []*structpb.Value) ([]interface{}, error) {
931	if len(values) > t.pkCols {
932		return nil, status.Errorf(codes.InvalidArgument, "primary key length too long: got %d values, table has %d", len(values), t.pkCols)
933	}
934
935	var pk []interface{}
936	for i, value := range values {
937		v, err := valForType(value, t.cols[i].Type)
938		if err != nil {
939			return nil, err
940		}
941		pk = append(pk, v)
942	}
943	return pk, nil
944}
945
946// rowForPK returns the index of t.rows that holds the row for the given primary key, and true.
947// If the given primary key isn't found, it returns the row that should hold it, and false.
948func (t *table) rowForPK(pk []interface{}) (row int, found bool) {
949	if len(pk) != t.pkCols {
950		panic(fmt.Sprintf("primary key length mismatch: got %d values, table has %d", len(pk), t.pkCols))
951	}
952
953	i := sort.Search(len(t.rows), func(i int) bool {
954		return rowCmp(pk, t.rows[i][:t.pkCols], t.pkDesc) <= 0
955	})
956	if i == len(t.rows) {
957		return i, false
958	}
959	return i, rowEqual(pk, t.rows[i][:t.pkCols])
960}
961
962// rowCmp compares two rows, returning -1/0/+1.
963// The desc arg indicates whether each column is in a descending order.
964// This is used for primary key matching and so doesn't support array/struct types.
965// a is permitted to be shorter than b.
966func rowCmp(a, b []interface{}, desc []bool) int {
967	for i := 0; i < len(a); i++ {
968		if cmp := compareVals(a[i], b[i]); cmp != 0 {
969			if desc[i] {
970				cmp = -cmp
971			}
972			return cmp
973		}
974	}
975	return 0
976}
977
978// rowEqual reports whether two rows are equal.
979// This doesn't support array/struct types.
980func rowEqual(a, b []interface{}) bool {
981	for i := 0; i < len(a); i++ {
982		if compareVals(a[i], b[i]) != 0 {
983			return false
984		}
985	}
986	return true
987}
988
989// valForType converts a value from its RPC form into its internal representation.
990func valForType(v *structpb.Value, t spansql.Type) (interface{}, error) {
991	if _, ok := v.Kind.(*structpb.Value_NullValue); ok {
992		return nil, nil
993	}
994
995	if lv, ok := v.Kind.(*structpb.Value_ListValue); ok && t.Array {
996		et := t // element type
997		et.Array = false
998
999		// Construct the non-nil slice for the list.
1000		arr := make([]interface{}, 0, len(lv.ListValue.Values))
1001		for _, v := range lv.ListValue.Values {
1002			x, err := valForType(v, et)
1003			if err != nil {
1004				return nil, err
1005			}
1006			arr = append(arr, x)
1007		}
1008		return arr, nil
1009	}
1010
1011	switch t.Base {
1012	case spansql.Bool:
1013		bv, ok := v.Kind.(*structpb.Value_BoolValue)
1014		if ok {
1015			return bv.BoolValue, nil
1016		}
1017	case spansql.Int64:
1018		// The Spanner protocol encodes int64 as a decimal string.
1019		sv, ok := v.Kind.(*structpb.Value_StringValue)
1020		if ok {
1021			x, err := strconv.ParseInt(sv.StringValue, 10, 64)
1022			if err != nil {
1023				return nil, fmt.Errorf("bad int64 string %q: %v", sv.StringValue, err)
1024			}
1025			return x, nil
1026		}
1027	case spansql.Float64:
1028		nv, ok := v.Kind.(*structpb.Value_NumberValue)
1029		if ok {
1030			return nv.NumberValue, nil
1031		}
1032	case spansql.String:
1033		sv, ok := v.Kind.(*structpb.Value_StringValue)
1034		if ok {
1035			return sv.StringValue, nil
1036		}
1037	case spansql.Bytes:
1038		sv, ok := v.Kind.(*structpb.Value_StringValue)
1039		if ok {
1040			// The Spanner protocol encodes BYTES in base64.
1041			return base64.StdEncoding.DecodeString(sv.StringValue)
1042		}
1043	case spansql.Date:
1044		// The Spanner protocol encodes DATE in RFC 3339 date format.
1045		sv, ok := v.Kind.(*structpb.Value_StringValue)
1046		if ok {
1047			s := sv.StringValue
1048			d, err := parseAsDate(s)
1049			if err != nil {
1050				return nil, fmt.Errorf("bad DATE string %q: %v", s, err)
1051			}
1052			return d, nil
1053		}
1054	case spansql.Timestamp:
1055		// The Spanner protocol encodes TIMESTAMP in RFC 3339 timestamp format with zone Z.
1056		sv, ok := v.Kind.(*structpb.Value_StringValue)
1057		if ok {
1058			s := sv.StringValue
1059			if strings.ToLower(s) == "spanner.commit_timestamp()" {
1060				return commitTimestampSentinel, nil
1061			}
1062			t, err := parseAsTimestamp(s)
1063			if err != nil {
1064				return nil, fmt.Errorf("bad TIMESTAMP string %q: %v", s, err)
1065			}
1066			return t, nil
1067		}
1068	}
1069	return nil, fmt.Errorf("unsupported inserting value kind %T into column of type %s", v.Kind, t.SQL())
1070}
1071
1072type keyRange struct {
1073	start, end             *structpb.ListValue
1074	startClosed, endClosed bool
1075
1076	// These are populated during an operation
1077	// when we know what table this keyRange applies to.
1078	startKey, endKey []interface{}
1079}
1080
1081func (r *keyRange) String() string {
1082	var sb bytes.Buffer // TODO: Switch to strings.Builder when we drop support for Go 1.9.
1083	if r.startClosed {
1084		sb.WriteString("[")
1085	} else {
1086		sb.WriteString("(")
1087	}
1088	fmt.Fprintf(&sb, "%v,%v", r.startKey, r.endKey)
1089	if r.endClosed {
1090		sb.WriteString("]")
1091	} else {
1092		sb.WriteString(")")
1093	}
1094	return sb.String()
1095}
1096
1097type keyRangeList []*keyRange
1098
1099// Execute runs a DML statement.
1100// It returns the number of affected rows.
1101func (d *database) Execute(stmt spansql.DMLStmt, params queryParams) (int, error) { // TODO: return *status.Status instead?
1102	switch stmt := stmt.(type) {
1103	default:
1104		return 0, status.Errorf(codes.Unimplemented, "unhandled DML statement type %T", stmt)
1105	case *spansql.Delete:
1106		t, err := d.table(stmt.Table)
1107		if err != nil {
1108			return 0, err
1109		}
1110
1111		t.mu.Lock()
1112		defer t.mu.Unlock()
1113
1114		n := 0
1115		for i := 0; i < len(t.rows); {
1116			ec := evalContext{
1117				cols:   t.cols,
1118				row:    t.rows[i],
1119				params: params,
1120			}
1121			b, err := ec.evalBoolExpr(stmt.Where)
1122			if err != nil {
1123				return 0, err
1124			}
1125			if b != nil && *b {
1126				copy(t.rows[i:], t.rows[i+1:])
1127				t.rows = t.rows[:len(t.rows)-1]
1128				n++
1129				continue
1130			}
1131			i++
1132		}
1133		return n, nil
1134	case *spansql.Update:
1135		t, err := d.table(stmt.Table)
1136		if err != nil {
1137			return 0, err
1138		}
1139
1140		t.mu.Lock()
1141		defer t.mu.Unlock()
1142
1143		ec := evalContext{
1144			cols:   t.cols,
1145			params: params,
1146		}
1147
1148		// Build parallel slices of destination column index and expressions to evaluate.
1149		var dstIndex []int
1150		var expr []spansql.Expr
1151		for _, ui := range stmt.Items {
1152			i, err := ec.resolveColumnIndex(ui.Column)
1153			if err != nil {
1154				return 0, err
1155			}
1156			// TODO: Enforce "A column can appear only once in the SET clause.".
1157			if i < t.pkCols {
1158				return 0, status.Errorf(codes.InvalidArgument, "cannot update primary key %s", ui.Column)
1159			}
1160			dstIndex = append(dstIndex, i)
1161			expr = append(expr, ui.Value)
1162		}
1163
1164		n := 0
1165		values := make(row, len(stmt.Items)) // scratch space for new values
1166		for i := 0; i < len(t.rows); i++ {
1167			ec.row = t.rows[i]
1168			b, err := ec.evalBoolExpr(stmt.Where)
1169			if err != nil {
1170				return 0, err
1171			}
1172			if b != nil && *b {
1173				// Compute every update item.
1174				for j := range dstIndex {
1175					if expr[j] == nil { // DEFAULT
1176						values[j] = nil
1177						continue
1178					}
1179					v, err := ec.evalExpr(expr[j])
1180					if err != nil {
1181						return 0, err
1182					}
1183					values[j] = v
1184				}
1185				// Write them to the row.
1186				for j, v := range values {
1187					t.rows[i][dstIndex[j]] = v
1188				}
1189				n++
1190			}
1191		}
1192		return n, nil
1193	}
1194}
1195
1196func parseAsDate(s string) (civil.Date, error) { return civil.ParseDate(s) }
1197func parseAsTimestamp(s string) (time.Time, error) {
1198	return time.Parse("2006-01-02T15:04:05.999999999Z", s)
1199}
1200