1/*
2Copyright 2019 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spannertest
18
19// This file contains the implementation of the Spanner fake itself,
20// namely the part behind the RPC interface.
21
22// TODO: missing transactionality in a serious way!
23
24import (
25	"bytes"
26	"encoding/base64"
27	"fmt"
28	"sort"
29	"strconv"
30	"sync"
31	"time"
32
33	"google.golang.org/grpc/codes"
34	"google.golang.org/grpc/status"
35
36	structpb "github.com/golang/protobuf/ptypes/struct"
37
38	"cloud.google.com/go/spanner/spansql"
39)
40
41type database struct {
42	mu      sync.Mutex
43	tables  map[string]*table
44	indexes map[string]struct{} // only record their existence
45}
46
47type table struct {
48	mu sync.Mutex
49
50	// Information about the table columns.
51	// They are reordered on table creation so the primary key columns come first.
52	cols     []colInfo
53	colIndex map[string]int // col name to index
54	pkCols   int            // number of primary key columns (may be 0)
55
56	// Rows are stored in primary key order.
57	rows []row
58}
59
60// colInfo represents information about a column in a table or result set.
61type colInfo struct {
62	Name string
63	Type spansql.Type
64}
65
66/*
67row represents a list of data elements.
68
69The mapping between Spanner types and Go types internal to this package are:
70	BOOL		bool
71	INT64		int64
72	FLOAT64		float64
73	STRING		string
74	BYTES		[]byte
75	DATE		string (RFC 3339 date; "YYYY-MM-DD")
76	TIMESTAMP	TODO
77	ARRAY<T>	[]T
78	STRUCT		TODO
79*/
80type row []interface{}
81
82func (r row) copyDataElem(index int) interface{} {
83	v := r[index]
84	if is, ok := v.([]interface{}); ok {
85		// Deep-copy array values.
86		v = append([]interface{}(nil), is...)
87	}
88	return v
89}
90
91// copyData returns a copy of a subset of a row.
92func (r row) copyData(indexes []int) row {
93	if len(indexes) == 0 {
94		return nil
95	}
96	dst := make(row, 0, len(indexes))
97	for _, i := range indexes {
98		dst = append(dst, r.copyDataElem(i))
99	}
100	return dst
101}
102
103func (d *database) ApplyDDL(stmt spansql.DDLStmt) *status.Status {
104	d.mu.Lock()
105	defer d.mu.Unlock()
106
107	// Lazy init.
108	if d.tables == nil {
109		d.tables = make(map[string]*table)
110	}
111	if d.indexes == nil {
112		d.indexes = make(map[string]struct{})
113	}
114
115	switch stmt := stmt.(type) {
116	default:
117		return status.Newf(codes.Unimplemented, "unhandled DDL statement type %T", stmt)
118	case spansql.CreateTable:
119		if _, ok := d.tables[stmt.Name]; ok {
120			return status.Newf(codes.AlreadyExists, "table %s already exists", stmt.Name)
121		}
122
123		// TODO: check stmt.Interleave details.
124
125		// Move primary keys first, preserving their order.
126		pk := make(map[string]int)
127		for i, kp := range stmt.PrimaryKey {
128			pk[kp.Column] = -1000 + i
129		}
130		sort.SliceStable(stmt.Columns, func(i, j int) bool {
131			a, b := pk[stmt.Columns[i].Name], pk[stmt.Columns[j].Name]
132			return a < b
133		})
134
135		t := &table{
136			colIndex: make(map[string]int),
137			pkCols:   len(pk),
138		}
139		for _, cd := range stmt.Columns {
140			if st := t.addColumn(cd); st.Code() != codes.OK {
141				return st
142			}
143		}
144		for col := range pk {
145			if _, ok := t.colIndex[col]; !ok {
146				return status.Newf(codes.InvalidArgument, "primary key column %q not in table", col)
147			}
148		}
149		d.tables[stmt.Name] = t
150		return nil
151	case spansql.CreateIndex:
152		if _, ok := d.indexes[stmt.Name]; ok {
153			return status.Newf(codes.AlreadyExists, "index %s already exists", stmt.Name)
154		}
155		d.indexes[stmt.Name] = struct{}{}
156		return nil
157	case spansql.DropTable:
158		if _, ok := d.tables[stmt.Name]; !ok {
159			return status.Newf(codes.NotFound, "no table named %s", stmt.Name)
160		}
161		// TODO: check for indexes on this table.
162		delete(d.tables, stmt.Name)
163		return nil
164	case spansql.DropIndex:
165		if _, ok := d.indexes[stmt.Name]; !ok {
166			return status.Newf(codes.NotFound, "no index named %s", stmt.Name)
167		}
168		delete(d.indexes, stmt.Name)
169		return nil
170	case spansql.AlterTable:
171		t, ok := d.tables[stmt.Name]
172		if !ok {
173			return status.Newf(codes.NotFound, "no table named %s", stmt.Name)
174		}
175		switch alt := stmt.Alteration.(type) {
176		default:
177			return status.Newf(codes.Unimplemented, "unhandled DDL table alteration type %T", alt)
178		case spansql.AddColumn:
179			if alt.Def.NotNull {
180				return status.Newf(codes.InvalidArgument, "new non-key columns cannot be NOT NULL")
181			}
182			if st := t.addColumn(alt.Def); st.Code() != codes.OK {
183				return st
184			}
185			return nil
186		}
187	}
188
189}
190
191func (d *database) table(tbl string) (*table, error) {
192	d.mu.Lock()
193	defer d.mu.Unlock()
194
195	t, ok := d.tables[tbl]
196	if !ok {
197		return nil, status.Errorf(codes.NotFound, "no table named %s", tbl)
198	}
199	return t, nil
200}
201
202// writeValues executes a write option (Insert, Update, etc.).
203func (d *database) writeValues(tbl string, cols []string, values []*structpb.ListValue, f func(t *table, colIndexes []int, r row) error) error {
204	t, err := d.table(tbl)
205	if err != nil {
206		return err
207	}
208
209	t.mu.Lock()
210	defer t.mu.Unlock()
211
212	colIndexes, err := t.colIndexes(cols)
213	if err != nil {
214		return err
215	}
216	revIndex := make(map[int]int) // table index to col index
217	for j, i := range colIndexes {
218		revIndex[i] = j
219	}
220
221	for pki := 0; pki < t.pkCols; pki++ {
222		_, ok := revIndex[pki]
223		if !ok {
224			return status.Errorf(codes.InvalidArgument, "primary key column %s not included in write", t.cols[pki].Name)
225		}
226	}
227
228	for _, vs := range values {
229		if len(vs.Values) != len(colIndexes) {
230			return status.Errorf(codes.InvalidArgument, "row of %d values can't be written to %d columns", len(vs.Values), len(colIndexes))
231		}
232
233		r := make(row, len(t.cols))
234		for j, v := range vs.Values {
235			i := colIndexes[j]
236
237			x, err := valForType(v, t.cols[i].Type)
238			if err != nil {
239				return err
240			}
241
242			r[i] = x
243		}
244		// TODO: enforce NOT NULL?
245
246		if err := f(t, colIndexes, r); err != nil {
247			return err
248		}
249	}
250
251	return nil
252}
253
254func (d *database) Insert(tbl string, cols []string, values []*structpb.ListValue) error {
255	return d.writeValues(tbl, cols, values, func(t *table, colIndexes []int, r row) error {
256		pk := r[:t.pkCols]
257		rowNum, found := t.rowForPK(pk)
258		if found {
259			return status.Errorf(codes.AlreadyExists, "row already in table")
260		}
261		t.insertRow(rowNum, r)
262		return nil
263	})
264}
265
266func (d *database) Update(tbl string, cols []string, values []*structpb.ListValue) error {
267	return d.writeValues(tbl, cols, values, func(t *table, colIndexes []int, r row) error {
268		if t.pkCols == 0 {
269			return status.Errorf(codes.InvalidArgument, "cannot update table %s with no columns in primary key", tbl)
270		}
271		pk := r[:t.pkCols]
272		rowNum, found := t.rowForPK(pk)
273		if !found {
274			// TODO: is this the right way to return `NOT_FOUND`?
275			return status.Errorf(codes.NotFound, "row not in table")
276		}
277
278		for _, i := range colIndexes {
279			t.rows[rowNum][i] = r[i]
280		}
281		return nil
282	})
283}
284
285func (d *database) InsertOrUpdate(tbl string, cols []string, values []*structpb.ListValue) error {
286	return d.writeValues(tbl, cols, values, func(t *table, colIndexes []int, r row) error {
287		pk := r[:t.pkCols]
288		rowNum, found := t.rowForPK(pk)
289		if !found {
290			// New row; do an insert.
291			t.insertRow(rowNum, r)
292		} else {
293			// Existing row; do an update.
294			for _, i := range colIndexes {
295				t.rows[rowNum][i] = r[i]
296			}
297		}
298		return nil
299	})
300}
301
302// TODO: Replace
303
304func (d *database) Delete(table string, keys []*structpb.ListValue, keyRanges keyRangeList, all bool) error {
305	t, err := d.table(table)
306	if err != nil {
307		return err
308	}
309
310	t.mu.Lock()
311	defer t.mu.Unlock()
312
313	if all {
314		t.rows = nil
315		return nil
316	}
317
318	for _, key := range keys {
319		pk, err := t.primaryKey(key.Values)
320		if err != nil {
321			return err
322		}
323		// Not an error if the key does not exist.
324		rowNum, found := t.rowForPK(pk)
325		if found {
326			copy(t.rows[rowNum:], t.rows[rowNum+1:])
327			t.rows = t.rows[:len(t.rows)-1]
328		}
329	}
330
331	for _, r := range keyRanges {
332		r.startKey, err = t.primaryKeyPrefix(r.start.Values)
333		if err != nil {
334			return err
335		}
336		r.endKey, err = t.primaryKeyPrefix(r.end.Values)
337		if err != nil {
338			return err
339		}
340		startRow, endRow := t.findRange(r)
341		if n := endRow - startRow; n > 0 {
342			copy(t.rows[startRow:], t.rows[endRow:])
343			t.rows = t.rows[:len(t.rows)-n]
344		}
345	}
346
347	return nil
348}
349
350// resultIter is returned by reads and queries.
351// Use its Next method to iterate over the result rows.
352type resultIter struct {
353	// Cols is the metadata about the returned data.
354	Cols []colInfo
355
356	// rows holds the result data itself.
357	rows []resultRow
358}
359
360type resultRow struct {
361	data []interface{}
362
363	// aux is any auxiliary values evaluated for the row.
364	// When a query has an ORDER BY clause, this will contain the values for those expressions.
365	aux []interface{}
366}
367
368func (ri *resultIter) Next() ([]interface{}, bool) {
369	if len(ri.rows) == 0 {
370		return nil, false
371	}
372	res := ri.rows[0]
373	ri.rows = ri.rows[1:]
374	return res.data, true
375}
376
377func (ri *resultIter) add(src row, colIndexes []int) {
378	ri.rows = append(ri.rows, resultRow{
379		data: src.copyData(colIndexes),
380	})
381}
382
383// readTable executes a read option (Read, ReadAll).
384func (d *database) readTable(table string, cols []string, f func(*table, *resultIter, []int) error) (*resultIter, error) {
385	t, err := d.table(table)
386	if err != nil {
387		return nil, err
388	}
389
390	t.mu.Lock()
391	defer t.mu.Unlock()
392
393	colIndexes, err := t.colIndexes(cols)
394	if err != nil {
395		return nil, err
396	}
397
398	ri := &resultIter{}
399	for _, i := range colIndexes {
400		ri.Cols = append(ri.Cols, t.cols[i])
401	}
402	return ri, f(t, ri, colIndexes)
403}
404
405func (d *database) Read(tbl string, cols []string, keys []*structpb.ListValue, limit int64) (*resultIter, error) {
406	return d.readTable(tbl, cols, func(t *table, ri *resultIter, colIndexes []int) error {
407		for _, key := range keys {
408			pk, err := t.primaryKey(key.Values)
409			if err != nil {
410				return err
411			}
412			// Not an error if the key does not exist.
413			rowNum, found := t.rowForPK(pk)
414			if !found {
415				continue
416			}
417			ri.add(t.rows[rowNum], colIndexes)
418			if limit > 0 && len(ri.rows) >= int(limit) {
419				break
420			}
421		}
422		return nil
423	})
424}
425
426func (d *database) ReadAll(tbl string, cols []string, limit int64) (*resultIter, error) {
427	return d.readTable(tbl, cols, func(t *table, ri *resultIter, colIndexes []int) error {
428		for _, r := range t.rows {
429			ri.add(r, colIndexes)
430			if limit > 0 && len(ri.rows) >= int(limit) {
431				break
432			}
433		}
434		return nil
435	})
436}
437
438type queryParams map[string]interface{}
439
440func (d *database) Query(q spansql.Query, params queryParams) (*resultIter, error) {
441	// If there's an ORDER BY clause, prepare the list of auxiliary data we need.
442	// This is provided to evalSelect to evaluate with each row.
443	var aux []spansql.Expr
444	var desc []bool
445	if len(q.Order) > 0 {
446		if len(q.Select.From) == 0 {
447			return nil, fmt.Errorf("ORDER BY doesn't work without a table")
448		}
449
450		for _, o := range q.Order {
451			aux = append(aux, o.Expr)
452			desc = append(desc, o.Desc)
453		}
454	}
455
456	ri, err := d.evalSelect(q.Select, params, aux)
457	if err != nil {
458		return nil, err
459	}
460	if len(q.Order) > 0 {
461		sort.Slice(ri.rows, func(one, two int) bool {
462			r1, r2 := ri.rows[one], ri.rows[two]
463			for i := range r1.aux {
464				cmp := compareVals(r1.aux[i], r2.aux[i])
465				if desc[i] {
466					cmp = -cmp
467				}
468				if cmp == 0 {
469					continue
470				}
471				return cmp < 0
472			}
473			return false
474		})
475	}
476	if q.Limit != nil {
477		lim, err := evalLimit(q.Limit, params)
478		if err != nil {
479			return nil, err
480		}
481		if n := int(lim); n < len(ri.rows) {
482			ri.rows = ri.rows[:n]
483		}
484	}
485	return ri, nil
486}
487
488func (t *table) addColumn(cd spansql.ColumnDef) *status.Status {
489	t.mu.Lock()
490	defer t.mu.Unlock()
491
492	if len(t.rows) > 0 {
493		if cd.NotNull {
494			// TODO: what happens in this case?
495			return status.Newf(codes.Unimplemented, "can't add NOT NULL columns to non-empty tables yet")
496		}
497		for i := range t.rows {
498			t.rows[i] = append(t.rows[i], nil)
499		}
500	}
501
502	t.cols = append(t.cols, colInfo{
503		Name: cd.Name,
504		Type: cd.Type,
505	})
506	t.colIndex[cd.Name] = len(t.cols) - 1
507
508	return nil
509}
510
511func (t *table) insertRow(rowNum int, r row) {
512	t.rows = append(t.rows, nil)
513	copy(t.rows[rowNum+1:], t.rows[rowNum:])
514	t.rows[rowNum] = r
515}
516
517// findRange finds the rows included in the key range,
518// reporting it as a half-open interval.
519// r.startKey and r.endKey should be populated.
520func (t *table) findRange(r *keyRange) (int, int) {
521	// TODO: This is incorrect for primary keys with descending order.
522	// It might be sufficient for the caller to switch start/end in that case.
523
524	// startRow is the first row matching the range.
525	startRow := sort.Search(len(t.rows), func(i int) bool {
526		return rowCmp(r.startKey, t.rows[i][:t.pkCols]) <= 0
527	})
528	if startRow == len(t.rows) {
529		return startRow, startRow
530	}
531	if !r.startClosed && rowCmp(r.startKey, t.rows[startRow][:t.pkCols]) == 0 {
532		startRow++
533	}
534
535	// endRow is one more than the last row matching the range.
536	endRow := sort.Search(len(t.rows), func(i int) bool {
537		return rowCmp(r.endKey, t.rows[i][:t.pkCols]) < 0
538	})
539	if !r.endClosed && rowCmp(r.endKey, t.rows[endRow-1][:t.pkCols]) == 0 {
540		endRow--
541	}
542
543	return startRow, endRow
544}
545
546// colIndexes returns the indexes for the named columns.
547func (t *table) colIndexes(cols []string) ([]int, error) {
548	var is []int
549	for _, col := range cols {
550		i, ok := t.colIndex[col]
551		if !ok {
552			return nil, status.Errorf(codes.InvalidArgument, "column %s not in table", col)
553		}
554		is = append(is, i)
555	}
556	return is, nil
557}
558
559// primaryKey constructs the internal representation of a primary key.
560// The list of given values must be in 1:1 correspondence with the primary key of the table.
561func (t *table) primaryKey(values []*structpb.Value) ([]interface{}, error) {
562	if len(values) != t.pkCols {
563		return nil, status.Errorf(codes.InvalidArgument, "primary key length mismatch: got %d values, table has %d", len(values), t.pkCols)
564	}
565	return t.primaryKeyPrefix(values)
566}
567
568// primaryKeyPrefix constructs the internal representation of a primary key prefix.
569func (t *table) primaryKeyPrefix(values []*structpb.Value) ([]interface{}, error) {
570	if len(values) > t.pkCols {
571		return nil, status.Errorf(codes.InvalidArgument, "primary key length too long: got %d values, table has %d", len(values), t.pkCols)
572	}
573
574	var pk []interface{}
575	for i, value := range values {
576		v, err := valForType(value, t.cols[i].Type)
577		if err != nil {
578			return nil, err
579		}
580		pk = append(pk, v)
581	}
582	return pk, nil
583}
584
585// rowForPK returns the index of t.rows that holds the row for the given primary key, and true.
586// If the given primary key isn't found, it returns the row that should hold it, and false.
587func (t *table) rowForPK(pk []interface{}) (row int, found bool) {
588	if len(pk) != t.pkCols {
589		panic(fmt.Sprintf("primary key length mismatch: got %d values, table has %d", len(pk), t.pkCols))
590	}
591
592	i := sort.Search(len(t.rows), func(i int) bool {
593		return rowCmp(pk, t.rows[i][:t.pkCols]) <= 0
594	})
595	if i == len(t.rows) {
596		return i, false
597	}
598	return i, rowCmp(pk, t.rows[i][:t.pkCols]) == 0
599}
600
601// rowCmp compares two rows, returning -1/0/+1.
602// This is used for primary key matching and so doesn't support array/struct types.
603// a is permitted to be shorter than b.
604func rowCmp(a, b []interface{}) int {
605	for i := 0; i < len(a); i++ {
606		if cmp := compareVals(a[i], b[i]); cmp != 0 {
607			return cmp
608		}
609	}
610	return 0
611}
612
613func valForType(v *structpb.Value, t spansql.Type) (interface{}, error) {
614	if _, ok := v.Kind.(*structpb.Value_NullValue); ok {
615		// TODO: enforce NOT NULL constraints?
616		return nil, nil
617	}
618
619	if lv, ok := v.Kind.(*structpb.Value_ListValue); ok && t.Array {
620		et := t // element type
621		et.Array = false
622
623		// Construct the non-nil slice for the list.
624		arr := make([]interface{}, 0, len(lv.ListValue.Values))
625		for _, v := range lv.ListValue.Values {
626			x, err := valForType(v, et)
627			if err != nil {
628				return nil, err
629			}
630			arr = append(arr, x)
631		}
632		return arr, nil
633	}
634
635	switch t.Base {
636	case spansql.Bool:
637		bv, ok := v.Kind.(*structpb.Value_BoolValue)
638		if ok {
639			return bv.BoolValue, nil
640		}
641	case spansql.Int64:
642		// The Spanner protocol encodes int64 as a decimal string.
643		sv, ok := v.Kind.(*structpb.Value_StringValue)
644		if ok {
645			x, err := strconv.ParseInt(sv.StringValue, 10, 64)
646			if err != nil {
647				return nil, fmt.Errorf("bad int64 string %q: %v", sv.StringValue, err)
648			}
649			return x, nil
650		}
651	case spansql.Float64:
652		nv, ok := v.Kind.(*structpb.Value_NumberValue)
653		if ok {
654			return nv.NumberValue, nil
655		}
656	case spansql.String:
657		sv, ok := v.Kind.(*structpb.Value_StringValue)
658		if ok {
659			return sv.StringValue, nil
660		}
661	case spansql.Bytes:
662		sv, ok := v.Kind.(*structpb.Value_StringValue)
663		if ok {
664			// The Spanner protocol encodes BYTES in base64.
665			return base64.StdEncoding.DecodeString(sv.StringValue)
666		}
667	case spansql.Date:
668		// The Spanner protocol encodes DATE in RFC 3339 date format.
669		sv, ok := v.Kind.(*structpb.Value_StringValue)
670		if ok {
671			// Store it internally as a string, but validate its value.
672			s := sv.StringValue
673			if _, err := time.Parse("2006-01-02", s); err != nil {
674				return nil, fmt.Errorf("bad DATE string %q: %v", s, err)
675			}
676			return s, nil
677		}
678	}
679	return nil, fmt.Errorf("unsupported inserting value kind %T into column of type %s", v.Kind, t.SQL())
680}
681
682type keyRange struct {
683	start, end             *structpb.ListValue
684	startClosed, endClosed bool
685
686	// These are populated during an operation
687	// when we know what table this keyRange applies to.
688	startKey, endKey []interface{}
689}
690
691func (r *keyRange) String() string {
692	var sb bytes.Buffer // TODO: Switch to strings.Builder when we drop support for Go 1.9.
693	if r.startClosed {
694		sb.WriteString("[")
695	} else {
696		sb.WriteString("(")
697	}
698	fmt.Fprintf(&sb, "%v,%v", r.startKey, r.endKey)
699	if r.endClosed {
700		sb.WriteString("]")
701	} else {
702		sb.WriteString(")")
703	}
704	return sb.String()
705}
706
707type keyRangeList []*keyRange
708