1// Copyright 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package testutil_test_test
16
17import (
18	"strconv"
19
20	. "cloud.google.com/go/spanner/internal/testutil"
21
22	"context"
23	"flag"
24	"fmt"
25	"log"
26	"net"
27	"os"
28	"strings"
29	"testing"
30
31	structpb "github.com/golang/protobuf/ptypes/struct"
32	spannerpb "google.golang.org/genproto/googleapis/spanner/v1"
33	"google.golang.org/grpc/codes"
34
35	apiv1 "cloud.google.com/go/spanner/apiv1"
36	"google.golang.org/api/iterator"
37	"google.golang.org/api/option"
38	"google.golang.org/grpc"
39
40	gstatus "google.golang.org/grpc/status"
41)
42
43// clientOpt is the option tests should use to connect to the test server.
44// It is initialized by TestMain.
45var serverAddress string
46var clientOpt option.ClientOption
47var testSpanner InMemSpannerServer
48
49// Mocked selectSQL statement.
50const selectSQL = "SELECT FOO FROM BAR"
51const selectRowCount int64 = 2
52const selectColCount int = 1
53
54var selectValues = [...]int64{1, 2}
55
56// Mocked DML statement.
57const updateSQL = "UPDATE FOO SET BAR=1 WHERE ID=ID"
58const updateRowCount int64 = 2
59
60func TestMain(m *testing.M) {
61	flag.Parse()
62
63	testSpanner = NewInMemSpannerServer()
64	serv := grpc.NewServer()
65	spannerpb.RegisterSpannerServer(serv, testSpanner)
66
67	lis, err := net.Listen("tcp", "localhost:0")
68	if err != nil {
69		log.Fatal(err)
70	}
71	go serv.Serve(lis)
72
73	serverAddress = lis.Addr().String()
74	conn, err := grpc.Dial(serverAddress, grpc.WithInsecure())
75	if err != nil {
76		log.Fatal(err)
77	}
78	clientOpt = option.WithGRPCConn(conn)
79
80	os.Exit(m.Run())
81}
82
83// Resets the mock server to its default values and registers a mocked result
84// for the statements "SELECT FOO FROM BAR" and
85// "UPDATE FOO SET BAR=1 WHERE ID=ID".
86func setup() {
87	testSpanner.Reset()
88	fields := make([]*spannerpb.StructType_Field, selectColCount)
89	fields[0] = &spannerpb.StructType_Field{
90		Name: "FOO",
91		Type: &spannerpb.Type{Code: spannerpb.TypeCode_INT64},
92	}
93	rowType := &spannerpb.StructType{
94		Fields: fields,
95	}
96	metadata := &spannerpb.ResultSetMetadata{
97		RowType: rowType,
98	}
99	rows := make([]*structpb.ListValue, selectRowCount)
100	for idx, value := range selectValues {
101		rowValue := make([]*structpb.Value, selectColCount)
102		rowValue[0] = &structpb.Value{
103			Kind: &structpb.Value_StringValue{StringValue: strconv.FormatInt(value, 10)},
104		}
105		rows[idx] = &structpb.ListValue{
106			Values: rowValue,
107		}
108	}
109	resultSet := &spannerpb.ResultSet{
110		Metadata: metadata,
111		Rows:     rows,
112	}
113	result := &StatementResult{Type: StatementResultResultSet, ResultSet: resultSet}
114	testSpanner.PutStatementResult(selectSQL, result)
115
116	updateResult := &StatementResult{Type: StatementResultUpdateCount, UpdateCount: updateRowCount}
117	testSpanner.PutStatementResult(updateSQL, updateResult)
118}
119
120func TestSpannerCreateSession(t *testing.T) {
121	testSpanner.Reset()
122	var expectedName = fmt.Sprintf("projects/%s/instances/%s/databases/%s/sessions/", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
123	var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
124	var request = &spannerpb.CreateSessionRequest{
125		Database: formattedDatabase,
126	}
127
128	c, err := apiv1.NewClient(context.Background(), clientOpt)
129	if err != nil {
130		t.Fatal(err)
131	}
132	resp, err := c.CreateSession(context.Background(), request)
133	if err != nil {
134		t.Fatal(err)
135	}
136	if strings.Index(resp.Name, expectedName) != 0 {
137		t.Errorf("Session name mismatch\nGot: %s\nWant: Name should start with %s)", resp.Name, expectedName)
138	}
139}
140
141func TestSpannerCreateSession_Unavailable(t *testing.T) {
142	testSpanner.Reset()
143	var expectedName = fmt.Sprintf("projects/%s/instances/%s/databases/%s/sessions/", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
144	var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
145	var request = &spannerpb.CreateSessionRequest{
146		Database: formattedDatabase,
147	}
148
149	c, err := apiv1.NewClient(context.Background(), clientOpt)
150	if err != nil {
151		t.Fatal(err)
152	}
153	testSpanner.SetError(gstatus.Error(codes.Unavailable, "Temporary unavailable"))
154	resp, err := c.CreateSession(context.Background(), request)
155	if err != nil {
156		t.Fatal(err)
157	}
158	if strings.Index(resp.Name, expectedName) != 0 {
159		t.Errorf("Session name mismatch\nGot: %s\nWant: Name should start with %s)", resp.Name, expectedName)
160	}
161}
162
163func TestSpannerGetSession(t *testing.T) {
164	testSpanner.Reset()
165	var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
166	var createRequest = &spannerpb.CreateSessionRequest{
167		Database: formattedDatabase,
168	}
169
170	c, err := apiv1.NewClient(context.Background(), clientOpt)
171	if err != nil {
172		t.Fatal(err)
173	}
174	createResp, err := c.CreateSession(context.Background(), createRequest)
175	if err != nil {
176		t.Fatal(err)
177	}
178	var getRequest = &spannerpb.GetSessionRequest{
179		Name: createResp.Name,
180	}
181	getResp, err := c.GetSession(context.Background(), getRequest)
182	if err != nil {
183		t.Fatal(err)
184	}
185	if getResp.Name != getRequest.Name {
186		t.Errorf("Session name mismatch\nGot: %s\nWant: Name should start with %s)", getResp.Name, getRequest.Name)
187	}
188}
189
190func TestSpannerListSessions(t *testing.T) {
191	testSpanner.Reset()
192	const expectedNumberOfSessions = 5
193	var expectedName = fmt.Sprintf("projects/%s/instances/%s/databases/%s/sessions/", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
194	var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
195	var createRequest = &spannerpb.CreateSessionRequest{
196		Database: formattedDatabase,
197	}
198
199	c, err := apiv1.NewClient(context.Background(), clientOpt)
200	if err != nil {
201		t.Fatal(err)
202	}
203	for i := 0; i < expectedNumberOfSessions; i++ {
204		_, err := c.CreateSession(context.Background(), createRequest)
205		if err != nil {
206			t.Fatal(err)
207		}
208	}
209	var listRequest = &spannerpb.ListSessionsRequest{
210		Database: formattedDatabase,
211	}
212	var sessionCount int
213	listResp := c.ListSessions(context.Background(), listRequest)
214	for {
215		session, err := listResp.Next()
216		if err == iterator.Done {
217			break
218		}
219		if err != nil {
220			t.Fatal(err)
221		}
222		if strings.Index(session.Name, expectedName) != 0 {
223			t.Errorf("Session name mismatch\nGot: %s\nWant: Name should start with %s)", session.Name, expectedName)
224		}
225		sessionCount++
226	}
227	if sessionCount != expectedNumberOfSessions {
228		t.Errorf("Session count mismatch\nGot: %d\nWant: %d", sessionCount, expectedNumberOfSessions)
229	}
230}
231
232func TestSpannerDeleteSession(t *testing.T) {
233	testSpanner.Reset()
234	const expectedNumberOfSessions = 5
235	var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
236	var createRequest = &spannerpb.CreateSessionRequest{
237		Database: formattedDatabase,
238	}
239
240	c, err := apiv1.NewClient(context.Background(), clientOpt)
241	if err != nil {
242		t.Fatal(err)
243	}
244	for i := 0; i < expectedNumberOfSessions; i++ {
245		_, err := c.CreateSession(context.Background(), createRequest)
246		if err != nil {
247			t.Fatal(err)
248		}
249	}
250	var listRequest = &spannerpb.ListSessionsRequest{
251		Database: formattedDatabase,
252	}
253	var sessionCount int
254	listResp := c.ListSessions(context.Background(), listRequest)
255	for {
256		session, err := listResp.Next()
257		if err == iterator.Done {
258			break
259		}
260		if err != nil {
261			t.Fatal(err)
262		}
263		var deleteRequest = &spannerpb.DeleteSessionRequest{
264			Name: session.Name,
265		}
266		c.DeleteSession(context.Background(), deleteRequest)
267		sessionCount++
268	}
269	if sessionCount != expectedNumberOfSessions {
270		t.Errorf("Session count mismatch\nGot: %d\nWant: %d", sessionCount, expectedNumberOfSessions)
271	}
272	// Re-list all sessions. This should now be empty.
273	listResp = c.ListSessions(context.Background(), listRequest)
274	_, err = listResp.Next()
275	if err != iterator.Done {
276		t.Errorf("expected empty session iterator")
277	}
278}
279
280func TestSpannerExecuteSql(t *testing.T) {
281	setup()
282	c, err := apiv1.NewClient(context.Background(), clientOpt)
283	if err != nil {
284		t.Fatal(err)
285	}
286
287	var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
288	var createRequest = &spannerpb.CreateSessionRequest{
289		Database: formattedDatabase,
290	}
291	session, err := c.CreateSession(context.Background(), createRequest)
292	if err != nil {
293		t.Fatal(err)
294	}
295	request := &spannerpb.ExecuteSqlRequest{
296		Session: session.Name,
297		Sql:     selectSQL,
298		Transaction: &spannerpb.TransactionSelector{
299			Selector: &spannerpb.TransactionSelector_SingleUse{
300				SingleUse: &spannerpb.TransactionOptions{
301					Mode: &spannerpb.TransactionOptions_ReadOnly_{
302						ReadOnly: &spannerpb.TransactionOptions_ReadOnly{
303							ReturnReadTimestamp: false,
304							TimestampBound: &spannerpb.TransactionOptions_ReadOnly_Strong{
305								Strong: true,
306							},
307						},
308					},
309				},
310			},
311		},
312		Seqno:     1,
313		QueryMode: spannerpb.ExecuteSqlRequest_NORMAL,
314	}
315	response, err := c.ExecuteSql(context.Background(), request)
316	if err != nil {
317		t.Fatal(err)
318	}
319	var rowCount int64
320	for _, row := range response.Rows {
321		if len(row.Values) != selectColCount {
322			t.Fatalf("Column count mismatch\nGot: %d\nWant: %d", len(row.Values), selectColCount)
323		}
324		rowCount++
325	}
326	if rowCount != selectRowCount {
327		t.Fatalf("Row count mismatch\nGot: %d\nWant: %d", rowCount, selectRowCount)
328	}
329}
330
331func TestSpannerExecuteSqlDml(t *testing.T) {
332	setup()
333	c, err := apiv1.NewClient(context.Background(), clientOpt)
334	if err != nil {
335		t.Fatal(err)
336	}
337
338	var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
339	var createRequest = &spannerpb.CreateSessionRequest{
340		Database: formattedDatabase,
341	}
342	session, err := c.CreateSession(context.Background(), createRequest)
343	if err != nil {
344		t.Fatal(err)
345	}
346	request := &spannerpb.ExecuteSqlRequest{
347		Session: session.Name,
348		Sql:     updateSQL,
349		Transaction: &spannerpb.TransactionSelector{
350			Selector: &spannerpb.TransactionSelector_Begin{
351				Begin: &spannerpb.TransactionOptions{
352					Mode: &spannerpb.TransactionOptions_ReadWrite_{
353						ReadWrite: &spannerpb.TransactionOptions_ReadWrite{},
354					},
355				},
356			},
357		},
358		Seqno:     1,
359		QueryMode: spannerpb.ExecuteSqlRequest_NORMAL,
360	}
361	response, err := c.ExecuteSql(context.Background(), request)
362	if err != nil {
363		t.Fatal(err)
364	}
365	var rowCount int64 = response.Stats.GetRowCountExact()
366	if rowCount != updateRowCount {
367		t.Fatalf("Update count mismatch\nGot: %d\nWant: %d", rowCount, updateRowCount)
368	}
369}
370
371func TestSpannerExecuteStreamingSql(t *testing.T) {
372	setup()
373	c, err := apiv1.NewClient(context.Background(), clientOpt)
374	if err != nil {
375		t.Fatal(err)
376	}
377
378	var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
379	var createRequest = &spannerpb.CreateSessionRequest{
380		Database: formattedDatabase,
381	}
382	session, err := c.CreateSession(context.Background(), createRequest)
383	if err != nil {
384		t.Fatal(err)
385	}
386	request := &spannerpb.ExecuteSqlRequest{
387		Session: session.Name,
388		Sql:     selectSQL,
389		Transaction: &spannerpb.TransactionSelector{
390			Selector: &spannerpb.TransactionSelector_SingleUse{
391				SingleUse: &spannerpb.TransactionOptions{
392					Mode: &spannerpb.TransactionOptions_ReadOnly_{
393						ReadOnly: &spannerpb.TransactionOptions_ReadOnly{
394							ReturnReadTimestamp: false,
395							TimestampBound: &spannerpb.TransactionOptions_ReadOnly_Strong{
396								Strong: true,
397							},
398						},
399					},
400				},
401			},
402		},
403		Seqno:     1,
404		QueryMode: spannerpb.ExecuteSqlRequest_NORMAL,
405	}
406	response, err := c.ExecuteStreamingSql(context.Background(), request)
407	if err != nil {
408		t.Fatal(err)
409	}
410	var rowIndex int64
411	var colCount int
412	for {
413		for rowIndexInPartial := int64(0); rowIndexInPartial < MaxRowsPerPartialResultSet; rowIndexInPartial++ {
414			partial, err := response.Recv()
415			if err != nil {
416				t.Fatal(err)
417			}
418			if rowIndex == 0 {
419				colCount = len(partial.Metadata.RowType.Fields)
420				if colCount != selectColCount {
421					t.Fatalf("Column count mismatch\nGot: %d\nWant: %d", colCount, selectColCount)
422				}
423			}
424			for col := 0; col < colCount; col++ {
425				pIndex := rowIndexInPartial*int64(colCount) + int64(col)
426				val, err := strconv.ParseInt(partial.Values[pIndex].GetStringValue(), 10, 64)
427				if err != nil {
428					t.Fatalf("Error parsing integer at #%d: %v", pIndex, err)
429				}
430				if val != selectValues[rowIndex] {
431					t.Fatalf("Value mismatch at index %d\nGot: %d\nWant: %d", rowIndex, val, selectValues[rowIndex])
432				}
433			}
434			rowIndex++
435		}
436		if rowIndex == selectRowCount {
437			break
438		}
439	}
440	if rowIndex != selectRowCount {
441		t.Fatalf("Row count mismatch\nGot: %d\nWant: %d", rowIndex, selectRowCount)
442	}
443}
444
445func TestSpannerExecuteBatchDml(t *testing.T) {
446	setup()
447	c, err := apiv1.NewClient(context.Background(), clientOpt)
448	if err != nil {
449		t.Fatal(err)
450	}
451
452	var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
453	var createRequest = &spannerpb.CreateSessionRequest{
454		Database: formattedDatabase,
455	}
456	session, err := c.CreateSession(context.Background(), createRequest)
457	if err != nil {
458		t.Fatal(err)
459	}
460	statements := make([]*spannerpb.ExecuteBatchDmlRequest_Statement, 3)
461	for idx := 0; idx < len(statements); idx++ {
462		statements[idx] = &spannerpb.ExecuteBatchDmlRequest_Statement{Sql: updateSQL}
463	}
464	executeBatchDmlRequest := &spannerpb.ExecuteBatchDmlRequest{
465		Session:    session.Name,
466		Statements: statements,
467		Transaction: &spannerpb.TransactionSelector{
468			Selector: &spannerpb.TransactionSelector_Begin{
469				Begin: &spannerpb.TransactionOptions{
470					Mode: &spannerpb.TransactionOptions_ReadWrite_{
471						ReadWrite: &spannerpb.TransactionOptions_ReadWrite{},
472					},
473				},
474			},
475		},
476		Seqno: 1,
477	}
478	response, err := c.ExecuteBatchDml(context.Background(), executeBatchDmlRequest)
479	if err != nil {
480		t.Fatal(err)
481	}
482	var totalRowCount int64
483	for _, res := range response.ResultSets {
484		var rowCount int64 = res.Stats.GetRowCountExact()
485		if rowCount != updateRowCount {
486			t.Fatalf("Update count mismatch\nGot: %d\nWant: %d", rowCount, updateRowCount)
487		}
488		totalRowCount += rowCount
489	}
490	if totalRowCount != updateRowCount*int64(len(statements)) {
491		t.Fatalf("Total update count mismatch\nGot: %d\nWant: %d", totalRowCount, updateRowCount*int64(len(statements)))
492	}
493}
494
495func TestBeginTransaction(t *testing.T) {
496	setup()
497	c, err := apiv1.NewClient(context.Background(), clientOpt)
498	if err != nil {
499		t.Fatal(err)
500	}
501
502	var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
503	var createRequest = &spannerpb.CreateSessionRequest{
504		Database: formattedDatabase,
505	}
506	session, err := c.CreateSession(context.Background(), createRequest)
507	if err != nil {
508		t.Fatal(err)
509	}
510	beginRequest := &spannerpb.BeginTransactionRequest{
511		Session: session.Name,
512		Options: &spannerpb.TransactionOptions{
513			Mode: &spannerpb.TransactionOptions_ReadWrite_{
514				ReadWrite: &spannerpb.TransactionOptions_ReadWrite{},
515			},
516		},
517	}
518	tx, err := c.BeginTransaction(context.Background(), beginRequest)
519	if err != nil {
520		t.Fatal(err)
521	}
522	expectedName := fmt.Sprintf("%s/transactions/", session.Name)
523	if strings.Index(string(tx.Id), expectedName) != 0 {
524		t.Errorf("Transaction name mismatch\nGot: %s\nWant: Name should start with %s)", string(tx.Id), expectedName)
525	}
526}
527
528func TestCommitTransaction(t *testing.T) {
529	setup()
530	c, err := apiv1.NewClient(context.Background(), clientOpt)
531	if err != nil {
532		t.Fatal(err)
533	}
534
535	var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
536	var createRequest = &spannerpb.CreateSessionRequest{
537		Database: formattedDatabase,
538	}
539	session, err := c.CreateSession(context.Background(), createRequest)
540	if err != nil {
541		t.Fatal(err)
542	}
543	beginRequest := &spannerpb.BeginTransactionRequest{
544		Session: session.Name,
545		Options: &spannerpb.TransactionOptions{
546			Mode: &spannerpb.TransactionOptions_ReadWrite_{
547				ReadWrite: &spannerpb.TransactionOptions_ReadWrite{},
548			},
549		},
550	}
551	tx, err := c.BeginTransaction(context.Background(), beginRequest)
552	if err != nil {
553		t.Fatal(err)
554	}
555	commitRequest := &spannerpb.CommitRequest{
556		Session: session.Name,
557		Transaction: &spannerpb.CommitRequest_TransactionId{
558			TransactionId: tx.Id,
559		},
560	}
561	resp, err := c.Commit(context.Background(), commitRequest)
562	if err != nil {
563		t.Fatal(err)
564	}
565	if resp.CommitTimestamp == nil {
566		t.Fatalf("No commit timestamp returned")
567	}
568}
569
570func TestRollbackTransaction(t *testing.T) {
571	setup()
572	c, err := apiv1.NewClient(context.Background(), clientOpt)
573	if err != nil {
574		t.Fatal(err)
575	}
576
577	var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
578	var createRequest = &spannerpb.CreateSessionRequest{
579		Database: formattedDatabase,
580	}
581	session, err := c.CreateSession(context.Background(), createRequest)
582	if err != nil {
583		t.Fatal(err)
584	}
585	beginRequest := &spannerpb.BeginTransactionRequest{
586		Session: session.Name,
587		Options: &spannerpb.TransactionOptions{
588			Mode: &spannerpb.TransactionOptions_ReadWrite_{
589				ReadWrite: &spannerpb.TransactionOptions_ReadWrite{},
590			},
591		},
592	}
593	tx, err := c.BeginTransaction(context.Background(), beginRequest)
594	if err != nil {
595		t.Fatal(err)
596	}
597	rollbackRequest := &spannerpb.RollbackRequest{
598		Session:       session.Name,
599		TransactionId: tx.Id,
600	}
601	err = c.Rollback(context.Background(), rollbackRequest)
602	if err != nil {
603		t.Fatal(err)
604	}
605}
606