1/*
2Copyright 2019 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spannertest
18
19// This file contains the implementation of the Spanner fake itself,
20// namely the part behind the RPC interface.
21
22// TODO: missing transactionality in a serious way!
23
24import (
25	"bytes"
26	"encoding/base64"
27	"fmt"
28	"sort"
29	"strconv"
30	"strings"
31	"sync"
32	"time"
33
34	"google.golang.org/grpc/codes"
35	"google.golang.org/grpc/status"
36
37	structpb "github.com/golang/protobuf/ptypes/struct"
38
39	"cloud.google.com/go/civil"
40	"cloud.google.com/go/spanner/spansql"
41)
42
43type database struct {
44	mu      sync.Mutex
45	lastTS  time.Time // last commit timestamp
46	tables  map[spansql.ID]*table
47	indexes map[spansql.ID]struct{} // only record their existence
48
49	rwMu sync.Mutex // held by read-write transactions
50}
51
52type table struct {
53	mu sync.Mutex
54
55	// Information about the table columns.
56	// They are reordered on table creation so the primary key columns come first.
57	cols      []colInfo
58	colIndex  map[spansql.ID]int // col name to index
59	origIndex map[spansql.ID]int // original index of each column upon construction
60	pkCols    int                // number of primary key columns (may be 0)
61	pkDesc    []bool             // whether each primary key column is in descending order
62
63	// Rows are stored in primary key order.
64	rows []row
65}
66
67// colInfo represents information about a column in a table or result set.
68type colInfo struct {
69	Name      spansql.ID
70	Type      spansql.Type
71	Generated spansql.Expr
72	NotNull   bool            // only set for table columns
73	AggIndex  int             // Index+1 of SELECT list for which this is an aggregate value.
74	Alias     spansql.PathExp // an alternate name for this column (result sets only)
75}
76
77// commitTimestampSentinel is a sentinel value for TIMESTAMP fields with allow_commit_timestamp=true.
78// It is accepted, but never stored.
79var commitTimestampSentinel = &struct{}{}
80
81// transaction records information about a running transaction.
82// This is not safe for concurrent use.
83type transaction struct {
84	// readOnly is whether this transaction was constructed
85	// for read-only use, and should yield errors if used
86	// to perform a mutation.
87	readOnly bool
88
89	d               *database
90	commitTimestamp time.Time // not set if readOnly
91	unlock          func()    // may be nil
92}
93
94func (d *database) NewReadOnlyTransaction() *transaction {
95	return &transaction{
96		readOnly: true,
97	}
98}
99
100func (d *database) NewTransaction() *transaction {
101	return &transaction{
102		d: d,
103	}
104}
105
106// Start starts the transaction and commits to a specific commit timestamp.
107// This also locks out any other read-write transaction on this database
108// until Commit/Rollback are called.
109func (tx *transaction) Start() {
110	// Commit timestamps are only guaranteed to be unique
111	// when transactions write to overlapping sets of fields.
112	// This simulated database exceeds that guarantee.
113
114	// Grab rwMu for the duration of this transaction.
115	// Take it before d.mu so we don't hold that lock
116	// while waiting for d.rwMu, which is held for longer.
117	tx.d.rwMu.Lock()
118
119	tx.d.mu.Lock()
120	const tsRes = 1 * time.Microsecond
121	now := time.Now().UTC().Truncate(tsRes)
122	if !now.After(tx.d.lastTS) {
123		now = tx.d.lastTS.Add(tsRes)
124	}
125	tx.d.lastTS = now
126	tx.d.mu.Unlock()
127
128	tx.commitTimestamp = now
129	tx.unlock = tx.d.rwMu.Unlock
130}
131
132func (tx *transaction) checkMutable() error {
133	if tx.readOnly {
134		// TODO: is this the right status?
135		return status.Errorf(codes.InvalidArgument, "transaction is read-only")
136	}
137	return nil
138}
139
140func (tx *transaction) Commit() (time.Time, error) {
141	if tx.unlock != nil {
142		tx.unlock()
143	}
144	return tx.commitTimestamp, nil
145}
146
147func (tx *transaction) Rollback() {
148	if tx.unlock != nil {
149		tx.unlock()
150	}
151	// TODO: actually rollback
152}
153
154/*
155row represents a list of data elements.
156
157The mapping between Spanner types and Go types internal to this package are:
158	BOOL		bool
159	INT64		int64
160	FLOAT64		float64
161	STRING		string
162	BYTES		[]byte
163	DATE		civil.Date
164	TIMESTAMP	time.Time (location set to UTC)
165	ARRAY<T>	[]interface{}
166	STRUCT		TODO
167*/
168type row []interface{}
169
170func (r row) copyDataElem(index int) interface{} {
171	v := r[index]
172	if is, ok := v.([]interface{}); ok {
173		// Deep-copy array values.
174		v = append([]interface{}(nil), is...)
175	}
176	return v
177}
178
179// copyData returns a copy of the row.
180func (r row) copyAllData() row {
181	dst := make(row, 0, len(r))
182	for i := range r {
183		dst = append(dst, r.copyDataElem(i))
184	}
185	return dst
186}
187
188// copyData returns a copy of a subset of a row.
189func (r row) copyData(indexes []int) row {
190	if len(indexes) == 0 {
191		return nil
192	}
193	dst := make(row, 0, len(indexes))
194	for _, i := range indexes {
195		dst = append(dst, r.copyDataElem(i))
196	}
197	return dst
198}
199
200func (d *database) LastCommitTimestamp() time.Time {
201	d.mu.Lock()
202	defer d.mu.Unlock()
203	return d.lastTS
204}
205
206func (d *database) GetDDL() []spansql.DDLStmt {
207	// This lacks fidelity, but captures the details we support.
208	d.mu.Lock()
209	defer d.mu.Unlock()
210
211	var stmts []spansql.DDLStmt
212
213	for name, t := range d.tables {
214		ct := &spansql.CreateTable{
215			Name: name,
216		}
217
218		t.mu.Lock()
219		for i, col := range t.cols {
220			ct.Columns = append(ct.Columns, spansql.ColumnDef{
221				Name:    col.Name,
222				Type:    col.Type,
223				NotNull: col.NotNull,
224				// TODO: AllowCommitTimestamp
225			})
226			if i < t.pkCols {
227				ct.PrimaryKey = append(ct.PrimaryKey, spansql.KeyPart{
228					Column: col.Name,
229					Desc:   t.pkDesc[i],
230				})
231			}
232		}
233		t.mu.Unlock()
234
235		stmts = append(stmts, ct)
236	}
237
238	return stmts
239}
240
241func (d *database) ApplyDDL(stmt spansql.DDLStmt) *status.Status {
242	d.mu.Lock()
243	defer d.mu.Unlock()
244
245	// Lazy init.
246	if d.tables == nil {
247		d.tables = make(map[spansql.ID]*table)
248	}
249	if d.indexes == nil {
250		d.indexes = make(map[spansql.ID]struct{})
251	}
252
253	switch stmt := stmt.(type) {
254	default:
255		return status.Newf(codes.Unimplemented, "unhandled DDL statement type %T", stmt)
256	case *spansql.CreateTable:
257		if _, ok := d.tables[stmt.Name]; ok {
258			return status.Newf(codes.AlreadyExists, "table %s already exists", stmt.Name)
259		}
260		if len(stmt.PrimaryKey) == 0 {
261			return status.Newf(codes.InvalidArgument, "table %s has no primary key", stmt.Name)
262		}
263
264		// TODO: check stmt.Interleave details.
265
266		// Record original column ordering.
267		orig := make(map[spansql.ID]int)
268		for i, col := range stmt.Columns {
269			orig[col.Name] = i
270		}
271
272		// Move primary keys first, preserving their order.
273		pk := make(map[spansql.ID]int)
274		var pkDesc []bool
275		for i, kp := range stmt.PrimaryKey {
276			pk[kp.Column] = -1000 + i
277			pkDesc = append(pkDesc, kp.Desc)
278		}
279		sort.SliceStable(stmt.Columns, func(i, j int) bool {
280			a, b := pk[stmt.Columns[i].Name], pk[stmt.Columns[j].Name]
281			return a < b
282		})
283
284		t := &table{
285			colIndex:  make(map[spansql.ID]int),
286			origIndex: orig,
287			pkCols:    len(pk),
288			pkDesc:    pkDesc,
289		}
290		for _, cd := range stmt.Columns {
291			if st := t.addColumn(cd, true); st.Code() != codes.OK {
292				return st
293			}
294		}
295		for col := range pk {
296			if _, ok := t.colIndex[col]; !ok {
297				return status.Newf(codes.InvalidArgument, "primary key column %q not in table", col)
298			}
299		}
300		d.tables[stmt.Name] = t
301		return nil
302	case *spansql.CreateIndex:
303		if _, ok := d.indexes[stmt.Name]; ok {
304			return status.Newf(codes.AlreadyExists, "index %s already exists", stmt.Name)
305		}
306		d.indexes[stmt.Name] = struct{}{}
307		return nil
308	case *spansql.DropTable:
309		if _, ok := d.tables[stmt.Name]; !ok {
310			return status.Newf(codes.NotFound, "no table named %s", stmt.Name)
311		}
312		// TODO: check for indexes on this table.
313		delete(d.tables, stmt.Name)
314		return nil
315	case *spansql.DropIndex:
316		if _, ok := d.indexes[stmt.Name]; !ok {
317			return status.Newf(codes.NotFound, "no index named %s", stmt.Name)
318		}
319		delete(d.indexes, stmt.Name)
320		return nil
321	case *spansql.AlterTable:
322		t, ok := d.tables[stmt.Name]
323		if !ok {
324			return status.Newf(codes.NotFound, "no table named %s", stmt.Name)
325		}
326		switch alt := stmt.Alteration.(type) {
327		default:
328			return status.Newf(codes.Unimplemented, "unhandled DDL table alteration type %T", alt)
329		case spansql.AddColumn:
330			if st := t.addColumn(alt.Def, false); st.Code() != codes.OK {
331				return st
332			}
333			return nil
334		case spansql.DropColumn:
335			if st := t.dropColumn(alt.Name); st.Code() != codes.OK {
336				return st
337			}
338			return nil
339		case spansql.AlterColumn:
340			if st := t.alterColumn(alt); st.Code() != codes.OK {
341				return st
342			}
343			return nil
344		}
345	}
346
347}
348
349func (d *database) table(tbl spansql.ID) (*table, error) {
350	d.mu.Lock()
351	defer d.mu.Unlock()
352
353	t, ok := d.tables[tbl]
354	if !ok {
355		return nil, status.Errorf(codes.NotFound, "no table named %s", tbl)
356	}
357	return t, nil
358}
359
360// writeValues executes a write option (Insert, Update, etc.).
361func (d *database) writeValues(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue, f func(t *table, colIndexes []int, r row) error) error {
362	if err := tx.checkMutable(); err != nil {
363		return err
364	}
365
366	t, err := d.table(tbl)
367	if err != nil {
368		return err
369	}
370
371	t.mu.Lock()
372	defer t.mu.Unlock()
373
374	colIndexes, err := t.colIndexes(cols)
375	if err != nil {
376		return err
377	}
378	revIndex := make(map[int]int) // table index to col index
379	for j, i := range colIndexes {
380		revIndex[i] = j
381	}
382
383	for pki := 0; pki < t.pkCols; pki++ {
384		_, ok := revIndex[pki]
385		if !ok {
386			return status.Errorf(codes.InvalidArgument, "primary key column %s not included in write", t.cols[pki].Name)
387		}
388	}
389
390	for _, vs := range values {
391		if len(vs.Values) != len(colIndexes) {
392			return status.Errorf(codes.InvalidArgument, "row of %d values can't be written to %d columns", len(vs.Values), len(colIndexes))
393		}
394
395		r := make(row, len(t.cols))
396		for j, v := range vs.Values {
397			i := colIndexes[j]
398
399			if t.cols[i].Generated != nil {
400				return status.Error(codes.InvalidArgument, "values can't be written to a generated column")
401			}
402			x, err := valForType(v, t.cols[i].Type)
403			if err != nil {
404				return err
405			}
406			if x == commitTimestampSentinel {
407				x = tx.commitTimestamp
408			}
409			if x == nil && t.cols[i].NotNull {
410				return status.Errorf(codes.FailedPrecondition, "%s must not be NULL in table %s", t.cols[i].Name, tbl)
411			}
412
413			r[i] = x
414		}
415		// TODO: enforce that provided timestamp for commit_timestamp=true columns
416		// are not ahead of the transaction's commit timestamp.
417
418		if err := f(t, colIndexes, r); err != nil {
419			return err
420		}
421
422		// Get row again after potential update merge to ensure we compute
423		// generated columns with fresh data.
424		pk := r[:t.pkCols]
425		rowNum, found := t.rowForPK(pk)
426		// This should never fail as the row was just inserted.
427		if !found {
428			return status.Error(codes.Internal, "row failed to be inserted")
429		}
430		row := t.rows[rowNum]
431		ec := evalContext{
432			cols: t.cols,
433			row:  row,
434		}
435
436		// TODO: We would need to do a topological sort on dependencies
437		// (i.e. what other columns the expression references) to ensure we
438		// can handle generated columns which reference other generated columns
439		for i, col := range t.cols {
440			if col.Generated != nil {
441				res, err := ec.evalExpr(col.Generated)
442				if err != nil {
443					return err
444				}
445				row[i] = res
446			}
447		}
448	}
449
450	return nil
451}
452
453func (d *database) Insert(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue) error {
454	return d.writeValues(tx, tbl, cols, values, func(t *table, colIndexes []int, r row) error {
455		pk := r[:t.pkCols]
456		rowNum, found := t.rowForPK(pk)
457		if found {
458			return status.Errorf(codes.AlreadyExists, "row already in table")
459		}
460		t.insertRow(rowNum, r)
461		return nil
462	})
463}
464
465func (d *database) Update(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue) error {
466	return d.writeValues(tx, tbl, cols, values, func(t *table, colIndexes []int, r row) error {
467		if t.pkCols == 0 {
468			return status.Errorf(codes.InvalidArgument, "cannot update table %s with no columns in primary key", tbl)
469		}
470		pk := r[:t.pkCols]
471		rowNum, found := t.rowForPK(pk)
472		if !found {
473			// TODO: is this the right way to return `NOT_FOUND`?
474			return status.Errorf(codes.NotFound, "row not in table")
475		}
476
477		for _, i := range colIndexes {
478			t.rows[rowNum][i] = r[i]
479		}
480		return nil
481	})
482}
483
484func (d *database) InsertOrUpdate(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue) error {
485	return d.writeValues(tx, tbl, cols, values, func(t *table, colIndexes []int, r row) error {
486		pk := r[:t.pkCols]
487		rowNum, found := t.rowForPK(pk)
488		if !found {
489			// New row; do an insert.
490			t.insertRow(rowNum, r)
491		} else {
492			// Existing row; do an update.
493			for _, i := range colIndexes {
494				t.rows[rowNum][i] = r[i]
495			}
496		}
497		return nil
498	})
499}
500
501// TODO: Replace
502
503func (d *database) Delete(tx *transaction, table spansql.ID, keys []*structpb.ListValue, keyRanges keyRangeList, all bool) error {
504	if err := tx.checkMutable(); err != nil {
505		return err
506	}
507
508	t, err := d.table(table)
509	if err != nil {
510		return err
511	}
512
513	t.mu.Lock()
514	defer t.mu.Unlock()
515
516	if all {
517		t.rows = nil
518		return nil
519	}
520
521	for _, key := range keys {
522		pk, err := t.primaryKey(key.Values)
523		if err != nil {
524			return err
525		}
526		// Not an error if the key does not exist.
527		rowNum, found := t.rowForPK(pk)
528		if found {
529			copy(t.rows[rowNum:], t.rows[rowNum+1:])
530			t.rows = t.rows[:len(t.rows)-1]
531		}
532	}
533
534	for _, r := range keyRanges {
535		r.startKey, err = t.primaryKeyPrefix(r.start.Values)
536		if err != nil {
537			return err
538		}
539		r.endKey, err = t.primaryKeyPrefix(r.end.Values)
540		if err != nil {
541			return err
542		}
543		startRow, endRow := t.findRange(r)
544		if n := endRow - startRow; n > 0 {
545			copy(t.rows[startRow:], t.rows[endRow:])
546			t.rows = t.rows[:len(t.rows)-n]
547		}
548	}
549
550	return nil
551}
552
553// readTable executes a read option (Read, ReadAll).
554func (d *database) readTable(table spansql.ID, cols []spansql.ID, f func(*table, *rawIter, []int) error) (*rawIter, error) {
555	t, err := d.table(table)
556	if err != nil {
557		return nil, err
558	}
559
560	t.mu.Lock()
561	defer t.mu.Unlock()
562
563	colIndexes, err := t.colIndexes(cols)
564	if err != nil {
565		return nil, err
566	}
567
568	ri := &rawIter{}
569	for _, i := range colIndexes {
570		ri.cols = append(ri.cols, t.cols[i])
571	}
572	return ri, f(t, ri, colIndexes)
573}
574
575func (d *database) Read(tbl spansql.ID, cols []spansql.ID, keys []*structpb.ListValue, keyRanges keyRangeList, limit int64) (rowIter, error) {
576	// The real Cloud Spanner returns an error if the key set is empty by definition.
577	// That doesn't seem to be well-defined, but it is a common error to attempt a read with no keys,
578	// so catch that here and return a representative error.
579	if len(keys) == 0 && len(keyRanges) == 0 {
580		return nil, status.Error(codes.Unimplemented, "Cloud Spanner does not support reading no keys")
581	}
582
583	return d.readTable(tbl, cols, func(t *table, ri *rawIter, colIndexes []int) error {
584		// "If the same key is specified multiple times in the set (for
585		// example if two ranges, two keys, or a key and a range
586		// overlap), Cloud Spanner behaves as if the key were only
587		// specified once."
588		done := make(map[int]bool) // row numbers we've included in ri.
589
590		// Specific keys.
591		for _, key := range keys {
592			pk, err := t.primaryKey(key.Values)
593			if err != nil {
594				return err
595			}
596			// Not an error if the key does not exist.
597			rowNum, found := t.rowForPK(pk)
598			if !found {
599				continue
600			}
601			if done[rowNum] {
602				continue
603			}
604			done[rowNum] = true
605			ri.add(t.rows[rowNum], colIndexes)
606			if limit > 0 && len(ri.rows) >= int(limit) {
607				return nil
608			}
609		}
610
611		// Key ranges.
612		for _, r := range keyRanges {
613			var err error
614			r.startKey, err = t.primaryKeyPrefix(r.start.Values)
615			if err != nil {
616				return err
617			}
618			r.endKey, err = t.primaryKeyPrefix(r.end.Values)
619			if err != nil {
620				return err
621			}
622			startRow, endRow := t.findRange(r)
623			for rowNum := startRow; rowNum < endRow; rowNum++ {
624				if done[rowNum] {
625					continue
626				}
627				done[rowNum] = true
628				ri.add(t.rows[rowNum], colIndexes)
629				if limit > 0 && len(ri.rows) >= int(limit) {
630					return nil
631				}
632			}
633		}
634
635		return nil
636	})
637}
638
639func (d *database) ReadAll(tbl spansql.ID, cols []spansql.ID, limit int64) (*rawIter, error) {
640	return d.readTable(tbl, cols, func(t *table, ri *rawIter, colIndexes []int) error {
641		for _, r := range t.rows {
642			ri.add(r, colIndexes)
643			if limit > 0 && len(ri.rows) >= int(limit) {
644				break
645			}
646		}
647		return nil
648	})
649}
650
651func (t *table) addColumn(cd spansql.ColumnDef, newTable bool) *status.Status {
652	if !newTable && cd.NotNull {
653		return status.Newf(codes.InvalidArgument, "new non-key columns cannot be NOT NULL")
654	}
655
656	if _, ok := t.colIndex[cd.Name]; ok {
657		return status.Newf(codes.AlreadyExists, "column %s already exists", cd.Name)
658	}
659
660	t.mu.Lock()
661	defer t.mu.Unlock()
662
663	if len(t.rows) > 0 {
664		if cd.NotNull {
665			// TODO: what happens in this case?
666			return status.Newf(codes.Unimplemented, "can't add NOT NULL columns to non-empty tables yet")
667		}
668		if cd.Generated != nil {
669			// TODO: should backfill the data to maintain behaviour with real spanner
670			return status.Newf(codes.Unimplemented, "can't add generated columns to non-empty tables yet")
671		}
672		for i := range t.rows {
673			t.rows[i] = append(t.rows[i], nil)
674		}
675	}
676
677	t.cols = append(t.cols, colInfo{
678		Name:    cd.Name,
679		Type:    cd.Type,
680		NotNull: cd.NotNull,
681		// TODO: We should figure out what columns the Generator expression
682		// relies on and check it is valid at this time currently it will
683		// fail when writing data instead as it is the first time we
684		// evaluate the expression.
685		Generated: cd.Generated,
686	})
687	t.colIndex[cd.Name] = len(t.cols) - 1
688	if !newTable {
689		t.origIndex[cd.Name] = len(t.cols) - 1
690	}
691
692	return nil
693}
694
695func (t *table) dropColumn(name spansql.ID) *status.Status {
696	// Only permit dropping non-key columns that aren't part of a secondary index.
697	// We don't support indexes, so only check that it isn't part of the primary key.
698
699	t.mu.Lock()
700	defer t.mu.Unlock()
701
702	ci, ok := t.colIndex[name]
703	if !ok {
704		// TODO: What's the right response code?
705		return status.Newf(codes.InvalidArgument, "unknown column %q", name)
706	}
707	if ci < t.pkCols {
708		// TODO: What's the right response code?
709		return status.Newf(codes.InvalidArgument, "can't drop primary key column %q", name)
710	}
711
712	// Remove from cols and colIndex, and renumber colIndex and origIndex.
713	t.cols = append(t.cols[:ci], t.cols[ci+1:]...)
714	delete(t.colIndex, name)
715	for i, col := range t.cols {
716		t.colIndex[col.Name] = i
717	}
718	pre := t.origIndex[name]
719	delete(t.origIndex, name)
720	for n, i := range t.origIndex {
721		if i > pre {
722			t.origIndex[n]--
723		}
724	}
725
726	// Drop data.
727	for i := range t.rows {
728		t.rows[i] = append(t.rows[i][:ci], t.rows[i][ci+1:]...)
729	}
730
731	return nil
732}
733
734func (t *table) alterColumn(alt spansql.AlterColumn) *status.Status {
735	// Supported changes here are:
736	//	Add NOT NULL to a non-key column, excluding ARRAY columns.
737	//	Remove NOT NULL from a non-key column.
738	//	Change a STRING column to a BYTES column or a BYTES column to a STRING column.
739	//	Increase or decrease the length limit for a STRING or BYTES type (including to MAX).
740	//	Enable or disable commit timestamps in value and primary key columns.
741	// https://cloud.google.com/spanner/docs/schema-updates#supported-updates
742
743	// TODO: codes.InvalidArgument is used throughout here for reporting errors,
744	// but that has not been validated against the real Spanner.
745
746	sct, ok := alt.Alteration.(spansql.SetColumnType)
747	if !ok {
748		return status.Newf(codes.InvalidArgument, "unsupported ALTER COLUMN %s", alt.SQL())
749	}
750
751	t.mu.Lock()
752	defer t.mu.Unlock()
753
754	ci, ok := t.colIndex[alt.Name]
755	if !ok {
756		return status.Newf(codes.InvalidArgument, "unknown column %q", alt.Name)
757	}
758
759	oldT, newT := t.cols[ci].Type, sct.Type
760	stringOrBytes := func(bt spansql.TypeBase) bool { return bt == spansql.String || bt == spansql.Bytes }
761
762	// First phase: Check the validity of the change.
763	// TODO: Don't permit changes to allow commit timestamps.
764	if !t.cols[ci].NotNull && sct.NotNull {
765		// Adding NOT NULL is not permitted for primary key columns or array typed columns.
766		if ci < t.pkCols {
767			return status.Newf(codes.InvalidArgument, "cannot set NOT NULL on primary key column %q", alt.Name)
768		}
769		if oldT.Array {
770			return status.Newf(codes.InvalidArgument, "cannot set NOT NULL on array-typed column %q", alt.Name)
771		}
772		// Validate that there are no NULL values.
773		for _, row := range t.rows {
774			if row[ci] == nil {
775				return status.Newf(codes.InvalidArgument, "cannot set NOT NULL on column %q that contains NULL values", alt.Name)
776			}
777		}
778	}
779	var conv func(x interface{}) interface{}
780	if stringOrBytes(oldT.Base) && stringOrBytes(newT.Base) && !oldT.Array && !newT.Array {
781		// Change between STRING and BYTES is fine, as is increasing/decreasing the length limit.
782		// TODO: This should permit array conversions too.
783		// TODO: Validate data; length limit changes should be rejected if they'd lead to data loss, for instance.
784		if oldT.Base == spansql.Bytes && newT.Base == spansql.String {
785			conv = func(x interface{}) interface{} { return string(x.([]byte)) }
786		} else if oldT.Base == spansql.String && newT.Base == spansql.Bytes {
787			conv = func(x interface{}) interface{} { return []byte(x.(string)) }
788		}
789	} else if oldT == newT {
790		// Same type; only NOT NULL changes.
791	} else { // TODO: Support other alterations.
792		return status.Newf(codes.InvalidArgument, "unsupported ALTER COLUMN %s", alt.SQL())
793	}
794
795	// Second phase: Make type transformations.
796	t.cols[ci].NotNull = sct.NotNull
797	t.cols[ci].Type = newT
798	if conv != nil {
799		for _, row := range t.rows {
800			if row[ci] != nil { // NULL stays as NULL.
801				row[ci] = conv(row[ci])
802			}
803		}
804	}
805	return nil
806}
807
808func (t *table) insertRow(rowNum int, r row) {
809	t.rows = append(t.rows, nil)
810	copy(t.rows[rowNum+1:], t.rows[rowNum:])
811	t.rows[rowNum] = r
812}
813
814// findRange finds the rows included in the key range,
815// reporting it as a half-open interval.
816// r.startKey and r.endKey should be populated.
817func (t *table) findRange(r *keyRange) (int, int) {
818	// startRow is the first row matching the range.
819	startRow := sort.Search(len(t.rows), func(i int) bool {
820		return rowCmp(r.startKey, t.rows[i][:t.pkCols], t.pkDesc) <= 0
821	})
822	if startRow == len(t.rows) {
823		return startRow, startRow
824	}
825	if !r.startClosed && rowCmp(r.startKey, t.rows[startRow][:t.pkCols], t.pkDesc) == 0 {
826		startRow++
827	}
828
829	// endRow is one more than the last row matching the range.
830	endRow := sort.Search(len(t.rows), func(i int) bool {
831		return rowCmp(r.endKey, t.rows[i][:t.pkCols], t.pkDesc) < 0
832	})
833	if !r.endClosed && rowCmp(r.endKey, t.rows[endRow-1][:t.pkCols], t.pkDesc) == 0 {
834		endRow--
835	}
836
837	return startRow, endRow
838}
839
840// colIndexes returns the indexes for the named columns.
841func (t *table) colIndexes(cols []spansql.ID) ([]int, error) {
842	var is []int
843	for _, col := range cols {
844		i, ok := t.colIndex[col]
845		if !ok {
846			return nil, status.Errorf(codes.InvalidArgument, "column %s not in table", col)
847		}
848		is = append(is, i)
849	}
850	return is, nil
851}
852
853// primaryKey constructs the internal representation of a primary key.
854// The list of given values must be in 1:1 correspondence with the primary key of the table.
855func (t *table) primaryKey(values []*structpb.Value) ([]interface{}, error) {
856	if len(values) != t.pkCols {
857		return nil, status.Errorf(codes.InvalidArgument, "primary key length mismatch: got %d values, table has %d", len(values), t.pkCols)
858	}
859	return t.primaryKeyPrefix(values)
860}
861
862// primaryKeyPrefix constructs the internal representation of a primary key prefix.
863func (t *table) primaryKeyPrefix(values []*structpb.Value) ([]interface{}, error) {
864	if len(values) > t.pkCols {
865		return nil, status.Errorf(codes.InvalidArgument, "primary key length too long: got %d values, table has %d", len(values), t.pkCols)
866	}
867
868	var pk []interface{}
869	for i, value := range values {
870		v, err := valForType(value, t.cols[i].Type)
871		if err != nil {
872			return nil, err
873		}
874		pk = append(pk, v)
875	}
876	return pk, nil
877}
878
879// rowForPK returns the index of t.rows that holds the row for the given primary key, and true.
880// If the given primary key isn't found, it returns the row that should hold it, and false.
881func (t *table) rowForPK(pk []interface{}) (row int, found bool) {
882	if len(pk) != t.pkCols {
883		panic(fmt.Sprintf("primary key length mismatch: got %d values, table has %d", len(pk), t.pkCols))
884	}
885
886	i := sort.Search(len(t.rows), func(i int) bool {
887		return rowCmp(pk, t.rows[i][:t.pkCols], t.pkDesc) <= 0
888	})
889	if i == len(t.rows) {
890		return i, false
891	}
892	return i, rowEqual(pk, t.rows[i][:t.pkCols])
893}
894
895// rowCmp compares two rows, returning -1/0/+1.
896// The desc arg indicates whether each column is in a descending order.
897// This is used for primary key matching and so doesn't support array/struct types.
898// a is permitted to be shorter than b.
899func rowCmp(a, b []interface{}, desc []bool) int {
900	for i := 0; i < len(a); i++ {
901		if cmp := compareVals(a[i], b[i]); cmp != 0 {
902			if desc[i] {
903				cmp = -cmp
904			}
905			return cmp
906		}
907	}
908	return 0
909}
910
911// rowEqual reports whether two rows are equal.
912// This doesn't support array/struct types.
913func rowEqual(a, b []interface{}) bool {
914	for i := 0; i < len(a); i++ {
915		if compareVals(a[i], b[i]) != 0 {
916			return false
917		}
918	}
919	return true
920}
921
922// valForType converts a value from its RPC form into its internal representation.
923func valForType(v *structpb.Value, t spansql.Type) (interface{}, error) {
924	if _, ok := v.Kind.(*structpb.Value_NullValue); ok {
925		return nil, nil
926	}
927
928	if lv, ok := v.Kind.(*structpb.Value_ListValue); ok && t.Array {
929		et := t // element type
930		et.Array = false
931
932		// Construct the non-nil slice for the list.
933		arr := make([]interface{}, 0, len(lv.ListValue.Values))
934		for _, v := range lv.ListValue.Values {
935			x, err := valForType(v, et)
936			if err != nil {
937				return nil, err
938			}
939			arr = append(arr, x)
940		}
941		return arr, nil
942	}
943
944	switch t.Base {
945	case spansql.Bool:
946		bv, ok := v.Kind.(*structpb.Value_BoolValue)
947		if ok {
948			return bv.BoolValue, nil
949		}
950	case spansql.Int64:
951		// The Spanner protocol encodes int64 as a decimal string.
952		sv, ok := v.Kind.(*structpb.Value_StringValue)
953		if ok {
954			x, err := strconv.ParseInt(sv.StringValue, 10, 64)
955			if err != nil {
956				return nil, fmt.Errorf("bad int64 string %q: %v", sv.StringValue, err)
957			}
958			return x, nil
959		}
960	case spansql.Float64:
961		nv, ok := v.Kind.(*structpb.Value_NumberValue)
962		if ok {
963			return nv.NumberValue, nil
964		}
965	case spansql.String:
966		sv, ok := v.Kind.(*structpb.Value_StringValue)
967		if ok {
968			return sv.StringValue, nil
969		}
970	case spansql.Bytes:
971		sv, ok := v.Kind.(*structpb.Value_StringValue)
972		if ok {
973			// The Spanner protocol encodes BYTES in base64.
974			return base64.StdEncoding.DecodeString(sv.StringValue)
975		}
976	case spansql.Date:
977		// The Spanner protocol encodes DATE in RFC 3339 date format.
978		sv, ok := v.Kind.(*structpb.Value_StringValue)
979		if ok {
980			s := sv.StringValue
981			d, err := parseAsDate(s)
982			if err != nil {
983				return nil, fmt.Errorf("bad DATE string %q: %v", s, err)
984			}
985			return d, nil
986		}
987	case spansql.Timestamp:
988		// The Spanner protocol encodes TIMESTAMP in RFC 3339 timestamp format with zone Z.
989		sv, ok := v.Kind.(*structpb.Value_StringValue)
990		if ok {
991			s := sv.StringValue
992			if strings.ToLower(s) == "spanner.commit_timestamp()" {
993				return commitTimestampSentinel, nil
994			}
995			t, err := parseAsTimestamp(s)
996			if err != nil {
997				return nil, fmt.Errorf("bad TIMESTAMP string %q: %v", s, err)
998			}
999			return t, nil
1000		}
1001	}
1002	return nil, fmt.Errorf("unsupported inserting value kind %T into column of type %s", v.Kind, t.SQL())
1003}
1004
1005type keyRange struct {
1006	start, end             *structpb.ListValue
1007	startClosed, endClosed bool
1008
1009	// These are populated during an operation
1010	// when we know what table this keyRange applies to.
1011	startKey, endKey []interface{}
1012}
1013
1014func (r *keyRange) String() string {
1015	var sb bytes.Buffer // TODO: Switch to strings.Builder when we drop support for Go 1.9.
1016	if r.startClosed {
1017		sb.WriteString("[")
1018	} else {
1019		sb.WriteString("(")
1020	}
1021	fmt.Fprintf(&sb, "%v,%v", r.startKey, r.endKey)
1022	if r.endClosed {
1023		sb.WriteString("]")
1024	} else {
1025		sb.WriteString(")")
1026	}
1027	return sb.String()
1028}
1029
1030type keyRangeList []*keyRange
1031
1032// Execute runs a DML statement.
1033// It returns the number of affected rows.
1034func (d *database) Execute(stmt spansql.DMLStmt, params queryParams) (int, error) { // TODO: return *status.Status instead?
1035	switch stmt := stmt.(type) {
1036	default:
1037		return 0, status.Errorf(codes.Unimplemented, "unhandled DML statement type %T", stmt)
1038	case *spansql.Delete:
1039		t, err := d.table(stmt.Table)
1040		if err != nil {
1041			return 0, err
1042		}
1043
1044		t.mu.Lock()
1045		defer t.mu.Unlock()
1046
1047		n := 0
1048		for i := 0; i < len(t.rows); {
1049			ec := evalContext{
1050				cols:   t.cols,
1051				row:    t.rows[i],
1052				params: params,
1053			}
1054			b, err := ec.evalBoolExpr(stmt.Where)
1055			if err != nil {
1056				return 0, err
1057			}
1058			if b != nil && *b {
1059				copy(t.rows[i:], t.rows[i+1:])
1060				t.rows = t.rows[:len(t.rows)-1]
1061				n++
1062				continue
1063			}
1064			i++
1065		}
1066		return n, nil
1067	case *spansql.Update:
1068		t, err := d.table(stmt.Table)
1069		if err != nil {
1070			return 0, err
1071		}
1072
1073		t.mu.Lock()
1074		defer t.mu.Unlock()
1075
1076		ec := evalContext{
1077			cols:   t.cols,
1078			params: params,
1079		}
1080
1081		// Build parallel slices of destination column index and expressions to evaluate.
1082		var dstIndex []int
1083		var expr []spansql.Expr
1084		for _, ui := range stmt.Items {
1085			i, err := ec.resolveColumnIndex(ui.Column)
1086			if err != nil {
1087				return 0, err
1088			}
1089			// TODO: Enforce "A column can appear only once in the SET clause.".
1090			if i < t.pkCols {
1091				return 0, status.Errorf(codes.InvalidArgument, "cannot update primary key %s", ui.Column)
1092			}
1093			dstIndex = append(dstIndex, i)
1094			expr = append(expr, ui.Value)
1095		}
1096
1097		n := 0
1098		values := make(row, len(stmt.Items)) // scratch space for new values
1099		for i := 0; i < len(t.rows); i++ {
1100			ec.row = t.rows[i]
1101			b, err := ec.evalBoolExpr(stmt.Where)
1102			if err != nil {
1103				return 0, err
1104			}
1105			if b != nil && *b {
1106				// Compute every update item.
1107				for j := range dstIndex {
1108					if expr[j] == nil { // DEFAULT
1109						values[j] = nil
1110						continue
1111					}
1112					v, err := ec.evalExpr(expr[j])
1113					if err != nil {
1114						return 0, err
1115					}
1116					values[j] = v
1117				}
1118				// Write them to the row.
1119				for j, v := range values {
1120					t.rows[i][dstIndex[j]] = v
1121				}
1122				n++
1123			}
1124		}
1125		return n, nil
1126	}
1127}
1128
1129func parseAsDate(s string) (civil.Date, error) { return civil.ParseDate(s) }
1130func parseAsTimestamp(s string) (time.Time, error) {
1131	return time.Parse("2006-01-02T15:04:05.999999999Z", s)
1132}
1133