1/*
2Copyright 2019 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spannertest
18
19// This file contains the implementation of the Spanner fake itself,
20// namely the part behind the RPC interface.
21
22// TODO: missing transactionality in a serious way!
23
24import (
25	"bytes"
26	"encoding/base64"
27	"fmt"
28	"sort"
29	"strconv"
30	"strings"
31	"sync"
32	"time"
33
34	"google.golang.org/grpc/codes"
35	"google.golang.org/grpc/status"
36
37	structpb "github.com/golang/protobuf/ptypes/struct"
38
39	"cloud.google.com/go/civil"
40	"cloud.google.com/go/spanner/spansql"
41)
42
43type database struct {
44	mu      sync.Mutex
45	lastTS  time.Time // last commit timestamp
46	tables  map[spansql.ID]*table
47	indexes map[spansql.ID]struct{} // only record their existence
48
49	rwMu sync.Mutex // held by read-write transactions
50}
51
52type table struct {
53	mu sync.Mutex
54
55	// Information about the table columns.
56	// They are reordered on table creation so the primary key columns come first.
57	cols      []colInfo
58	colIndex  map[spansql.ID]int // col name to index
59	origIndex map[spansql.ID]int // original index of each column upon construction
60	pkCols    int                // number of primary key columns (may be 0)
61	pkDesc    []bool             // whether each primary key column is in descending order
62
63	// Rows are stored in primary key order.
64	rows []row
65}
66
67// colInfo represents information about a column in a table or result set.
68type colInfo struct {
69	Name     spansql.ID
70	Type     spansql.Type
71	NotNull  bool            // only set for table columns
72	AggIndex int             // Index+1 of SELECT list for which this is an aggregate value.
73	Alias    spansql.PathExp // an alternate name for this column (result sets only)
74}
75
76// commitTimestampSentinel is a sentinel value for TIMESTAMP fields with allow_commit_timestamp=true.
77// It is accepted, but never stored.
78var commitTimestampSentinel = &struct{}{}
79
80// transaction records information about a running transaction.
81// This is not safe for concurrent use.
82type transaction struct {
83	// readOnly is whether this transaction was constructed
84	// for read-only use, and should yield errors if used
85	// to perform a mutation.
86	readOnly bool
87
88	d               *database
89	commitTimestamp time.Time // not set if readOnly
90	unlock          func()    // may be nil
91}
92
93func (d *database) NewReadOnlyTransaction() *transaction {
94	return &transaction{
95		readOnly: true,
96	}
97}
98
99func (d *database) NewTransaction() *transaction {
100	return &transaction{
101		d: d,
102	}
103}
104
105// Start starts the transaction and commits to a specific commit timestamp.
106// This also locks out any other read-write transaction on this database
107// until Commit/Rollback are called.
108func (tx *transaction) Start() {
109	// Commit timestamps are only guaranteed to be unique
110	// when transactions write to overlapping sets of fields.
111	// This simulated database exceeds that guarantee.
112
113	// Grab rwMu for the duration of this transaction.
114	// Take it before d.mu so we don't hold that lock
115	// while waiting for d.rwMu, which is held for longer.
116	tx.d.rwMu.Lock()
117
118	tx.d.mu.Lock()
119	const tsRes = 1 * time.Microsecond
120	now := time.Now().UTC().Truncate(tsRes)
121	if !now.After(tx.d.lastTS) {
122		now = tx.d.lastTS.Add(tsRes)
123	}
124	tx.d.lastTS = now
125	tx.d.mu.Unlock()
126
127	tx.commitTimestamp = now
128	tx.unlock = tx.d.rwMu.Unlock
129}
130
131func (tx *transaction) checkMutable() error {
132	if tx.readOnly {
133		// TODO: is this the right status?
134		return status.Errorf(codes.InvalidArgument, "transaction is read-only")
135	}
136	return nil
137}
138
139func (tx *transaction) Commit() (time.Time, error) {
140	if tx.unlock != nil {
141		tx.unlock()
142	}
143	return tx.commitTimestamp, nil
144}
145
146func (tx *transaction) Rollback() {
147	if tx.unlock != nil {
148		tx.unlock()
149	}
150	// TODO: actually rollback
151}
152
153/*
154row represents a list of data elements.
155
156The mapping between Spanner types and Go types internal to this package are:
157	BOOL		bool
158	INT64		int64
159	FLOAT64		float64
160	STRING		string
161	BYTES		[]byte
162	DATE		civil.Date
163	TIMESTAMP	time.Time (location set to UTC)
164	ARRAY<T>	[]interface{}
165	STRUCT		TODO
166*/
167type row []interface{}
168
169func (r row) copyDataElem(index int) interface{} {
170	v := r[index]
171	if is, ok := v.([]interface{}); ok {
172		// Deep-copy array values.
173		v = append([]interface{}(nil), is...)
174	}
175	return v
176}
177
178// copyData returns a copy of the row.
179func (r row) copyAllData() row {
180	dst := make(row, 0, len(r))
181	for i := range r {
182		dst = append(dst, r.copyDataElem(i))
183	}
184	return dst
185}
186
187// copyData returns a copy of a subset of a row.
188func (r row) copyData(indexes []int) row {
189	if len(indexes) == 0 {
190		return nil
191	}
192	dst := make(row, 0, len(indexes))
193	for _, i := range indexes {
194		dst = append(dst, r.copyDataElem(i))
195	}
196	return dst
197}
198
199func (d *database) LastCommitTimestamp() time.Time {
200	d.mu.Lock()
201	defer d.mu.Unlock()
202	return d.lastTS
203}
204
205func (d *database) GetDDL() []spansql.DDLStmt {
206	// This lacks fidelity, but captures the details we support.
207	d.mu.Lock()
208	defer d.mu.Unlock()
209
210	var stmts []spansql.DDLStmt
211
212	for name, t := range d.tables {
213		ct := &spansql.CreateTable{
214			Name: name,
215		}
216
217		t.mu.Lock()
218		for i, col := range t.cols {
219			ct.Columns = append(ct.Columns, spansql.ColumnDef{
220				Name:    col.Name,
221				Type:    col.Type,
222				NotNull: col.NotNull,
223				// TODO: AllowCommitTimestamp
224			})
225			if i < t.pkCols {
226				ct.PrimaryKey = append(ct.PrimaryKey, spansql.KeyPart{
227					Column: col.Name,
228					Desc:   t.pkDesc[i],
229				})
230			}
231		}
232		t.mu.Unlock()
233
234		stmts = append(stmts, ct)
235	}
236
237	return stmts
238}
239
240func (d *database) ApplyDDL(stmt spansql.DDLStmt) *status.Status {
241	d.mu.Lock()
242	defer d.mu.Unlock()
243
244	// Lazy init.
245	if d.tables == nil {
246		d.tables = make(map[spansql.ID]*table)
247	}
248	if d.indexes == nil {
249		d.indexes = make(map[spansql.ID]struct{})
250	}
251
252	switch stmt := stmt.(type) {
253	default:
254		return status.Newf(codes.Unimplemented, "unhandled DDL statement type %T", stmt)
255	case *spansql.CreateTable:
256		if _, ok := d.tables[stmt.Name]; ok {
257			return status.Newf(codes.AlreadyExists, "table %s already exists", stmt.Name)
258		}
259		if len(stmt.PrimaryKey) == 0 {
260			return status.Newf(codes.InvalidArgument, "table %s has no primary key", stmt.Name)
261		}
262
263		// TODO: check stmt.Interleave details.
264
265		// Record original column ordering.
266		orig := make(map[spansql.ID]int)
267		for i, col := range stmt.Columns {
268			orig[col.Name] = i
269		}
270
271		// Move primary keys first, preserving their order.
272		pk := make(map[spansql.ID]int)
273		var pkDesc []bool
274		for i, kp := range stmt.PrimaryKey {
275			pk[kp.Column] = -1000 + i
276			pkDesc = append(pkDesc, kp.Desc)
277		}
278		sort.SliceStable(stmt.Columns, func(i, j int) bool {
279			a, b := pk[stmt.Columns[i].Name], pk[stmt.Columns[j].Name]
280			return a < b
281		})
282
283		t := &table{
284			colIndex:  make(map[spansql.ID]int),
285			origIndex: orig,
286			pkCols:    len(pk),
287			pkDesc:    pkDesc,
288		}
289		for _, cd := range stmt.Columns {
290			if st := t.addColumn(cd, true); st.Code() != codes.OK {
291				return st
292			}
293		}
294		for col := range pk {
295			if _, ok := t.colIndex[col]; !ok {
296				return status.Newf(codes.InvalidArgument, "primary key column %q not in table", col)
297			}
298		}
299		d.tables[stmt.Name] = t
300		return nil
301	case *spansql.CreateIndex:
302		if _, ok := d.indexes[stmt.Name]; ok {
303			return status.Newf(codes.AlreadyExists, "index %s already exists", stmt.Name)
304		}
305		d.indexes[stmt.Name] = struct{}{}
306		return nil
307	case *spansql.DropTable:
308		if _, ok := d.tables[stmt.Name]; !ok {
309			return status.Newf(codes.NotFound, "no table named %s", stmt.Name)
310		}
311		// TODO: check for indexes on this table.
312		delete(d.tables, stmt.Name)
313		return nil
314	case *spansql.DropIndex:
315		if _, ok := d.indexes[stmt.Name]; !ok {
316			return status.Newf(codes.NotFound, "no index named %s", stmt.Name)
317		}
318		delete(d.indexes, stmt.Name)
319		return nil
320	case *spansql.AlterTable:
321		t, ok := d.tables[stmt.Name]
322		if !ok {
323			return status.Newf(codes.NotFound, "no table named %s", stmt.Name)
324		}
325		switch alt := stmt.Alteration.(type) {
326		default:
327			return status.Newf(codes.Unimplemented, "unhandled DDL table alteration type %T", alt)
328		case spansql.AddColumn:
329			if st := t.addColumn(alt.Def, false); st.Code() != codes.OK {
330				return st
331			}
332			return nil
333		case spansql.DropColumn:
334			if st := t.dropColumn(alt.Name); st.Code() != codes.OK {
335				return st
336			}
337			return nil
338		case spansql.AlterColumn:
339			if st := t.alterColumn(alt); st.Code() != codes.OK {
340				return st
341			}
342			return nil
343		}
344	}
345
346}
347
348func (d *database) table(tbl spansql.ID) (*table, error) {
349	d.mu.Lock()
350	defer d.mu.Unlock()
351
352	t, ok := d.tables[tbl]
353	if !ok {
354		return nil, status.Errorf(codes.NotFound, "no table named %s", tbl)
355	}
356	return t, nil
357}
358
359// writeValues executes a write option (Insert, Update, etc.).
360func (d *database) writeValues(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue, f func(t *table, colIndexes []int, r row) error) error {
361	if err := tx.checkMutable(); err != nil {
362		return err
363	}
364
365	t, err := d.table(tbl)
366	if err != nil {
367		return err
368	}
369
370	t.mu.Lock()
371	defer t.mu.Unlock()
372
373	colIndexes, err := t.colIndexes(cols)
374	if err != nil {
375		return err
376	}
377	revIndex := make(map[int]int) // table index to col index
378	for j, i := range colIndexes {
379		revIndex[i] = j
380	}
381
382	for pki := 0; pki < t.pkCols; pki++ {
383		_, ok := revIndex[pki]
384		if !ok {
385			return status.Errorf(codes.InvalidArgument, "primary key column %s not included in write", t.cols[pki].Name)
386		}
387	}
388
389	for _, vs := range values {
390		if len(vs.Values) != len(colIndexes) {
391			return status.Errorf(codes.InvalidArgument, "row of %d values can't be written to %d columns", len(vs.Values), len(colIndexes))
392		}
393
394		r := make(row, len(t.cols))
395		for j, v := range vs.Values {
396			i := colIndexes[j]
397
398			x, err := valForType(v, t.cols[i].Type)
399			if err != nil {
400				return err
401			}
402			if x == commitTimestampSentinel {
403				x = tx.commitTimestamp
404			}
405			if x == nil && t.cols[i].NotNull {
406				return status.Errorf(codes.FailedPrecondition, "%s must not be NULL in table %s", t.cols[i].Name, tbl)
407			}
408
409			r[i] = x
410		}
411		// TODO: enforce that provided timestamp for commit_timestamp=true columns
412		// are not ahead of the transaction's commit timestamp.
413
414		if err := f(t, colIndexes, r); err != nil {
415			return err
416		}
417	}
418
419	return nil
420}
421
422func (d *database) Insert(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue) error {
423	return d.writeValues(tx, tbl, cols, values, func(t *table, colIndexes []int, r row) error {
424		pk := r[:t.pkCols]
425		rowNum, found := t.rowForPK(pk)
426		if found {
427			return status.Errorf(codes.AlreadyExists, "row already in table")
428		}
429		t.insertRow(rowNum, r)
430		return nil
431	})
432}
433
434func (d *database) Update(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue) error {
435	return d.writeValues(tx, tbl, cols, values, func(t *table, colIndexes []int, r row) error {
436		if t.pkCols == 0 {
437			return status.Errorf(codes.InvalidArgument, "cannot update table %s with no columns in primary key", tbl)
438		}
439		pk := r[:t.pkCols]
440		rowNum, found := t.rowForPK(pk)
441		if !found {
442			// TODO: is this the right way to return `NOT_FOUND`?
443			return status.Errorf(codes.NotFound, "row not in table")
444		}
445
446		for _, i := range colIndexes {
447			t.rows[rowNum][i] = r[i]
448		}
449		return nil
450	})
451}
452
453func (d *database) InsertOrUpdate(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue) error {
454	return d.writeValues(tx, tbl, cols, values, func(t *table, colIndexes []int, r row) error {
455		pk := r[:t.pkCols]
456		rowNum, found := t.rowForPK(pk)
457		if !found {
458			// New row; do an insert.
459			t.insertRow(rowNum, r)
460		} else {
461			// Existing row; do an update.
462			for _, i := range colIndexes {
463				t.rows[rowNum][i] = r[i]
464			}
465		}
466		return nil
467	})
468}
469
470// TODO: Replace
471
472func (d *database) Delete(tx *transaction, table spansql.ID, keys []*structpb.ListValue, keyRanges keyRangeList, all bool) error {
473	if err := tx.checkMutable(); err != nil {
474		return err
475	}
476
477	t, err := d.table(table)
478	if err != nil {
479		return err
480	}
481
482	t.mu.Lock()
483	defer t.mu.Unlock()
484
485	if all {
486		t.rows = nil
487		return nil
488	}
489
490	for _, key := range keys {
491		pk, err := t.primaryKey(key.Values)
492		if err != nil {
493			return err
494		}
495		// Not an error if the key does not exist.
496		rowNum, found := t.rowForPK(pk)
497		if found {
498			copy(t.rows[rowNum:], t.rows[rowNum+1:])
499			t.rows = t.rows[:len(t.rows)-1]
500		}
501	}
502
503	for _, r := range keyRanges {
504		r.startKey, err = t.primaryKeyPrefix(r.start.Values)
505		if err != nil {
506			return err
507		}
508		r.endKey, err = t.primaryKeyPrefix(r.end.Values)
509		if err != nil {
510			return err
511		}
512		startRow, endRow := t.findRange(r)
513		if n := endRow - startRow; n > 0 {
514			copy(t.rows[startRow:], t.rows[endRow:])
515			t.rows = t.rows[:len(t.rows)-n]
516		}
517	}
518
519	return nil
520}
521
522// readTable executes a read option (Read, ReadAll).
523func (d *database) readTable(table spansql.ID, cols []spansql.ID, f func(*table, *rawIter, []int) error) (*rawIter, error) {
524	t, err := d.table(table)
525	if err != nil {
526		return nil, err
527	}
528
529	t.mu.Lock()
530	defer t.mu.Unlock()
531
532	colIndexes, err := t.colIndexes(cols)
533	if err != nil {
534		return nil, err
535	}
536
537	ri := &rawIter{}
538	for _, i := range colIndexes {
539		ri.cols = append(ri.cols, t.cols[i])
540	}
541	return ri, f(t, ri, colIndexes)
542}
543
544func (d *database) Read(tbl spansql.ID, cols []spansql.ID, keys []*structpb.ListValue, keyRanges keyRangeList, limit int64) (rowIter, error) {
545	// The real Cloud Spanner returns an error if the key set is empty by definition.
546	// That doesn't seem to be well-defined, but it is a common error to attempt a read with no keys,
547	// so catch that here and return a representative error.
548	if len(keys) == 0 && len(keyRanges) == 0 {
549		return nil, status.Error(codes.Unimplemented, "Cloud Spanner does not support reading no keys")
550	}
551
552	return d.readTable(tbl, cols, func(t *table, ri *rawIter, colIndexes []int) error {
553		// "If the same key is specified multiple times in the set (for
554		// example if two ranges, two keys, or a key and a range
555		// overlap), Cloud Spanner behaves as if the key were only
556		// specified once."
557		done := make(map[int]bool) // row numbers we've included in ri.
558
559		// Specific keys.
560		for _, key := range keys {
561			pk, err := t.primaryKey(key.Values)
562			if err != nil {
563				return err
564			}
565			// Not an error if the key does not exist.
566			rowNum, found := t.rowForPK(pk)
567			if !found {
568				continue
569			}
570			if done[rowNum] {
571				continue
572			}
573			done[rowNum] = true
574			ri.add(t.rows[rowNum], colIndexes)
575			if limit > 0 && len(ri.rows) >= int(limit) {
576				return nil
577			}
578		}
579
580		// Key ranges.
581		for _, r := range keyRanges {
582			var err error
583			r.startKey, err = t.primaryKeyPrefix(r.start.Values)
584			if err != nil {
585				return err
586			}
587			r.endKey, err = t.primaryKeyPrefix(r.end.Values)
588			if err != nil {
589				return err
590			}
591			startRow, endRow := t.findRange(r)
592			for rowNum := startRow; rowNum < endRow; rowNum++ {
593				if done[rowNum] {
594					continue
595				}
596				done[rowNum] = true
597				ri.add(t.rows[rowNum], colIndexes)
598				if limit > 0 && len(ri.rows) >= int(limit) {
599					return nil
600				}
601			}
602		}
603
604		return nil
605	})
606}
607
608func (d *database) ReadAll(tbl spansql.ID, cols []spansql.ID, limit int64) (*rawIter, error) {
609	return d.readTable(tbl, cols, func(t *table, ri *rawIter, colIndexes []int) error {
610		for _, r := range t.rows {
611			ri.add(r, colIndexes)
612			if limit > 0 && len(ri.rows) >= int(limit) {
613				break
614			}
615		}
616		return nil
617	})
618}
619
620func (t *table) addColumn(cd spansql.ColumnDef, newTable bool) *status.Status {
621	if !newTable && cd.NotNull {
622		return status.Newf(codes.InvalidArgument, "new non-key columns cannot be NOT NULL")
623	}
624
625	if _, ok := t.colIndex[cd.Name]; ok {
626		return status.Newf(codes.AlreadyExists, "column %s already exists", cd.Name)
627	}
628
629	t.mu.Lock()
630	defer t.mu.Unlock()
631
632	if len(t.rows) > 0 {
633		if cd.NotNull {
634			// TODO: what happens in this case?
635			return status.Newf(codes.Unimplemented, "can't add NOT NULL columns to non-empty tables yet")
636		}
637		for i := range t.rows {
638			t.rows[i] = append(t.rows[i], nil)
639		}
640	}
641
642	t.cols = append(t.cols, colInfo{
643		Name:    cd.Name,
644		Type:    cd.Type,
645		NotNull: cd.NotNull,
646	})
647	t.colIndex[cd.Name] = len(t.cols) - 1
648	if !newTable {
649		t.origIndex[cd.Name] = len(t.cols) - 1
650	}
651
652	return nil
653}
654
655func (t *table) dropColumn(name spansql.ID) *status.Status {
656	// Only permit dropping non-key columns that aren't part of a secondary index.
657	// We don't support indexes, so only check that it isn't part of the primary key.
658
659	t.mu.Lock()
660	defer t.mu.Unlock()
661
662	ci, ok := t.colIndex[name]
663	if !ok {
664		// TODO: What's the right response code?
665		return status.Newf(codes.InvalidArgument, "unknown column %q", name)
666	}
667	if ci < t.pkCols {
668		// TODO: What's the right response code?
669		return status.Newf(codes.InvalidArgument, "can't drop primary key column %q", name)
670	}
671
672	// Remove from cols and colIndex, and renumber colIndex and origIndex.
673	t.cols = append(t.cols[:ci], t.cols[ci+1:]...)
674	delete(t.colIndex, name)
675	for i, col := range t.cols {
676		t.colIndex[col.Name] = i
677	}
678	pre := t.origIndex[name]
679	delete(t.origIndex, name)
680	for n, i := range t.origIndex {
681		if i > pre {
682			t.origIndex[n]--
683		}
684	}
685
686	// Drop data.
687	for i := range t.rows {
688		t.rows[i] = append(t.rows[i][:ci], t.rows[i][ci+1:]...)
689	}
690
691	return nil
692}
693
694func (t *table) alterColumn(alt spansql.AlterColumn) *status.Status {
695	// Supported changes here are:
696	//	Add NOT NULL to a non-key column, excluding ARRAY columns.
697	//	Remove NOT NULL from a non-key column.
698	//	Change a STRING column to a BYTES column or a BYTES column to a STRING column.
699	//	Increase or decrease the length limit for a STRING or BYTES type (including to MAX).
700	//	Enable or disable commit timestamps in value and primary key columns.
701	// https://cloud.google.com/spanner/docs/schema-updates#supported-updates
702
703	// TODO: codes.InvalidArgument is used throughout here for reporting errors,
704	// but that has not been validated against the real Spanner.
705
706	sct, ok := alt.Alteration.(spansql.SetColumnType)
707	if !ok {
708		return status.Newf(codes.InvalidArgument, "unsupported ALTER COLUMN %s", alt.SQL())
709	}
710
711	t.mu.Lock()
712	defer t.mu.Unlock()
713
714	ci, ok := t.colIndex[alt.Name]
715	if !ok {
716		return status.Newf(codes.InvalidArgument, "unknown column %q", alt.Name)
717	}
718
719	oldT, newT := t.cols[ci].Type, sct.Type
720	stringOrBytes := func(bt spansql.TypeBase) bool { return bt == spansql.String || bt == spansql.Bytes }
721
722	// First phase: Check the validity of the change.
723	// TODO: Don't permit changes to allow commit timestamps.
724	if !t.cols[ci].NotNull && sct.NotNull {
725		// Adding NOT NULL is not permitted for primary key columns or array typed columns.
726		if ci < t.pkCols {
727			return status.Newf(codes.InvalidArgument, "cannot set NOT NULL on primary key column %q", alt.Name)
728		}
729		if oldT.Array {
730			return status.Newf(codes.InvalidArgument, "cannot set NOT NULL on array-typed column %q", alt.Name)
731		}
732		// Validate that there are no NULL values.
733		for _, row := range t.rows {
734			if row[ci] == nil {
735				return status.Newf(codes.InvalidArgument, "cannot set NOT NULL on column %q that contains NULL values", alt.Name)
736			}
737		}
738	}
739	var conv func(x interface{}) interface{}
740	if stringOrBytes(oldT.Base) && stringOrBytes(newT.Base) && !oldT.Array && !newT.Array {
741		// Change between STRING and BYTES is fine, as is increasing/decreasing the length limit.
742		// TODO: This should permit array conversions too.
743		// TODO: Validate data; length limit changes should be rejected if they'd lead to data loss, for instance.
744		if oldT.Base == spansql.Bytes && newT.Base == spansql.String {
745			conv = func(x interface{}) interface{} { return string(x.([]byte)) }
746		} else if oldT.Base == spansql.String && newT.Base == spansql.Bytes {
747			conv = func(x interface{}) interface{} { return []byte(x.(string)) }
748		}
749	} else if oldT == newT {
750		// Same type; only NOT NULL changes.
751	} else { // TODO: Support other alterations.
752		return status.Newf(codes.InvalidArgument, "unsupported ALTER COLUMN %s", alt.SQL())
753	}
754
755	// Second phase: Make type transformations.
756	t.cols[ci].NotNull = sct.NotNull
757	t.cols[ci].Type = newT
758	if conv != nil {
759		for _, row := range t.rows {
760			if row[ci] != nil { // NULL stays as NULL.
761				row[ci] = conv(row[ci])
762			}
763		}
764	}
765	return nil
766}
767
768func (t *table) insertRow(rowNum int, r row) {
769	t.rows = append(t.rows, nil)
770	copy(t.rows[rowNum+1:], t.rows[rowNum:])
771	t.rows[rowNum] = r
772}
773
774// findRange finds the rows included in the key range,
775// reporting it as a half-open interval.
776// r.startKey and r.endKey should be populated.
777func (t *table) findRange(r *keyRange) (int, int) {
778	// startRow is the first row matching the range.
779	startRow := sort.Search(len(t.rows), func(i int) bool {
780		return rowCmp(r.startKey, t.rows[i][:t.pkCols], t.pkDesc) <= 0
781	})
782	if startRow == len(t.rows) {
783		return startRow, startRow
784	}
785	if !r.startClosed && rowCmp(r.startKey, t.rows[startRow][:t.pkCols], t.pkDesc) == 0 {
786		startRow++
787	}
788
789	// endRow is one more than the last row matching the range.
790	endRow := sort.Search(len(t.rows), func(i int) bool {
791		return rowCmp(r.endKey, t.rows[i][:t.pkCols], t.pkDesc) < 0
792	})
793	if !r.endClosed && rowCmp(r.endKey, t.rows[endRow-1][:t.pkCols], t.pkDesc) == 0 {
794		endRow--
795	}
796
797	return startRow, endRow
798}
799
800// colIndexes returns the indexes for the named columns.
801func (t *table) colIndexes(cols []spansql.ID) ([]int, error) {
802	var is []int
803	for _, col := range cols {
804		i, ok := t.colIndex[col]
805		if !ok {
806			return nil, status.Errorf(codes.InvalidArgument, "column %s not in table", col)
807		}
808		is = append(is, i)
809	}
810	return is, nil
811}
812
813// primaryKey constructs the internal representation of a primary key.
814// The list of given values must be in 1:1 correspondence with the primary key of the table.
815func (t *table) primaryKey(values []*structpb.Value) ([]interface{}, error) {
816	if len(values) != t.pkCols {
817		return nil, status.Errorf(codes.InvalidArgument, "primary key length mismatch: got %d values, table has %d", len(values), t.pkCols)
818	}
819	return t.primaryKeyPrefix(values)
820}
821
822// primaryKeyPrefix constructs the internal representation of a primary key prefix.
823func (t *table) primaryKeyPrefix(values []*structpb.Value) ([]interface{}, error) {
824	if len(values) > t.pkCols {
825		return nil, status.Errorf(codes.InvalidArgument, "primary key length too long: got %d values, table has %d", len(values), t.pkCols)
826	}
827
828	var pk []interface{}
829	for i, value := range values {
830		v, err := valForType(value, t.cols[i].Type)
831		if err != nil {
832			return nil, err
833		}
834		pk = append(pk, v)
835	}
836	return pk, nil
837}
838
839// rowForPK returns the index of t.rows that holds the row for the given primary key, and true.
840// If the given primary key isn't found, it returns the row that should hold it, and false.
841func (t *table) rowForPK(pk []interface{}) (row int, found bool) {
842	if len(pk) != t.pkCols {
843		panic(fmt.Sprintf("primary key length mismatch: got %d values, table has %d", len(pk), t.pkCols))
844	}
845
846	i := sort.Search(len(t.rows), func(i int) bool {
847		return rowCmp(pk, t.rows[i][:t.pkCols], t.pkDesc) <= 0
848	})
849	if i == len(t.rows) {
850		return i, false
851	}
852	return i, rowEqual(pk, t.rows[i][:t.pkCols])
853}
854
855// rowCmp compares two rows, returning -1/0/+1.
856// The desc arg indicates whether each column is in a descending order.
857// This is used for primary key matching and so doesn't support array/struct types.
858// a is permitted to be shorter than b.
859func rowCmp(a, b []interface{}, desc []bool) int {
860	for i := 0; i < len(a); i++ {
861		if cmp := compareVals(a[i], b[i]); cmp != 0 {
862			if desc[i] {
863				cmp = -cmp
864			}
865			return cmp
866		}
867	}
868	return 0
869}
870
871// rowEqual reports whether two rows are equal.
872// This doesn't support array/struct types.
873func rowEqual(a, b []interface{}) bool {
874	for i := 0; i < len(a); i++ {
875		if compareVals(a[i], b[i]) != 0 {
876			return false
877		}
878	}
879	return true
880}
881
882// valForType converts a value from its RPC form into its internal representation.
883func valForType(v *structpb.Value, t spansql.Type) (interface{}, error) {
884	if _, ok := v.Kind.(*structpb.Value_NullValue); ok {
885		return nil, nil
886	}
887
888	if lv, ok := v.Kind.(*structpb.Value_ListValue); ok && t.Array {
889		et := t // element type
890		et.Array = false
891
892		// Construct the non-nil slice for the list.
893		arr := make([]interface{}, 0, len(lv.ListValue.Values))
894		for _, v := range lv.ListValue.Values {
895			x, err := valForType(v, et)
896			if err != nil {
897				return nil, err
898			}
899			arr = append(arr, x)
900		}
901		return arr, nil
902	}
903
904	switch t.Base {
905	case spansql.Bool:
906		bv, ok := v.Kind.(*structpb.Value_BoolValue)
907		if ok {
908			return bv.BoolValue, nil
909		}
910	case spansql.Int64:
911		// The Spanner protocol encodes int64 as a decimal string.
912		sv, ok := v.Kind.(*structpb.Value_StringValue)
913		if ok {
914			x, err := strconv.ParseInt(sv.StringValue, 10, 64)
915			if err != nil {
916				return nil, fmt.Errorf("bad int64 string %q: %v", sv.StringValue, err)
917			}
918			return x, nil
919		}
920	case spansql.Float64:
921		nv, ok := v.Kind.(*structpb.Value_NumberValue)
922		if ok {
923			return nv.NumberValue, nil
924		}
925	case spansql.String:
926		sv, ok := v.Kind.(*structpb.Value_StringValue)
927		if ok {
928			return sv.StringValue, nil
929		}
930	case spansql.Bytes:
931		sv, ok := v.Kind.(*structpb.Value_StringValue)
932		if ok {
933			// The Spanner protocol encodes BYTES in base64.
934			return base64.StdEncoding.DecodeString(sv.StringValue)
935		}
936	case spansql.Date:
937		// The Spanner protocol encodes DATE in RFC 3339 date format.
938		sv, ok := v.Kind.(*structpb.Value_StringValue)
939		if ok {
940			s := sv.StringValue
941			d, err := parseAsDate(s)
942			if err != nil {
943				return nil, fmt.Errorf("bad DATE string %q: %v", s, err)
944			}
945			return d, nil
946		}
947	case spansql.Timestamp:
948		// The Spanner protocol encodes TIMESTAMP in RFC 3339 timestamp format with zone Z.
949		sv, ok := v.Kind.(*structpb.Value_StringValue)
950		if ok {
951			s := sv.StringValue
952			if strings.ToLower(s) == "spanner.commit_timestamp()" {
953				return commitTimestampSentinel, nil
954			}
955			t, err := parseAsTimestamp(s)
956			if err != nil {
957				return nil, fmt.Errorf("bad TIMESTAMP string %q: %v", s, err)
958			}
959			return t, nil
960		}
961	}
962	return nil, fmt.Errorf("unsupported inserting value kind %T into column of type %s", v.Kind, t.SQL())
963}
964
965type keyRange struct {
966	start, end             *structpb.ListValue
967	startClosed, endClosed bool
968
969	// These are populated during an operation
970	// when we know what table this keyRange applies to.
971	startKey, endKey []interface{}
972}
973
974func (r *keyRange) String() string {
975	var sb bytes.Buffer // TODO: Switch to strings.Builder when we drop support for Go 1.9.
976	if r.startClosed {
977		sb.WriteString("[")
978	} else {
979		sb.WriteString("(")
980	}
981	fmt.Fprintf(&sb, "%v,%v", r.startKey, r.endKey)
982	if r.endClosed {
983		sb.WriteString("]")
984	} else {
985		sb.WriteString(")")
986	}
987	return sb.String()
988}
989
990type keyRangeList []*keyRange
991
992// Execute runs a DML statement.
993// It returns the number of affected rows.
994func (d *database) Execute(stmt spansql.DMLStmt, params queryParams) (int, error) { // TODO: return *status.Status instead?
995	switch stmt := stmt.(type) {
996	default:
997		return 0, status.Errorf(codes.Unimplemented, "unhandled DML statement type %T", stmt)
998	case *spansql.Delete:
999		t, err := d.table(stmt.Table)
1000		if err != nil {
1001			return 0, err
1002		}
1003
1004		t.mu.Lock()
1005		defer t.mu.Unlock()
1006
1007		n := 0
1008		for i := 0; i < len(t.rows); {
1009			ec := evalContext{
1010				cols:   t.cols,
1011				row:    t.rows[i],
1012				params: params,
1013			}
1014			b, err := ec.evalBoolExpr(stmt.Where)
1015			if err != nil {
1016				return 0, err
1017			}
1018			if b != nil && *b {
1019				copy(t.rows[i:], t.rows[i+1:])
1020				t.rows = t.rows[:len(t.rows)-1]
1021				n++
1022				continue
1023			}
1024			i++
1025		}
1026		return n, nil
1027	case *spansql.Update:
1028		t, err := d.table(stmt.Table)
1029		if err != nil {
1030			return 0, err
1031		}
1032
1033		t.mu.Lock()
1034		defer t.mu.Unlock()
1035
1036		ec := evalContext{
1037			cols:   t.cols,
1038			params: params,
1039		}
1040
1041		// Build parallel slices of destination column index and expressions to evaluate.
1042		var dstIndex []int
1043		var expr []spansql.Expr
1044		for _, ui := range stmt.Items {
1045			i, err := ec.resolveColumnIndex(ui.Column)
1046			if err != nil {
1047				return 0, err
1048			}
1049			// TODO: Enforce "A column can appear only once in the SET clause.".
1050			if i < t.pkCols {
1051				return 0, status.Errorf(codes.InvalidArgument, "cannot update primary key %s", ui.Column)
1052			}
1053			dstIndex = append(dstIndex, i)
1054			expr = append(expr, ui.Value)
1055		}
1056
1057		n := 0
1058		values := make(row, len(stmt.Items)) // scratch space for new values
1059		for i := 0; i < len(t.rows); i++ {
1060			ec.row = t.rows[i]
1061			b, err := ec.evalBoolExpr(stmt.Where)
1062			if err != nil {
1063				return 0, err
1064			}
1065			if b != nil && *b {
1066				// Compute every update item.
1067				for j := range dstIndex {
1068					if expr[j] == nil { // DEFAULT
1069						values[j] = nil
1070						continue
1071					}
1072					v, err := ec.evalExpr(expr[j])
1073					if err != nil {
1074						return 0, err
1075					}
1076					values[j] = v
1077				}
1078				// Write them to the row.
1079				for j, v := range values {
1080					t.rows[i][dstIndex[j]] = v
1081				}
1082				n++
1083			}
1084		}
1085		return n, nil
1086	}
1087}
1088
1089func parseAsDate(s string) (civil.Date, error) { return civil.ParseDate(s) }
1090func parseAsTimestamp(s string) (time.Time, error) {
1091	return time.Parse("2006-01-02T15:04:05.999999999Z", s)
1092}
1093