1/*
2Copyright 2019 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spannertest
18
19import (
20	"io"
21	"reflect"
22	"sync"
23	"testing"
24
25	"google.golang.org/grpc/codes"
26
27	structpb "github.com/golang/protobuf/ptypes/struct"
28
29	"cloud.google.com/go/spanner/spansql"
30)
31
32var stdTestTable = &spansql.CreateTable{
33	Name: "Staff",
34	Columns: []spansql.ColumnDef{
35		{Name: "Tenure", Type: spansql.Type{Base: spansql.Int64}},
36		{Name: "ID", Type: spansql.Type{Base: spansql.Int64}},
37		{Name: "Name", Type: spansql.Type{Base: spansql.String}},
38		{Name: "Cool", Type: spansql.Type{Base: spansql.Bool}},
39		{Name: "Height", Type: spansql.Type{Base: spansql.Float64}},
40	},
41	PrimaryKey: []spansql.KeyPart{{Column: "Name"}, {Column: "ID"}},
42}
43
44func TestTableCreation(t *testing.T) {
45	var db database
46	st := db.ApplyDDL(stdTestTable)
47	if st.Code() != codes.OK {
48		t.Fatalf("Creating table: %v", st.Err())
49	}
50
51	// Snoop inside to check that it was constructed correctly.
52	got, ok := db.tables["Staff"]
53	if !ok {
54		t.Fatal("Table didn't get registered")
55	}
56	want := table{
57		cols: []colInfo{
58			{Name: "Name", Type: spansql.Type{Base: spansql.String}},
59			{Name: "ID", Type: spansql.Type{Base: spansql.Int64}},
60			{Name: "Tenure", Type: spansql.Type{Base: spansql.Int64}},
61			{Name: "Cool", Type: spansql.Type{Base: spansql.Bool}},
62			{Name: "Height", Type: spansql.Type{Base: spansql.Float64}},
63		},
64		colIndex: map[string]int{
65			"Tenure": 2, "ID": 1, "Cool": 3, "Name": 0, "Height": 4,
66		},
67		pkCols: 2,
68	}
69	if !reflect.DeepEqual(got.cols, want.cols) {
70		t.Errorf("table.cols incorrect.\n got %v\nwant %v", got.cols, want.cols)
71	}
72	if !reflect.DeepEqual(got.colIndex, want.colIndex) {
73		t.Errorf("table.colIndex incorrect.\n got %v\nwant %v", got.colIndex, want.colIndex)
74	}
75	if got.pkCols != want.pkCols {
76		t.Errorf("table.pkCols incorrect.\n got %d\nwant %d", got.pkCols, want.pkCols)
77	}
78}
79
80func TestTableData(t *testing.T) {
81	var db database
82	st := db.ApplyDDL(stdTestTable)
83	if st.Code() != codes.OK {
84		t.Fatalf("Creating table: %v", st.Err())
85	}
86
87	// Insert a subset of columns.
88	tx := db.NewTransaction()
89	tx.Start()
90	err := db.Insert(tx, "Staff", []string{"ID", "Name", "Tenure", "Height"}, []*structpb.ListValue{
91		// int64 arrives as a decimal string.
92		listV(stringV("1"), stringV("Jack"), stringV("10"), floatV(1.85)),
93		listV(stringV("2"), stringV("Daniel"), stringV("11"), floatV(1.83)),
94	})
95	if err != nil {
96		t.Fatalf("Inserting data: %v", err)
97	}
98	// Insert a different set of columns.
99	err = db.Insert(tx, "Staff", []string{"Name", "ID", "Cool", "Tenure", "Height"}, []*structpb.ListValue{
100		listV(stringV("Sam"), stringV("3"), boolV(false), stringV("9"), floatV(1.75)),
101		listV(stringV("Teal'c"), stringV("4"), boolV(true), stringV("8"), floatV(1.91)),
102		listV(stringV("George"), stringV("5"), nullV(), stringV("6"), floatV(1.73)),
103		listV(stringV("Harry"), stringV("6"), boolV(true), nullV(), nullV()),
104	})
105	if err != nil {
106		t.Fatalf("Inserting more data: %v", err)
107	}
108	// Delete that last one.
109	err = db.Delete(tx, "Staff", []*structpb.ListValue{listV(stringV("Harry"), stringV("6"))}, nil, false)
110	if err != nil {
111		t.Fatalf("Deleting a row: %v", err)
112	}
113	// Turns out this guy isn't cool after all.
114	err = db.Update(tx, "Staff", []string{"Name", "ID", "Cool"}, []*structpb.ListValue{
115		// Missing columns should be left alone.
116		listV(stringV("Daniel"), stringV("2"), boolV(false)),
117	})
118	if err != nil {
119		t.Fatalf("Updating a row: %v", err)
120	}
121	if _, err := tx.Commit(); err != nil {
122		t.Fatalf("Committing changes: %v", err)
123	}
124
125	// Read some specific keys.
126	ri, err := db.Read("Staff", []string{"Name", "Tenure"}, []*structpb.ListValue{
127		listV(stringV("George"), stringV("5")),
128		listV(stringV("Harry"), stringV("6")), // Missing key should be silently ignored.
129		listV(stringV("Sam"), stringV("3")),
130		listV(stringV("George"), stringV("5")), // Duplicate key should be silently ignored.
131	}, nil, 0)
132	if err != nil {
133		t.Fatalf("Reading keys: %v", err)
134	}
135	all := slurp(t, ri)
136	wantAll := [][]interface{}{
137		{"George", int64(6)},
138		{"Sam", int64(9)},
139	}
140	if !reflect.DeepEqual(all, wantAll) {
141		t.Errorf("Read data by keys wrong.\n got %v\nwant %v", all, wantAll)
142	}
143	// Read the same, but by key range.
144	ri, err = db.Read("Staff", []string{"Name", "Tenure"}, nil, keyRangeList{
145		{start: listV(stringV("Gabriel")), end: listV(stringV("Harpo"))}, // open/open
146		{
147			// closed/open
148			start:       listV(stringV("Sam"), stringV("3")),
149			startClosed: true,
150			end: listV(stringV("Teal'c"),
151				stringV("4")),
152		},
153	}, 0)
154	if err != nil {
155		t.Fatalf("Reading key ranges: %v", err)
156	}
157	all = slurp(t, ri)
158	if !reflect.DeepEqual(all, wantAll) {
159		t.Errorf("Read data by key ranges wrong.\n got %v\nwant %v", all, wantAll)
160	}
161
162	// Read a subset of all rows, with a limit.
163	ri, err = db.ReadAll("Staff", []string{"Tenure", "Name", "Height"}, 4)
164	if err != nil {
165		t.Fatalf("ReadAll: %v", err)
166	}
167	wantCols := []colInfo{
168		{Name: "Tenure", Type: spansql.Type{Base: spansql.Int64}},
169		{Name: "Name", Type: spansql.Type{Base: spansql.String}},
170		{Name: "Height", Type: spansql.Type{Base: spansql.Float64}},
171	}
172	if !reflect.DeepEqual(ri.Cols(), wantCols) {
173		t.Errorf("ReadAll cols wrong.\n got %v\nwant %v", ri.Cols(), wantCols)
174	}
175	all = slurp(t, ri)
176	wantAll = [][]interface{}{
177		// Primary key is (Name, ID), so results should come back sorted by Name then ID.
178		{int64(11), "Daniel", 1.83},
179		{int64(6), "George", 1.73},
180		{int64(10), "Jack", 1.85},
181		{int64(9), "Sam", 1.75},
182	}
183	if !reflect.DeepEqual(all, wantAll) {
184		t.Errorf("ReadAll data wrong.\n got %v\nwant %v", all, wantAll)
185	}
186
187	// Add DATE and TIMESTAMP columns, and populate them with some data.
188	st = db.ApplyDDL(&spansql.AlterTable{
189		Name: "Staff",
190		Alteration: spansql.AddColumn{Def: spansql.ColumnDef{
191			Name: "FirstSeen",
192			Type: spansql.Type{Base: spansql.Date},
193		}},
194	})
195	if st.Code() != codes.OK {
196		t.Fatalf("Adding column: %v", st.Err())
197	}
198	st = db.ApplyDDL(&spansql.AlterTable{
199		Name: "Staff",
200		Alteration: spansql.AddColumn{Def: spansql.ColumnDef{
201			Name: "To", // keyword; will need quoting in queries
202			Type: spansql.Type{Base: spansql.Timestamp},
203		}},
204	})
205	if st.Code() != codes.OK {
206		t.Fatalf("Adding column: %v", st.Err())
207	}
208	tx = db.NewTransaction()
209	tx.Start()
210	err = db.Update(tx, "Staff", []string{"Name", "ID", "FirstSeen", "To"}, []*structpb.ListValue{
211		listV(stringV("Jack"), stringV("1"), stringV("1994-10-28"), nullV()),
212		listV(stringV("Daniel"), stringV("2"), stringV("1994-10-28"), nullV()),
213		listV(stringV("George"), stringV("5"), stringV("1997-07-27"), stringV("2008-07-29T11:22:43Z")),
214	})
215	if err != nil {
216		t.Fatalf("Updating rows: %v", err)
217	}
218	if _, err := tx.Commit(); err != nil {
219		t.Fatalf("Committing changes: %v", err)
220	}
221
222	// Add some more data, then delete it with a KeyRange.
223	// The queries below ensure that this was all deleted.
224	tx = db.NewTransaction()
225	tx.Start()
226	err = db.Insert(tx, "Staff", []string{"Name", "ID"}, []*structpb.ListValue{
227		listV(stringV("01"), stringV("1")),
228		listV(stringV("03"), stringV("3")),
229		listV(stringV("06"), stringV("6")),
230	})
231	if err != nil {
232		t.Fatalf("Inserting data: %v", err)
233	}
234	err = db.Delete(tx, "Staff", nil, keyRangeList{{
235		start:       listV(stringV("01"), stringV("1")),
236		startClosed: true,
237		end:         listV(stringV("9")),
238	}}, false)
239	if err != nil {
240		t.Fatalf("Deleting key range: %v", err)
241	}
242	if _, err := tx.Commit(); err != nil {
243		t.Fatalf("Committing changes: %v", err)
244	}
245	// Re-add the data and delete with DML.
246	err = db.Insert(tx, "Staff", []string{"Name", "ID"}, []*structpb.ListValue{
247		listV(stringV("01"), stringV("1")),
248		listV(stringV("03"), stringV("3")),
249		listV(stringV("06"), stringV("6")),
250	})
251	if err != nil {
252		t.Fatalf("Inserting data: %v", err)
253	}
254	n, err := db.Execute(&spansql.Delete{
255		Table: "Staff",
256		Where: spansql.LogicalOp{
257			LHS: spansql.ComparisonOp{
258				LHS: spansql.ID("Name"),
259				Op:  spansql.Ge,
260				RHS: spansql.Param("min"),
261			},
262			Op: spansql.And,
263			RHS: spansql.ComparisonOp{
264				LHS: spansql.ID("Name"),
265				Op:  spansql.Lt,
266				RHS: spansql.Param("max"),
267			},
268		},
269	}, queryParams{
270		"min": stringParam("01"),
271		"max": stringParam("07"),
272	})
273	if err != nil {
274		t.Fatalf("Deleting with DML: %v", err)
275	}
276	if n != 3 {
277		t.Errorf("Deleting with DML affected %d rows, want 3", n)
278	}
279
280	// Add a BYTES column, and populate it with some data.
281	st = db.ApplyDDL(&spansql.AlterTable{
282		Name: "Staff",
283		Alteration: spansql.AddColumn{Def: spansql.ColumnDef{
284			Name: "RawBytes",
285			Type: spansql.Type{Base: spansql.Bytes, Len: spansql.MaxLen},
286		}},
287	})
288	if st.Code() != codes.OK {
289		t.Fatalf("Adding column: %v", st.Err())
290	}
291	tx = db.NewTransaction()
292	tx.Start()
293	err = db.Update(tx, "Staff", []string{"Name", "ID", "RawBytes"}, []*structpb.ListValue{
294		// bytes {0x01 0x00 0x01} encode as base-64 AQAB.
295		listV(stringV("Jack"), stringV("1"), stringV("AQAB")),
296	})
297	if err != nil {
298		t.Fatalf("Updating rows: %v", err)
299	}
300	if _, err := tx.Commit(); err != nil {
301		t.Fatalf("Committing changes: %v", err)
302	}
303
304	// Prepare the sample tables from the Cloud Spanner docs.
305	// https://cloud.google.com/spanner/docs/query-syntax#appendix-a-examples-with-sample-data
306	for _, ct := range []*spansql.CreateTable{
307		// TODO: Roster, TeamMascot when we implement JOINs.
308		{
309			Name: "PlayerStats",
310			Columns: []spansql.ColumnDef{
311				{Name: "LastName", Type: spansql.Type{Base: spansql.String}},
312				{Name: "OpponentID", Type: spansql.Type{Base: spansql.Int64}},
313				{Name: "PointsScored", Type: spansql.Type{Base: spansql.Int64}},
314			},
315			PrimaryKey: []spansql.KeyPart{{Column: "LastName"}, {Column: "OpponentID"}}, // TODO: is this right?
316		},
317	} {
318		st := db.ApplyDDL(ct)
319		if st.Code() != codes.OK {
320			t.Fatalf("Creating table: %v", st.Err())
321		}
322	}
323	tx = db.NewTransaction()
324	tx.Start()
325	err = db.Insert(tx, "PlayerStats", []string{"LastName", "OpponentID", "PointsScored"}, []*structpb.ListValue{
326		listV(stringV("Adams"), stringV("51"), stringV("3")),
327		listV(stringV("Buchanan"), stringV("77"), stringV("0")),
328		listV(stringV("Coolidge"), stringV("77"), stringV("1")),
329		listV(stringV("Adams"), stringV("52"), stringV("4")),
330		listV(stringV("Buchanan"), stringV("50"), stringV("13")),
331	})
332	if err != nil {
333		t.Fatalf("Inserting data: %v", err)
334	}
335	if _, err := tx.Commit(); err != nil {
336		t.Fatalf("Commiting changes: %v", err)
337	}
338
339	// Do some complex queries.
340	tests := []struct {
341		q      string
342		params queryParams
343		want   [][]interface{}
344	}{
345		{
346			`SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello"`,
347			nil,
348			[][]interface{}{{int64(17), "sweet", false, nil, []byte("hello")}},
349		},
350		// Check handling of NULL values for the IS operator.
351		// There was a bug that returned errors for some of these cases.
352		{
353			`SELECT @x IS TRUE, @x IS NOT TRUE, @x IS FALSE, @x IS NOT FALSE, @x IS NULL, @x IS NOT NULL`,
354			queryParams{"x": nullParam()},
355			[][]interface{}{
356				{false, true, false, true, true, false},
357			},
358		},
359		{
360			`SELECT Name FROM Staff WHERE Cool`,
361			nil,
362			[][]interface{}{{"Teal'c"}},
363		},
364		{
365			`SELECT ID FROM Staff WHERE Cool IS NOT NULL ORDER BY ID DESC`,
366			nil,
367			[][]interface{}{{int64(4)}, {int64(3)}, {int64(2)}},
368		},
369		{
370			`SELECT Name, Tenure FROM Staff WHERE Cool IS NULL OR Cool ORDER BY Name LIMIT 2`,
371			nil,
372			[][]interface{}{
373				{"George", int64(6)},
374				{"Jack", int64(10)},
375			},
376		},
377		{
378			`SELECT Name, ID + 100 FROM Staff WHERE @min <= Tenure AND Tenure < @lim ORDER BY Cool, Name DESC LIMIT @numResults`,
379			queryParams{"min": intParam(9), "lim": intParam(11), "numResults": intParam(100)},
380			[][]interface{}{
381				{"Jack", int64(101)},
382				{"Sam", int64(103)},
383			},
384		},
385		{
386			// Expression in SELECT list.
387			`SELECT Name, Cool IS NOT NULL FROM Staff WHERE Tenure/2 > 4 ORDER BY NOT Cool, Name`,
388			nil,
389			[][]interface{}{
390				{"Daniel", true}, // Daniel has Cool==true
391				{"Jack", false},  // Jack has NULL Cool
392				{"Sam", true},    // Sam has Cool==false
393			},
394		},
395		{
396			`SELECT Name, Height FROM Staff ORDER BY Height DESC LIMIT 2`,
397			nil,
398			[][]interface{}{
399				{"Teal'c", 1.91},
400				{"Jack", 1.85},
401			},
402		},
403		{
404			`SELECT Name FROM Staff WHERE Name LIKE "J%k" OR Name LIKE "_am"`,
405			nil,
406			[][]interface{}{
407				{"Jack"},
408				{"Sam"},
409			},
410		},
411		{
412			`SELECT Name, Height FROM Staff WHERE Height BETWEEN @min AND @max ORDER BY Height DESC`,
413			queryParams{"min": floatParam(1.75), "max": floatParam(1.85)},
414			[][]interface{}{
415				{"Jack", 1.85},
416				{"Daniel", 1.83},
417				{"Sam", 1.75},
418			},
419		},
420		{
421			`SELECT COUNT(*) FROM Staff WHERE Name < "T"`,
422			nil,
423			[][]interface{}{
424				{int64(4)},
425			},
426		},
427		{
428			// Check that aggregation still works for the empty set.
429			`SELECT COUNT(*) FROM Staff WHERE Name = "Nobody"`,
430			nil,
431			[][]interface{}{
432				{int64(0)},
433			},
434		},
435		{
436			`SELECT * FROM Staff WHERE Name LIKE "S%"`,
437			nil,
438			[][]interface{}{
439				// These are returned in table column order.
440				// Note that the primary key columns get sorted first.
441				{"Sam", int64(3), int64(9), false, 1.75, nil, nil, nil},
442			},
443		},
444		{
445			// Exactly the same as the previous, except with a redundant ORDER BY clause.
446			`SELECT * FROM Staff WHERE Name LIKE "S%" ORDER BY Name`,
447			nil,
448			[][]interface{}{
449				{"Sam", int64(3), int64(9), false, 1.75, nil, nil, nil},
450			},
451		},
452		{
453			`SELECT Name FROM Staff WHERE FirstSeen >= @min`,
454			queryParams{"min": queryParam{Value: "1996-01-01", Type: spansql.Type{Base: spansql.Date}}},
455			[][]interface{}{
456				{"George"},
457			},
458		},
459		{
460			`SELECT RawBytes FROM Staff WHERE RawBytes IS NOT NULL`,
461			nil,
462			[][]interface{}{
463				{[]byte("\x01\x00\x01")},
464			},
465		},
466		{
467			// The keyword "To" needs quoting in queries.
468			"SELECT COUNT(*) FROM Staff WHERE `To` IS NOT NULL",
469			nil,
470			[][]interface{}{
471				{int64(1)},
472			},
473		},
474		{
475			`SELECT DISTINCT Cool, Tenure > 8 FROM Staff`,
476			nil,
477			[][]interface{}{
478				// The non-distinct results are be
479				//          [[false true] [<nil> false] [<nil> true] [false true] [true false]]
480				{false, true},
481				{nil, false},
482				{nil, true},
483				{true, false},
484			},
485		},
486		{
487			`SELECT Name FROM Staff WHERE ID IN UNNEST(@ids)`,
488			queryParams{"ids": queryParam{
489				Value: []interface{}{int64(3), int64(1)},
490				Type:  spansql.Type{Base: spansql.Int64, Array: true},
491			}},
492			[][]interface{}{
493				{"Jack"},
494				{"Sam"},
495			},
496		},
497		// From https://cloud.google.com/spanner/docs/query-syntax#group-by-clause_1:
498		{
499			// TODO: Ordering matters? Our implementation sorts by the GROUP BY key,
500			// but nothing documented seems to guarantee that.
501			`SELECT LastName, SUM(PointsScored) FROM PlayerStats GROUP BY LastName`,
502			nil,
503			[][]interface{}{
504				{"Adams", int64(7)},
505				{"Buchanan", int64(13)},
506				{"Coolidge", int64(1)},
507			},
508		},
509		{
510			// Another GROUP BY, but referring to an alias.
511			// Group by ID oddness, SUM over Tenure.
512			`SELECT ID&0x01 AS odd, SUM(Tenure) FROM Staff GROUP BY odd`,
513			nil,
514			[][]interface{}{
515				{int64(0), int64(19)}, // Daniel(ID=2, Tenure=11), Teal'c(ID=4, Tenure=8)
516				{int64(1), int64(25)}, // Jack(ID=1, Tenure=10), Sam(ID=3, Tenure=9), George(ID=5, Tenure=6)
517			},
518		},
519		{
520			`SELECT ARRAY_AGG(Cool) FROM Staff ORDER BY Name`,
521			nil,
522			[][]interface{}{
523				// Daniel, George (NULL), Jack (NULL), Sam, Teal'c
524				{[]interface{}{false, nil, nil, false, true}},
525			},
526		},
527	}
528	for _, test := range tests {
529		q, err := spansql.ParseQuery(test.q)
530		if err != nil {
531			t.Errorf("ParseQuery(%q): %v", test.q, err)
532			continue
533		}
534		ri, err := db.Query(q, test.params)
535		if err != nil {
536			t.Errorf("Query(%q, %v): %v", test.q, test.params, err)
537			continue
538		}
539		all := slurp(t, ri)
540		if !reflect.DeepEqual(all, test.want) {
541			t.Errorf("Results from Query(%q, %v) are wrong.\n got %v\nwant %v", test.q, test.params, all, test.want)
542		}
543	}
544}
545
546func TestTableDescendingKey(t *testing.T) {
547	var descTestTable = &spansql.CreateTable{
548		Name: "Timeseries",
549		Columns: []spansql.ColumnDef{
550			{Name: "Name", Type: spansql.Type{Base: spansql.String}},
551			{Name: "Observed", Type: spansql.Type{Base: spansql.Int64}},
552			{Name: "Value", Type: spansql.Type{Base: spansql.Float64}},
553		},
554		PrimaryKey: []spansql.KeyPart{{Column: "Name"}, {Column: "Observed", Desc: true}},
555	}
556
557	var db database
558	if st := db.ApplyDDL(descTestTable); st.Code() != codes.OK {
559		t.Fatalf("Creating table: %v", st.Err())
560	}
561
562	tx := db.NewTransaction()
563	tx.Start()
564	err := db.Insert(tx, "Timeseries", []string{"Name", "Observed", "Value"}, []*structpb.ListValue{
565		listV(stringV("box"), stringV("1"), floatV(1.1)),
566		listV(stringV("cupcake"), stringV("1"), floatV(6)),
567		listV(stringV("box"), stringV("2"), floatV(1.2)),
568		listV(stringV("cupcake"), stringV("2"), floatV(7)),
569		listV(stringV("box"), stringV("3"), floatV(1.3)),
570		listV(stringV("cupcake"), stringV("3"), floatV(8)),
571	})
572	if err != nil {
573		t.Fatalf("Inserting data: %v", err)
574	}
575	if _, err := tx.Commit(); err != nil {
576		t.Fatalf("Committing changes: %v", err)
577	}
578
579	// Querying the entire table should return values in key order,
580	// noting that the second key part here is in descending order.
581	q, err := spansql.ParseQuery(`SELECT * FROM Timeseries`)
582	if err != nil {
583		t.Fatalf("ParseQuery: %v", err)
584	}
585	ri, err := db.Query(q, nil)
586	if err != nil {
587		t.Fatalf("Query: %v", err)
588	}
589	got := slurp(t, ri)
590	want := [][]interface{}{
591		{"box", int64(3), 1.3},
592		{"box", int64(2), 1.2},
593		{"box", int64(1), 1.1},
594		{"cupcake", int64(3), 8.0},
595		{"cupcake", int64(2), 7.0},
596		{"cupcake", int64(1), 6.0},
597	}
598	if !reflect.DeepEqual(got, want) {
599		t.Errorf("Results from Query are wrong.\n got %v\nwant %v", got, want)
600	}
601
602	// TestKeyRange exercises the edge cases for key range reading.
603}
604
605func TestTableSchemaConvertNull(t *testing.T) {
606	var db database
607	st := db.ApplyDDL(&spansql.CreateTable{
608		Name: "Songwriters",
609		Columns: []spansql.ColumnDef{
610			{Name: "ID", Type: spansql.Type{Base: spansql.Int64}, NotNull: true},
611			{Name: "Nickname", Type: spansql.Type{Base: spansql.String}},
612		},
613		PrimaryKey: []spansql.KeyPart{{Column: "ID"}},
614	})
615	if err := st.Err(); err != nil {
616		t.Fatal(err)
617	}
618
619	// Populate with data including a NULL for the STRING field.
620	tx := db.NewTransaction()
621	tx.Start()
622	err := db.Insert(tx, "Songwriters", []string{"ID", "Nickname"}, []*structpb.ListValue{
623		listV(stringV("6"), stringV("Tiger")),
624		listV(stringV("7"), nullV()),
625	})
626	if err != nil {
627		t.Fatalf("Inserting data: %v", err)
628	}
629	if _, err := tx.Commit(); err != nil {
630		t.Fatalf("Committing changes: %v", err)
631	}
632
633	// Convert the STRING field to a BYTES and back.
634	st = db.ApplyDDL(&spansql.AlterTable{
635		Name: "Songwriters",
636		Alteration: spansql.AlterColumn{
637			Def: spansql.ColumnDef{Name: "Nickname", Type: spansql.Type{Base: spansql.Bytes}},
638		},
639	})
640	if err := st.Err(); err != nil {
641		t.Fatalf("Converting STRING -> BYTES: %v", err)
642	}
643	st = db.ApplyDDL(&spansql.AlterTable{
644		Name: "Songwriters",
645		Alteration: spansql.AlterColumn{
646			Def: spansql.ColumnDef{Name: "Nickname", Type: spansql.Type{Base: spansql.String}},
647		},
648	})
649	if err := st.Err(); err != nil {
650		t.Fatalf("Converting BYTES -> STRING: %v", err)
651	}
652
653	// Check that the data is maintained.
654	q, err := spansql.ParseQuery(`SELECT * FROM Songwriters`)
655	if err != nil {
656		t.Fatalf("ParseQuery: %v", err)
657	}
658	ri, err := db.Query(q, nil)
659	if err != nil {
660		t.Fatalf("Query: %v", err)
661	}
662	got := slurp(t, ri)
663	want := [][]interface{}{
664		{int64(6), "Tiger"},
665		{int64(7), nil},
666	}
667	if !reflect.DeepEqual(got, want) {
668		t.Errorf("Results from Query are wrong.\n got %v\nwant %v", got, want)
669	}
670}
671
672func TestTableSchemaUpdates(t *testing.T) {
673	tests := []struct {
674		desc     string
675		ddl      string
676		wantCode codes.Code
677	}{
678		// TODO: add more cases, including interactions with the primary key and dropping columns.
679
680		{
681			"Add new column",
682			`CREATE TABLE Songwriters (
683				Id INT64 NOT NULL,
684			) PRIMARY KEY (Id);
685			ALTER TABLE Songwriters ADD COLUMN Nickname STRING(MAX);`,
686			codes.OK,
687		},
688		{
689			"Add new column with NOT NULL",
690			`CREATE TABLE Songwriters (
691				Id INT64 NOT NULL,
692			) PRIMARY KEY (Id);
693			ALTER TABLE Songwriters ADD COLUMN Nickname STRING(MAX) NOT NULL;`,
694			codes.InvalidArgument,
695		},
696
697		// Examples from https://cloud.google.com/spanner/docs/schema-updates:
698
699		{
700			"Add NOT NULL to a non-key column",
701			`CREATE TABLE Songwriters (
702				Id INT64 NOT NULL,
703				Nickname STRING(MAX),
704			) PRIMARY KEY (Id);
705			ALTER TABLE Songwriters ALTER COLUMN Nickname STRING(MAX) NOT NULL;`,
706			codes.OK,
707		},
708		{
709			"Remove NOT NULL from a non-key column",
710			`CREATE TABLE Songwriters (
711				Id INT64 NOT NULL,
712				Nickname STRING(MAX) NOT NULL,
713			) PRIMARY KEY (Id);
714			ALTER TABLE Songwriters ALTER COLUMN Nickname STRING(MAX);`,
715			codes.OK,
716		},
717		{
718			"Change a STRING column to a BYTES column",
719			`CREATE TABLE Songwriters (
720				Id INT64 NOT NULL,
721				Nickname STRING(MAX),
722			) PRIMARY KEY (Id);
723			ALTER TABLE Songwriters ALTER COLUMN Nickname BYTES(MAX);`,
724			codes.OK,
725		},
726		// TODO: Increase or decrease the length limit for a STRING or BYTES type (including to MAX)
727		// TODO: Enable or disable commit timestamps in value and primary key columns
728	}
729testLoop:
730	for _, test := range tests {
731		var db database
732
733		ddl, err := spansql.ParseDDL("filename", test.ddl)
734		if err != nil {
735			t.Fatalf("%s: Bad DDL: %v", test.desc, err)
736		}
737		for _, stmt := range ddl.List {
738			if st := db.ApplyDDL(stmt); st.Code() != codes.OK {
739				if st.Code() != test.wantCode {
740					t.Errorf("%s: Applying statement %q: %v", test.desc, stmt.SQL(), st.Err())
741				}
742				continue testLoop
743			}
744		}
745		if test.wantCode != codes.OK {
746			t.Errorf("%s: Finished with OK, want %v", test.desc, test.wantCode)
747		}
748	}
749}
750
751func TestConcurrentReadInsert(t *testing.T) {
752	// Check that data is safely copied during a query.
753	tbl := &spansql.CreateTable{
754		Name: "Tablino",
755		Columns: []spansql.ColumnDef{
756			{Name: "A", Type: spansql.Type{Base: spansql.Int64}},
757		},
758		PrimaryKey: []spansql.KeyPart{{Column: "A"}},
759	}
760
761	var db database
762	if st := db.ApplyDDL(tbl); st.Code() != codes.OK {
763		t.Fatalf("Creating table: %v", st.Err())
764	}
765
766	// Insert some initial data.
767	tx := db.NewTransaction()
768	tx.Start()
769	err := db.Insert(tx, "Tablino", []string{"A"}, []*structpb.ListValue{
770		listV(stringV("1")),
771		listV(stringV("2")),
772		listV(stringV("4")),
773	})
774	if err != nil {
775		t.Fatalf("Inserting data: %v", err)
776	}
777	if _, err := tx.Commit(); err != nil {
778		t.Fatalf("Committing changes: %v", err)
779	}
780
781	// Now insert "3", and query concurrently.
782	q, err := spansql.ParseQuery(`SELECT * FROM Tablino WHERE A > 2`)
783	if err != nil {
784		t.Fatalf("ParseQuery: %v", err)
785	}
786	var out [][]interface{}
787
788	var wg sync.WaitGroup
789	wg.Add(2)
790	go func() {
791		defer wg.Done()
792
793		ri, err := db.Query(q, nil)
794		if err != nil {
795			t.Errorf("Query: %v", err)
796			return
797		}
798		out = slurp(t, ri)
799	}()
800	go func() {
801		defer wg.Done()
802
803		tx := db.NewTransaction()
804		tx.Start()
805		err := db.Insert(tx, "Tablino", []string{"A"}, []*structpb.ListValue{
806			listV(stringV("3")),
807		})
808		if err != nil {
809			t.Errorf("Inserting data: %v", err)
810			return
811		}
812		if _, err := tx.Commit(); err != nil {
813			t.Errorf("Committing changes: %v", err)
814		}
815	}()
816	wg.Wait()
817
818	// We should get either 1 or 2 rows (value 4 should be included, and value 3 might).
819	if n := len(out); n != 1 && n != 2 {
820		t.Fatalf("Concurrent read returned %d rows, want 1 or 2", n)
821	}
822}
823
824func slurp(t *testing.T, ri rowIter) (all [][]interface{}) {
825	t.Helper()
826	for {
827		row, err := ri.Next()
828		if err == io.EOF {
829			return
830		} else if err != nil {
831			t.Fatalf("Reading rows: %v", err)
832		}
833		all = append(all, row)
834	}
835}
836
837func listV(vs ...*structpb.Value) *structpb.ListValue { return &structpb.ListValue{Values: vs} }
838func stringV(s string) *structpb.Value                { return &structpb.Value{Kind: &structpb.Value_StringValue{s}} }
839func floatV(f float64) *structpb.Value                { return &structpb.Value{Kind: &structpb.Value_NumberValue{f}} }
840func boolV(b bool) *structpb.Value                    { return &structpb.Value{Kind: &structpb.Value_BoolValue{b}} }
841func nullV() *structpb.Value                          { return &structpb.Value{Kind: &structpb.Value_NullValue{}} }
842
843func stringParam(s string) queryParam { return queryParam{Value: s, Type: stringType} }
844func intParam(i int64) queryParam     { return queryParam{Value: i, Type: int64Type} }
845func floatParam(f float64) queryParam { return queryParam{Value: f, Type: float64Type} }
846func nullParam() queryParam           { return queryParam{Value: nil} }
847
848func TestRowCmp(t *testing.T) {
849	r := func(x ...interface{}) []interface{} { return x }
850	tests := []struct {
851		a, b []interface{}
852		desc []bool
853		want int
854	}{
855		{r(int64(1), "foo", 1.6), r(int64(1), "foo", 1.6), []bool{false, false, false}, 0},
856		{r(int64(1), "foo"), r(int64(1), "foo", 1.6), []bool{false, false, false}, 0}, // first is shorter
857
858		{r(int64(1), "bar", 1.8), r(int64(1), "foo", 1.6), []bool{false, false, false}, -1},
859		{r(int64(1), "bar", 1.8), r(int64(1), "foo", 1.6), []bool{false, false, true}, -1},
860		{r(int64(1), "bar", 1.8), r(int64(1), "foo", 1.6), []bool{false, true, false}, 1},
861
862		{r(int64(1), "foo", 1.6), r(int64(1), "bar", 1.8), []bool{false, false, false}, 1},
863		{r(int64(1), "foo", 1.6), r(int64(1), "bar", 1.8), []bool{false, false, true}, 1},
864		{r(int64(1), "foo", 1.6), r(int64(1), "bar", 1.8), []bool{false, true, false}, -1},
865		{r(int64(1), "foo", 1.6), r(int64(1), "bar", 1.8), []bool{false, true, true}, -1},
866	}
867	for _, test := range tests {
868		if got := rowCmp(test.a, test.b, test.desc); got != test.want {
869			t.Errorf("rowCmp(%v, %v, %v) = %d, want %d", test.a, test.b, test.desc, got, test.want)
870		}
871	}
872}
873
874func TestKeyRange(t *testing.T) {
875	r := func(x ...interface{}) []interface{} { return x }
876	closedClosed := func(start, end []interface{}) *keyRange {
877		return &keyRange{
878			startKey:    start,
879			endKey:      end,
880			startClosed: true,
881			endClosed:   true,
882		}
883	}
884	halfOpen := func(start, end []interface{}) *keyRange {
885		return &keyRange{
886			startKey:    start,
887			endKey:      end,
888			startClosed: true,
889		}
890	}
891	openOpen := func(start, end []interface{}) *keyRange {
892		return &keyRange{
893			startKey: start,
894			endKey:   end,
895		}
896	}
897	tests := []struct {
898		kr      *keyRange
899		desc    []bool
900		include [][]interface{}
901		exclude [][]interface{}
902	}{
903		// Examples from google/spanner/v1/keys.proto.
904		{
905			kr: closedClosed(r("Bob", "2015-01-01"), r("Bob", "2015-12-31")),
906			include: [][]interface{}{
907				r("Bob", "2015-01-01"),
908				r("Bob", "2015-07-07"),
909				r("Bob", "2015-12-31"),
910			},
911			exclude: [][]interface{}{
912				r("Alice", "2015-07-07"),
913				r("Bob", "2014-12-31"),
914				r("Bob", "2016-01-01"),
915			},
916		},
917		{
918			kr: closedClosed(r("Bob", "2000-01-01"), r("Bob")),
919			include: [][]interface{}{
920				r("Bob", "2000-01-01"),
921				r("Bob", "2022-07-07"),
922			},
923			exclude: [][]interface{}{
924				r("Alice", "2015-07-07"),
925				r("Bob", "1999-11-07"),
926			},
927		},
928		{
929			kr: closedClosed(r("Bob"), r("Bob")),
930			include: [][]interface{}{
931				r("Bob", "2000-01-01"),
932			},
933			exclude: [][]interface{}{
934				r("Alice", "2015-07-07"),
935				r("Charlie", "1999-11-07"),
936			},
937		},
938		{
939			kr: halfOpen(r("Bob"), r("Bob", "2000-01-01")),
940			include: [][]interface{}{
941				r("Bob", "1999-11-07"),
942			},
943			exclude: [][]interface{}{
944				r("Alice", "1999-11-07"),
945				r("Bob", "2000-01-01"),
946				r("Bob", "2004-07-07"),
947				r("Charlie", "1999-11-07"),
948			},
949		},
950		{
951			kr: openOpen(r("Bob", "1999-11-06"), r("Bob", "2000-01-01")),
952			include: [][]interface{}{
953				r("Bob", "1999-11-07"),
954			},
955			exclude: [][]interface{}{
956				r("Alice", "1999-11-07"),
957				r("Bob", "1999-11-06"),
958				r("Bob", "2000-01-01"),
959				r("Bob", "2004-07-07"),
960				r("Charlie", "1999-11-07"),
961			},
962		},
963		{
964			kr: closedClosed(r(), r()),
965			include: [][]interface{}{
966				r("Alice", "1999-11-07"),
967				r("Bob", "1999-11-07"),
968				r("Charlie", "1999-11-07"),
969			},
970		},
971		{
972			kr: halfOpen(r("A"), r("D")),
973			include: [][]interface{}{
974				r("Alice", "1999-11-07"),
975				r("Bob", "1999-11-07"),
976				r("Charlie", "1999-11-07"),
977			},
978			exclude: [][]interface{}{
979				r("0day", "1999-11-07"),
980				r("Doris", "1999-11-07"),
981			},
982		},
983		// Exercise descending primary key ordering.
984		{
985			kr:   halfOpen(r("Alpha"), r("Charlie")),
986			desc: []bool{true, false},
987			// Key range is backwards, so nothing should be returned.
988			exclude: [][]interface{}{
989				r("Alice", "1999-11-07"),
990				r("Bob", "1999-11-07"),
991				r("Doris", "1999-11-07"),
992			},
993		},
994		{
995			kr:   halfOpen(r("Alice", "1999-11-07"), r("Charlie")),
996			desc: []bool{false, true},
997			// The second primary key column is descending.
998			include: [][]interface{}{
999				r("Alice", "1999-09-09"),
1000				r("Alice", "1999-11-07"),
1001				r("Bob", "2000-01-01"),
1002			},
1003			exclude: [][]interface{}{
1004				r("Alice", "2000-01-01"),
1005				r("Doris", "1999-11-07"),
1006			},
1007		},
1008	}
1009	for _, test := range tests {
1010		desc := test.desc
1011		if desc == nil {
1012			desc = []bool{false, false} // default
1013		}
1014		tbl := &table{
1015			pkCols: 2,
1016			pkDesc: desc,
1017		}
1018		for _, pk := range append(test.include, test.exclude...) {
1019			rowNum, _ := tbl.rowForPK(pk)
1020			tbl.insertRow(rowNum, pk)
1021		}
1022		start, end := tbl.findRange(test.kr)
1023		has := func(pk []interface{}) bool {
1024			n, _ := tbl.rowForPK(pk)
1025			return start <= n && n < end
1026		}
1027		for _, pk := range test.include {
1028			if !has(pk) {
1029				t.Errorf("keyRange %v does not include %v", test.kr, pk)
1030			}
1031		}
1032		for _, pk := range test.exclude {
1033			if has(pk) {
1034				t.Errorf("keyRange %v includes %v", test.kr, pk)
1035			}
1036		}
1037	}
1038}
1039