1/*
2Copyright 2017 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"context"
21	"fmt"
22	"io"
23	"os"
24	"strings"
25	"testing"
26	"time"
27
28	"cloud.google.com/go/civil"
29	itestutil "cloud.google.com/go/internal/testutil"
30	vkit "cloud.google.com/go/spanner/apiv1"
31	. "cloud.google.com/go/spanner/internal/testutil"
32	structpb "github.com/golang/protobuf/ptypes/struct"
33	gax "github.com/googleapis/gax-go/v2"
34	"google.golang.org/api/iterator"
35	"google.golang.org/api/option"
36	sppb "google.golang.org/genproto/googleapis/spanner/v1"
37	"google.golang.org/grpc/codes"
38	"google.golang.org/grpc/status"
39)
40
41func setupMockedTestServer(t *testing.T) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) {
42	return setupMockedTestServerWithConfig(t, ClientConfig{})
43}
44
45func setupMockedTestServerWithConfig(t *testing.T, config ClientConfig) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) {
46	return setupMockedTestServerWithConfigAndClientOptions(t, config, []option.ClientOption{})
47}
48
49func setupMockedTestServerWithConfigAndClientOptions(t *testing.T, config ClientConfig, clientOptions []option.ClientOption) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) {
50	grpcHeaderChecker := &itestutil.HeadersEnforcer{
51		OnFailure: t.Fatalf,
52		Checkers: []*itestutil.HeaderChecker{
53			{
54				Key: "x-goog-api-client",
55				ValuesValidator: func(token ...string) error {
56					if len(token) != 1 {
57						return status.Errorf(codes.Internal, "unexpected number of api client token headers: %v", len(token))
58					}
59					if !strings.HasPrefix(token[0], "gl-go/") {
60						return status.Errorf(codes.Internal, "unexpected api client token: %v", token[0])
61					}
62					if !strings.Contains(token[0], "gccl/") {
63						return status.Errorf(codes.Internal, "unexpected api client token: %v", token[0])
64					}
65					return nil
66				},
67			},
68		},
69	}
70	clientOptions = append(clientOptions, grpcHeaderChecker.CallOptions()...)
71	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
72	opts = append(opts, clientOptions...)
73	ctx := context.Background()
74	formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
75	client, err := NewClientWithConfig(ctx, formattedDatabase, config, opts...)
76	if err != nil {
77		t.Fatal(err)
78	}
79	return server, client, func() {
80		client.Close()
81		serverTeardown()
82	}
83}
84
85// Test validDatabaseName()
86func TestValidDatabaseName(t *testing.T) {
87	validDbURI := "projects/spanner-cloud-test/instances/foo/databases/foodb"
88	invalidDbUris := []string{
89		// Completely wrong DB URI.
90		"foobarDB",
91		// Project ID contains "/".
92		"projects/spanner-cloud/test/instances/foo/databases/foodb",
93		// No instance ID.
94		"projects/spanner-cloud-test/instances//databases/foodb",
95	}
96	if err := validDatabaseName(validDbURI); err != nil {
97		t.Errorf("validateDatabaseName(%q) = %v, want nil", validDbURI, err)
98	}
99	for _, d := range invalidDbUris {
100		if err, wantErr := validDatabaseName(d), "should conform to pattern"; !strings.Contains(err.Error(), wantErr) {
101			t.Errorf("validateDatabaseName(%q) = %q, want error pattern %q", validDbURI, err, wantErr)
102		}
103	}
104}
105
106func TestReadOnlyTransactionClose(t *testing.T) {
107	// Closing a ReadOnlyTransaction shouldn't panic.
108	c := &Client{}
109	tx := c.ReadOnlyTransaction()
110	tx.Close()
111}
112
113func TestClient_Single(t *testing.T) {
114	t.Parallel()
115	err := testSingleQuery(t, nil)
116	if err != nil {
117		t.Fatal(err)
118	}
119}
120
121func TestClient_Single_Unavailable(t *testing.T) {
122	t.Parallel()
123	err := testSingleQuery(t, status.Error(codes.Unavailable, "Temporary unavailable"))
124	if err != nil {
125		t.Fatal(err)
126	}
127}
128
129func TestClient_Single_InvalidArgument(t *testing.T) {
130	t.Parallel()
131	err := testSingleQuery(t, status.Error(codes.InvalidArgument, "Invalid argument"))
132	if status.Code(err) != codes.InvalidArgument {
133		t.Fatalf("got: %v, want: %v", err, codes.InvalidArgument)
134	}
135}
136
137func TestClient_Single_SessionNotFound(t *testing.T) {
138	t.Parallel()
139
140	server, client, teardown := setupMockedTestServer(t)
141	defer teardown()
142	server.TestSpanner.PutExecutionTime(
143		MethodExecuteStreamingSql,
144		SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
145	)
146	ctx := context.Background()
147	iter := client.Single().Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
148	defer iter.Stop()
149	rowCount := int64(0)
150	for {
151		_, err := iter.Next()
152		if err == iterator.Done {
153			break
154		}
155		if err != nil {
156			t.Fatal(err)
157		}
158		rowCount++
159	}
160	if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
161		t.Fatalf("row count mismatch\nGot: %v\nWant: %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
162	}
163}
164
165func TestClient_Single_Read_SessionNotFound(t *testing.T) {
166	t.Parallel()
167
168	server, client, teardown := setupMockedTestServer(t)
169	defer teardown()
170	server.TestSpanner.PutExecutionTime(
171		MethodStreamingRead,
172		SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
173	)
174	ctx := context.Background()
175	iter := client.Single().Read(ctx, "Albums", KeySets(Key{"foo"}), []string{"SingerId", "AlbumId", "AlbumTitle"})
176	defer iter.Stop()
177	rowCount := int64(0)
178	for {
179		_, err := iter.Next()
180		if err == iterator.Done {
181			break
182		}
183		if err != nil {
184			t.Fatal(err)
185		}
186		rowCount++
187	}
188	if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
189		t.Fatalf("row count mismatch\nGot: %v\nWant: %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
190	}
191}
192
193func TestClient_Single_ReadRow_SessionNotFound(t *testing.T) {
194	t.Parallel()
195
196	server, client, teardown := setupMockedTestServer(t)
197	defer teardown()
198	server.TestSpanner.PutExecutionTime(
199		MethodStreamingRead,
200		SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
201	)
202	ctx := context.Background()
203	row, err := client.Single().ReadRow(ctx, "Albums", Key{"foo"}, []string{"SingerId", "AlbumId", "AlbumTitle"})
204	if err != nil {
205		t.Fatalf("Unexpected error for read row: %v", err)
206	}
207	if row == nil {
208		t.Fatal("ReadRow did not return a row")
209	}
210}
211
212func TestClient_Single_RetryableErrorOnPartialResultSet(t *testing.T) {
213	t.Parallel()
214	server, client, teardown := setupMockedTestServer(t)
215	defer teardown()
216
217	// Add two errors that will be returned by the mock server when the client
218	// is trying to fetch a partial result set. Both errors are retryable.
219	// The errors are not 'sticky' on the mocked server, i.e. once the error
220	// has been returned once, the next call for the same partial result set
221	// will succeed.
222
223	// When the client is fetching the partial result set with resume token 2,
224	// the mock server will respond with an internal error with the message
225	// 'stream terminated by RST_STREAM'. The client will retry the call to get
226	// this partial result set.
227	server.TestSpanner.AddPartialResultSetError(
228		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
229		PartialResultSetExecutionTime{
230			ResumeToken: EncodeResumeToken(2),
231			Err:         status.Errorf(codes.Internal, "stream terminated by RST_STREAM"),
232		},
233	)
234	// When the client is fetching the partial result set with resume token 3,
235	// the mock server will respond with a 'Unavailable' error. The client will
236	// retry the call to get this partial result set.
237	server.TestSpanner.AddPartialResultSetError(
238		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
239		PartialResultSetExecutionTime{
240			ResumeToken: EncodeResumeToken(3),
241			Err:         status.Errorf(codes.Unavailable, "server is unavailable"),
242		},
243	)
244	ctx := context.Background()
245	if err := executeSingerQuery(ctx, client.Single()); err != nil {
246		t.Fatal(err)
247	}
248}
249
250func TestClient_Single_NonRetryableErrorOnPartialResultSet(t *testing.T) {
251	t.Parallel()
252	server, client, teardown := setupMockedTestServer(t)
253	defer teardown()
254
255	// Add two errors that will be returned by the mock server when the client
256	// is trying to fetch a partial result set. The first error is retryable,
257	// the second is not.
258
259	// This error will automatically be retried.
260	server.TestSpanner.AddPartialResultSetError(
261		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
262		PartialResultSetExecutionTime{
263			ResumeToken: EncodeResumeToken(2),
264			Err:         status.Errorf(codes.Internal, "stream terminated by RST_STREAM"),
265		},
266	)
267	// 'Session not found' is not retryable and the error will be returned to
268	// the user.
269	server.TestSpanner.AddPartialResultSetError(
270		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
271		PartialResultSetExecutionTime{
272			ResumeToken: EncodeResumeToken(3),
273			Err:         newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s"),
274		},
275	)
276	ctx := context.Background()
277	err := executeSingerQuery(ctx, client.Single())
278	if status.Code(err) != codes.NotFound {
279		t.Fatalf("Error mismatch:\ngot: %v\nwant: %v", err, codes.NotFound)
280	}
281}
282
283func TestClient_Single_NonRetryableInternalErrors(t *testing.T) {
284	t.Parallel()
285	server, client, teardown := setupMockedTestServer(t)
286	defer teardown()
287
288	server.TestSpanner.AddPartialResultSetError(
289		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
290		PartialResultSetExecutionTime{
291			ResumeToken: EncodeResumeToken(2),
292			Err:         status.Errorf(codes.Internal, "grpc: error while marshaling: string field contains invalid UTF-8"),
293		},
294	)
295	ctx := context.Background()
296	err := executeSingerQuery(ctx, client.Single())
297	if status.Code(err) != codes.Internal {
298		t.Fatalf("Error mismatch:\ngot: %v\nwant: %v", err, codes.Internal)
299	}
300}
301
302func TestClient_Single_DeadlineExceeded_NoErrors(t *testing.T) {
303	t.Parallel()
304	server, client, teardown := setupMockedTestServer(t)
305	defer teardown()
306	server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql,
307		SimulatedExecutionTime{
308			MinimumExecutionTime: 50 * time.Millisecond,
309		})
310	ctx := context.Background()
311	ctx, cancel := context.WithDeadline(ctx, time.Now().Add(5*time.Millisecond))
312	defer cancel()
313	err := executeSingerQuery(ctx, client.Single())
314	if status.Code(err) != codes.DeadlineExceeded {
315		t.Fatalf("Error mismatch:\ngot: %v\nwant: %v", err, codes.DeadlineExceeded)
316	}
317}
318
319func TestClient_Single_DeadlineExceeded_WithErrors(t *testing.T) {
320	t.Parallel()
321	server, client, teardown := setupMockedTestServer(t)
322	defer teardown()
323	server.TestSpanner.AddPartialResultSetError(
324		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
325		PartialResultSetExecutionTime{
326			ResumeToken: EncodeResumeToken(2),
327			Err:         status.Errorf(codes.Internal, "stream terminated by RST_STREAM"),
328		},
329	)
330	server.TestSpanner.AddPartialResultSetError(
331		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
332		PartialResultSetExecutionTime{
333			ResumeToken:   EncodeResumeToken(3),
334			Err:           status.Errorf(codes.Unavailable, "server is unavailable"),
335			ExecutionTime: 50 * time.Millisecond,
336		},
337	)
338	ctx := context.Background()
339	ctx, cancel := context.WithDeadline(ctx, time.Now().Add(25*time.Millisecond))
340	defer cancel()
341	err := executeSingerQuery(ctx, client.Single())
342	if status.Code(err) != codes.DeadlineExceeded {
343		t.Fatalf("got unexpected error %v, expected DeadlineExceeded", err)
344	}
345}
346
347func TestClient_Single_ContextCanceled_noDeclaredServerErrors(t *testing.T) {
348	t.Parallel()
349	_, client, teardown := setupMockedTestServer(t)
350	defer teardown()
351	ctx := context.Background()
352	ctx, cancel := context.WithCancel(ctx)
353	cancel()
354	err := executeSingerQuery(ctx, client.Single())
355	if status.Code(err) != codes.Canceled {
356		t.Fatalf("got unexpected error %v, expected Canceled", err)
357	}
358}
359
360func TestClient_Single_ContextCanceled_withDeclaredServerErrors(t *testing.T) {
361	t.Parallel()
362	server, client, teardown := setupMockedTestServer(t)
363	defer teardown()
364	server.TestSpanner.AddPartialResultSetError(
365		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
366		PartialResultSetExecutionTime{
367			ResumeToken: EncodeResumeToken(2),
368			Err:         status.Errorf(codes.Internal, "stream terminated by RST_STREAM"),
369		},
370	)
371	server.TestSpanner.AddPartialResultSetError(
372		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
373		PartialResultSetExecutionTime{
374			ResumeToken: EncodeResumeToken(3),
375			Err:         status.Errorf(codes.Unavailable, "server is unavailable"),
376		},
377	)
378	ctx := context.Background()
379	ctx, cancel := context.WithCancel(ctx)
380	defer cancel()
381	f := func(rowCount int64) error {
382		if rowCount == 2 {
383			cancel()
384		}
385		return nil
386	}
387	iter := client.Single().Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
388	defer iter.Stop()
389	err := executeSingerQueryWithRowFunc(ctx, client.Single(), f)
390	if status.Code(err) != codes.Canceled {
391		t.Fatalf("got unexpected error %v, expected Canceled", err)
392	}
393}
394
395func TestClient_Single_QueryOptions(t *testing.T) {
396	for _, tt := range queryOptionsTestCases() {
397		t.Run(tt.name, func(t *testing.T) {
398			if tt.env.Options != nil {
399				os.Setenv("SPANNER_OPTIMIZER_VERSION", tt.env.Options.OptimizerVersion)
400				defer os.Setenv("SPANNER_OPTIMIZER_VERSION", "")
401			}
402
403			ctx := context.Background()
404			server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{QueryOptions: tt.client})
405			defer teardown()
406
407			var iter *RowIterator
408			if tt.query.Options == nil {
409				iter = client.Single().Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
410			} else {
411				iter = client.Single().QueryWithOptions(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums), tt.query)
412			}
413			testQueryOptions(t, iter, server.TestSpanner, tt.want)
414		})
415	}
416}
417
418func testQueryOptions(t *testing.T, iter *RowIterator, server InMemSpannerServer, qo QueryOptions) {
419	defer iter.Stop()
420
421	_, err := iter.Next()
422	if err != nil {
423		t.Fatalf("Failed to read from the iterator: %v", err)
424	}
425
426	checkReqsForQueryOptions(t, server, qo)
427}
428
429func checkReqsForQueryOptions(t *testing.T, server InMemSpannerServer, qo QueryOptions) {
430	reqs := drainRequestsFromServer(server)
431	sqlReqs := []*sppb.ExecuteSqlRequest{}
432
433	for _, req := range reqs {
434		if sqlReq, ok := req.(*sppb.ExecuteSqlRequest); ok {
435			sqlReqs = append(sqlReqs, sqlReq)
436		}
437	}
438
439	if got, want := len(sqlReqs), 1; got != want {
440		t.Fatalf("Length mismatch, got %v, want %v", got, want)
441	}
442
443	reqQueryOptions := sqlReqs[0].QueryOptions
444	if got, want := reqQueryOptions.OptimizerVersion, qo.Options.OptimizerVersion; got != want {
445		t.Fatalf("Optimizer version mismatch, got %v, want %v", got, want)
446	}
447}
448
449func testSingleQuery(t *testing.T, serverError error) error {
450	ctx := context.Background()
451	server, client, teardown := setupMockedTestServer(t)
452	defer teardown()
453	if serverError != nil {
454		server.TestSpanner.SetError(serverError)
455	}
456	return executeSingerQuery(ctx, client.Single())
457}
458
459func executeSingerQuery(ctx context.Context, tx *ReadOnlyTransaction) error {
460	return executeSingerQueryWithRowFunc(ctx, tx, nil)
461}
462
463func executeSingerQueryWithRowFunc(ctx context.Context, tx *ReadOnlyTransaction, f func(rowCount int64) error) error {
464	iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
465	defer iter.Stop()
466	rowCount := int64(0)
467	for {
468		row, err := iter.Next()
469		if err == iterator.Done {
470			break
471		}
472		if err != nil {
473			return err
474		}
475		var singerID, albumID int64
476		var albumTitle string
477		if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
478			return err
479		}
480		rowCount++
481		if f != nil {
482			if err := f(rowCount); err != nil {
483				return err
484			}
485		}
486	}
487	if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
488		return status.Errorf(codes.Internal, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
489	}
490	return nil
491}
492
493func createSimulatedExecutionTimeWithTwoUnavailableErrors(method string) map[string]SimulatedExecutionTime {
494	errors := make([]error, 2)
495	errors[0] = status.Error(codes.Unavailable, "Temporary unavailable")
496	errors[1] = status.Error(codes.Unavailable, "Temporary unavailable")
497	executionTimes := make(map[string]SimulatedExecutionTime)
498	executionTimes[method] = SimulatedExecutionTime{
499		Errors: errors,
500	}
501	return executionTimes
502}
503
504func TestClient_ReadOnlyTransaction(t *testing.T) {
505	t.Parallel()
506	if err := testReadOnlyTransaction(t, make(map[string]SimulatedExecutionTime)); err != nil {
507		t.Fatal(err)
508	}
509}
510
511func TestClient_ReadOnlyTransaction_UnavailableOnSessionCreate(t *testing.T) {
512	t.Parallel()
513	if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(MethodCreateSession)); err != nil {
514		t.Fatal(err)
515	}
516}
517
518func TestClient_ReadOnlyTransaction_UnavailableOnBeginTransaction(t *testing.T) {
519	t.Parallel()
520	if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(MethodBeginTransaction)); err != nil {
521		t.Fatal(err)
522	}
523}
524
525func TestClient_ReadOnlyTransaction_UnavailableOnExecuteStreamingSql(t *testing.T) {
526	t.Parallel()
527	if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(MethodExecuteStreamingSql)); err != nil {
528		t.Fatal(err)
529	}
530}
531
532func TestClient_ReadOnlyTransaction_SessionNotFoundOnExecuteStreamingSql(t *testing.T) {
533	t.Parallel()
534	// Session not found is not retryable for a query on a multi-use read-only
535	// transaction, as we would need to start a new transaction on a new
536	// session.
537	err := testReadOnlyTransaction(t, map[string]SimulatedExecutionTime{
538		MethodExecuteStreamingSql: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
539	})
540	want := toSpannerError(newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s"))
541	if err == nil {
542		t.Fatalf("missing expected error\nGot: nil\nWant: %v", want)
543	}
544	if status.Code(err) != status.Code(want) || !strings.Contains(err.Error(), want.Error()) {
545		t.Fatalf("error mismatch\nGot: %v\nWant: %v", err, want)
546	}
547}
548
549func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndBeginTransaction(t *testing.T) {
550	t.Parallel()
551	exec := map[string]SimulatedExecutionTime{
552		MethodCreateSession:    {Errors: []error{status.Error(codes.Unavailable, "Temporary unavailable")}},
553		MethodBeginTransaction: {Errors: []error{status.Error(codes.Unavailable, "Temporary unavailable")}},
554	}
555	if err := testReadOnlyTransaction(t, exec); err != nil {
556		t.Fatal(err)
557	}
558}
559
560func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndInvalidArgumentOnBeginTransaction(t *testing.T) {
561	t.Parallel()
562	exec := map[string]SimulatedExecutionTime{
563		MethodCreateSession:    {Errors: []error{status.Error(codes.Unavailable, "Temporary unavailable")}},
564		MethodBeginTransaction: {Errors: []error{status.Error(codes.InvalidArgument, "Invalid argument")}},
565	}
566	if err := testReadOnlyTransaction(t, exec); err == nil {
567		t.Fatalf("Missing expected exception")
568	} else if status.Code(err) != codes.InvalidArgument {
569		t.Fatalf("Got unexpected exception: %v", err)
570	}
571}
572
573func TestClient_ReadOnlyTransaction_SessionNotFoundOnBeginTransaction(t *testing.T) {
574	t.Parallel()
575	if err := testReadOnlyTransaction(
576		t,
577		map[string]SimulatedExecutionTime{
578			MethodBeginTransaction: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
579		},
580	); err != nil {
581		t.Fatal(err)
582	}
583}
584
585func TestClient_ReadOnlyTransaction_SessionNotFoundOnBeginTransaction_WithMaxOneSession(t *testing.T) {
586	t.Parallel()
587	server, client, teardown := setupMockedTestServerWithConfig(
588		t,
589		ClientConfig{
590			SessionPoolConfig: SessionPoolConfig{
591				MinOpened: 0,
592				MaxOpened: 1,
593			},
594		})
595	defer teardown()
596	server.TestSpanner.PutExecutionTime(
597		MethodBeginTransaction,
598		SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
599	)
600	tx := client.ReadOnlyTransaction()
601	defer tx.Close()
602	ctx := context.Background()
603	if err := executeSingerQuery(ctx, tx); err != nil {
604		t.Fatal(err)
605	}
606}
607
608func TestClient_ReadOnlyTransaction_QueryOptions(t *testing.T) {
609	for _, tt := range queryOptionsTestCases() {
610		t.Run(tt.name, func(t *testing.T) {
611			if tt.env.Options != nil {
612				os.Setenv("SPANNER_OPTIMIZER_VERSION", tt.env.Options.OptimizerVersion)
613				defer os.Setenv("SPANNER_OPTIMIZER_VERSION", "")
614			}
615
616			ctx := context.Background()
617			server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{QueryOptions: tt.client})
618			defer teardown()
619
620			tx := client.ReadOnlyTransaction()
621			defer tx.Close()
622
623			var iter *RowIterator
624			if tt.query.Options == nil {
625				iter = tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
626			} else {
627				iter = tx.QueryWithOptions(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums), tt.query)
628			}
629			testQueryOptions(t, iter, server.TestSpanner, tt.want)
630		})
631	}
632}
633
634func testReadOnlyTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime) error {
635	server, client, teardown := setupMockedTestServer(t)
636	defer teardown()
637	for method, exec := range executionTimes {
638		server.TestSpanner.PutExecutionTime(method, exec)
639	}
640	tx := client.ReadOnlyTransaction()
641	defer tx.Close()
642	ctx := context.Background()
643	return executeSingerQuery(ctx, tx)
644}
645
646func TestClient_ReadWriteTransaction(t *testing.T) {
647	t.Parallel()
648	if err := testReadWriteTransaction(t, make(map[string]SimulatedExecutionTime), 1); err != nil {
649		t.Fatal(err)
650	}
651}
652
653func TestClient_ReadWriteTransactionCommitAborted(t *testing.T) {
654	t.Parallel()
655	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
656		MethodCommitTransaction: {Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}},
657	}, 2); err != nil {
658		t.Fatal(err)
659	}
660}
661
662func TestClient_ReadWriteTransaction_SessionNotFoundOnCommit(t *testing.T) {
663	t.Parallel()
664	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
665		MethodCommitTransaction: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
666	}, 2); err != nil {
667		t.Fatal(err)
668	}
669}
670
671func TestClient_ReadWriteTransaction_SessionNotFoundOnBeginTransaction(t *testing.T) {
672	t.Parallel()
673	// We expect only 1 attempt, as the 'Session not found' error is already
674	//handled in the session pool where the session is prepared.
675	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
676		MethodBeginTransaction: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
677	}, 1); err != nil {
678		t.Fatal(err)
679	}
680}
681
682func TestClient_ReadWriteTransaction_SessionNotFoundOnBeginTransactionWithEmptySessionPool(t *testing.T) {
683	t.Parallel()
684	// There will be no prepared sessions in the pool, so the error will occur
685	// when the transaction tries to get a session from the pool. This will
686	// also be handled by the session pool, so the transaction itself does not
687	// need to retry, hence the expectedAttempts == 1.
688	if err := testReadWriteTransactionWithConfig(t, ClientConfig{
689		SessionPoolConfig: SessionPoolConfig{WriteSessions: 0.0},
690	}, map[string]SimulatedExecutionTime{
691		MethodBeginTransaction: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
692	}, 1); err != nil {
693		t.Fatal(err)
694	}
695}
696
697func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteStreamingSql(t *testing.T) {
698	t.Parallel()
699	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
700		MethodExecuteStreamingSql: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
701	}, 2); err != nil {
702		t.Fatal(err)
703	}
704}
705
706func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteUpdate(t *testing.T) {
707	t.Parallel()
708
709	server, client, teardown := setupMockedTestServer(t)
710	defer teardown()
711	server.TestSpanner.PutExecutionTime(
712		MethodExecuteSql,
713		SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
714	)
715	ctx := context.Background()
716	var attempts int
717	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
718		attempts++
719		rowCount, err := tx.Update(ctx, NewStatement(UpdateBarSetFoo))
720		if err != nil {
721			return err
722		}
723		if g, w := rowCount, int64(UpdateBarSetFooRowCount); g != w {
724			return status.Errorf(codes.FailedPrecondition, "Row count mismatch\nGot: %v\nWant: %v", g, w)
725		}
726		return nil
727	})
728	if err != nil {
729		t.Fatal(err)
730	}
731	if g, w := attempts, 2; g != w {
732		t.Fatalf("number of attempts mismatch:\nGot%d\nWant:%d", g, w)
733	}
734}
735
736func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteBatchUpdate(t *testing.T) {
737	t.Parallel()
738
739	server, client, teardown := setupMockedTestServer(t)
740	defer teardown()
741	server.TestSpanner.PutExecutionTime(
742		MethodExecuteBatchDml,
743		SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
744	)
745	ctx := context.Background()
746	var attempts int
747	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
748		attempts++
749		rowCounts, err := tx.BatchUpdate(ctx, []Statement{NewStatement(UpdateBarSetFoo)})
750		if err != nil {
751			return err
752		}
753		if g, w := len(rowCounts), 1; g != w {
754			return status.Errorf(codes.FailedPrecondition, "Row counts length mismatch\nGot: %v\nWant: %v", g, w)
755		}
756		if g, w := rowCounts[0], int64(UpdateBarSetFooRowCount); g != w {
757			return status.Errorf(codes.FailedPrecondition, "Row count mismatch\nGot: %v\nWant: %v", g, w)
758		}
759		return nil
760	})
761	if err != nil {
762		t.Fatal(err)
763	}
764	if g, w := attempts, 2; g != w {
765		t.Fatalf("number of attempts mismatch:\nGot%d\nWant:%d", g, w)
766	}
767}
768
769func TestClient_ReadWriteTransaction_Query_QueryOptions(t *testing.T) {
770	for _, tt := range queryOptionsTestCases() {
771		t.Run(tt.name, func(t *testing.T) {
772			if tt.env.Options != nil {
773				os.Setenv("SPANNER_OPTIMIZER_VERSION", tt.env.Options.OptimizerVersion)
774				defer os.Setenv("SPANNER_OPTIMIZER_VERSION", "")
775			}
776
777			ctx := context.Background()
778			server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{QueryOptions: tt.client})
779			defer teardown()
780
781			_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
782				var iter *RowIterator
783				if tt.query.Options == nil {
784					iter = tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
785				} else {
786					iter = tx.QueryWithOptions(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums), tt.query)
787				}
788				testQueryOptions(t, iter, server.TestSpanner, tt.want)
789				return nil
790			})
791			if err != nil {
792				t.Fatal(err)
793			}
794		})
795	}
796}
797
798func TestClient_ReadWriteTransaction_Update_QueryOptions(t *testing.T) {
799	for _, tt := range queryOptionsTestCases() {
800		t.Run(tt.name, func(t *testing.T) {
801			if tt.env.Options != nil {
802				os.Setenv("SPANNER_OPTIMIZER_VERSION", tt.env.Options.OptimizerVersion)
803				defer os.Setenv("SPANNER_OPTIMIZER_VERSION", "")
804			}
805
806			ctx := context.Background()
807			server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{QueryOptions: tt.client})
808			defer teardown()
809
810			_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
811				var rowCount int64
812				var err error
813				if tt.query.Options == nil {
814					rowCount, err = tx.Update(ctx, NewStatement(UpdateBarSetFoo))
815				} else {
816					rowCount, err = tx.UpdateWithOptions(ctx, NewStatement(UpdateBarSetFoo), tt.query)
817				}
818				if got, want := rowCount, int64(5); got != want {
819					t.Fatalf("Incorrect updated row count: got %v, want %v", got, want)
820				}
821				return err
822			})
823			if err != nil {
824				t.Fatalf("Failed to update rows: %v", err)
825			}
826			checkReqsForQueryOptions(t, server.TestSpanner, tt.want)
827		})
828	}
829}
830
831func TestClient_SessionNotFound(t *testing.T) {
832	// Ensure we always have at least one session in the pool.
833	sc := SessionPoolConfig{
834		MinOpened: 1,
835	}
836	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{SessionPoolConfig: sc})
837	defer teardown()
838	ctx := context.Background()
839	for {
840		client.idleSessions.mu.Lock()
841		numSessions := client.idleSessions.idleList.Len()
842		client.idleSessions.mu.Unlock()
843		if numSessions > 0 {
844			break
845		}
846		time.After(time.Millisecond)
847	}
848	// Remove the session from the server without the pool knowing it.
849	_, err := server.TestSpanner.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: client.idleSessions.idleList.Front().Value.(*session).id})
850	if err != nil {
851		t.Fatalf("Failed to delete session unexpectedly: %v", err)
852	}
853
854	_, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
855		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
856		defer iter.Stop()
857		rowCount := int64(0)
858		for {
859			row, err := iter.Next()
860			if err == iterator.Done {
861				break
862			}
863			if err != nil {
864				return err
865			}
866			var singerID, albumID int64
867			var albumTitle string
868			if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
869				return err
870			}
871			rowCount++
872		}
873		if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
874			return spannerErrorf(codes.FailedPrecondition, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
875		}
876		return nil
877	})
878	if err != nil {
879		t.Fatalf("Unexpected error during transaction: %v", err)
880	}
881}
882
883func TestClient_ReadWriteTransactionExecuteStreamingSqlAborted(t *testing.T) {
884	t.Parallel()
885	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
886		MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}},
887	}, 2); err != nil {
888		t.Fatal(err)
889	}
890}
891
892func TestClient_ReadWriteTransaction_UnavailableOnBeginTransaction(t *testing.T) {
893	t.Parallel()
894	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
895		MethodBeginTransaction: {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}},
896	}, 1); err != nil {
897		t.Fatal(err)
898	}
899}
900
901func TestClient_ReadWriteTransaction_UnavailableOnBeginAndAbortOnCommit(t *testing.T) {
902	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
903		MethodBeginTransaction:  {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}},
904		MethodCommitTransaction: {Errors: []error{status.Error(codes.Aborted, "Aborted")}},
905	}, 2); err != nil {
906		t.Fatal(err)
907	}
908}
909
910func TestClient_ReadWriteTransaction_UnavailableOnExecuteStreamingSql(t *testing.T) {
911	t.Parallel()
912	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
913		MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}},
914	}, 1); err != nil {
915		t.Fatal(err)
916	}
917}
918
919func TestClient_ReadWriteTransaction_UnavailableOnBeginAndExecuteStreamingSqlAndTwiceAbortOnCommit(t *testing.T) {
920	t.Parallel()
921	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
922		MethodBeginTransaction:    {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}},
923		MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}},
924		MethodCommitTransaction:   {Errors: []error{status.Error(codes.Aborted, "Aborted"), status.Error(codes.Aborted, "Aborted")}},
925	}, 3); err != nil {
926		t.Fatal(err)
927	}
928}
929
930func TestClient_ReadWriteTransaction_CommitAborted(t *testing.T) {
931	t.Parallel()
932	server, client, teardown := setupMockedTestServer(t)
933	server.TestSpanner.PutExecutionTime(MethodCommitTransaction, SimulatedExecutionTime{
934		Errors: []error{status.Error(codes.Aborted, "Aborted")},
935	})
936	defer teardown()
937	ctx := context.Background()
938	attempts := 0
939	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
940		attempts++
941		_, err := tx.Update(ctx, Statement{SQL: UpdateBarSetFoo})
942		if err != nil {
943			return err
944		}
945		return nil
946	})
947	if err != nil {
948		t.Fatal(err)
949	}
950	if g, w := attempts, 2; g != w {
951		t.Fatalf("attempt count mismatch:\nWant: %v\nGot: %v", w, g)
952	}
953}
954
955func TestClient_ReadWriteTransaction_DMLAborted(t *testing.T) {
956	t.Parallel()
957	server, client, teardown := setupMockedTestServer(t)
958	server.TestSpanner.PutExecutionTime(MethodExecuteSql, SimulatedExecutionTime{
959		Errors: []error{status.Error(codes.Aborted, "Aborted")},
960	})
961	defer teardown()
962	ctx := context.Background()
963	attempts := 0
964	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
965		attempts++
966		_, err := tx.Update(ctx, Statement{SQL: UpdateBarSetFoo})
967		if err != nil {
968			return err
969		}
970		return nil
971	})
972	if err != nil {
973		t.Fatal(err)
974	}
975	if g, w := attempts, 2; g != w {
976		t.Fatalf("attempt count mismatch:\nWant: %v\nGot: %v", w, g)
977	}
978}
979
980func TestClient_ReadWriteTransaction_BatchDMLAborted(t *testing.T) {
981	t.Parallel()
982	server, client, teardown := setupMockedTestServer(t)
983	server.TestSpanner.PutExecutionTime(MethodExecuteBatchDml, SimulatedExecutionTime{
984		Errors: []error{status.Error(codes.Aborted, "Aborted")},
985	})
986	defer teardown()
987	ctx := context.Background()
988	attempts := 0
989	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
990		attempts++
991		_, err := tx.BatchUpdate(ctx, []Statement{{SQL: UpdateBarSetFoo}})
992		if err != nil {
993			return err
994		}
995		return nil
996	})
997	if err != nil {
998		t.Fatal(err)
999	}
1000	if g, w := attempts, 2; g != w {
1001		t.Fatalf("attempt count mismatch:\nWant: %v\nGot: %v", w, g)
1002	}
1003}
1004
1005func TestClient_ReadWriteTransaction_BatchDMLAbortedHalfway(t *testing.T) {
1006	t.Parallel()
1007	server, client, teardown := setupMockedTestServer(t)
1008	defer teardown()
1009	abortedStatement := "UPDATE FOO_ABORTED SET BAR=1 WHERE ID=2"
1010	server.TestSpanner.PutStatementResult(
1011		abortedStatement,
1012		&StatementResult{
1013			Type: StatementResultError,
1014			Err:  status.Error(codes.Aborted, "Statement was aborted"),
1015		},
1016	)
1017	ctx := context.Background()
1018	var updateCounts []int64
1019	attempts := 0
1020	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1021		attempts++
1022		if attempts > 1 {
1023			// Replace the aborted result with a real result to prevent the
1024			// transaction from aborting indefinitely.
1025			server.TestSpanner.PutStatementResult(
1026				abortedStatement,
1027				&StatementResult{
1028					Type:        StatementResultUpdateCount,
1029					UpdateCount: 3,
1030				},
1031			)
1032		}
1033		var err error
1034		updateCounts, err = tx.BatchUpdate(ctx, []Statement{{SQL: abortedStatement}, {SQL: UpdateBarSetFoo}})
1035		if err != nil {
1036			return err
1037		}
1038		return nil
1039	})
1040	if err != nil {
1041		t.Fatal(err)
1042	}
1043	if g, w := attempts, 2; g != w {
1044		t.Fatalf("attempt count mismatch:\nWant: %v\nGot: %v", w, g)
1045	}
1046	if g, w := updateCounts, []int64{3, UpdateBarSetFooRowCount}; !testEqual(w, g) {
1047		t.Fatalf("update count mismatch\nWant: %v\nGot: %v", w, g)
1048	}
1049}
1050
1051func TestClient_ReadWriteTransaction_QueryAborted(t *testing.T) {
1052	t.Parallel()
1053	server, client, teardown := setupMockedTestServer(t)
1054	server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql, SimulatedExecutionTime{
1055		Errors: []error{status.Error(codes.Aborted, "Aborted")},
1056	})
1057	defer teardown()
1058	ctx := context.Background()
1059	attempts := 0
1060	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1061		attempts++
1062		iter := tx.Query(ctx, Statement{SQL: SelectFooFromBar})
1063		defer iter.Stop()
1064		for {
1065			_, err := iter.Next()
1066			if err == iterator.Done {
1067				break
1068			}
1069			if err != nil {
1070				return err
1071			}
1072		}
1073		return nil
1074	})
1075	if err != nil {
1076		t.Fatal(err)
1077	}
1078	if g, w := attempts, 2; g != w {
1079		t.Fatalf("attempt count mismatch:\nWant: %v\nGot: %v", w, g)
1080	}
1081}
1082
1083func TestClient_ReadWriteTransaction_AbortedOnExecuteStreamingSqlAndCommit(t *testing.T) {
1084	t.Parallel()
1085	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
1086		MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Aborted, "Aborted")}},
1087		MethodCommitTransaction:   {Errors: []error{status.Error(codes.Aborted, "Aborted"), status.Error(codes.Aborted, "Aborted")}},
1088	}, 4); err != nil {
1089		t.Fatal(err)
1090	}
1091}
1092
1093func TestClient_ReadWriteTransactionCommitAbortedAndUnavailable(t *testing.T) {
1094	t.Parallel()
1095	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
1096		MethodCommitTransaction: {
1097			Errors: []error{
1098				status.Error(codes.Aborted, "Transaction aborted"),
1099				status.Error(codes.Unavailable, "Unavailable"),
1100			},
1101		},
1102	}, 2); err != nil {
1103		t.Fatal(err)
1104	}
1105}
1106
1107func TestClient_ReadWriteTransactionCommitAlreadyExists(t *testing.T) {
1108	t.Parallel()
1109	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
1110		MethodCommitTransaction: {Errors: []error{status.Error(codes.AlreadyExists, "A row with this key already exists")}},
1111	}, 1); err != nil {
1112		if status.Code(err) != codes.AlreadyExists {
1113			t.Fatalf("Got unexpected error %v, expected %v", err, codes.AlreadyExists)
1114		}
1115	} else {
1116		t.Fatalf("Missing expected exception")
1117	}
1118}
1119
1120func testReadWriteTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime, expectedAttempts int) error {
1121	return testReadWriteTransactionWithConfig(t, ClientConfig{SessionPoolConfig: DefaultSessionPoolConfig}, executionTimes, expectedAttempts)
1122}
1123
1124func testReadWriteTransactionWithConfig(t *testing.T, config ClientConfig, executionTimes map[string]SimulatedExecutionTime, expectedAttempts int) error {
1125	server, client, teardown := setupMockedTestServer(t)
1126	defer teardown()
1127	for method, exec := range executionTimes {
1128		server.TestSpanner.PutExecutionTime(method, exec)
1129	}
1130	ctx := context.Background()
1131	var attempts int
1132	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1133		attempts++
1134		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
1135		defer iter.Stop()
1136		rowCount := int64(0)
1137		for {
1138			row, err := iter.Next()
1139			if err == iterator.Done {
1140				break
1141			}
1142			if err != nil {
1143				return err
1144			}
1145			var singerID, albumID int64
1146			var albumTitle string
1147			if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
1148				return err
1149			}
1150			rowCount++
1151		}
1152		if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
1153			return status.Errorf(codes.FailedPrecondition, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
1154		}
1155		return nil
1156	})
1157	if err != nil {
1158		return err
1159	}
1160	if expectedAttempts != attempts {
1161		t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts)
1162	}
1163	return nil
1164}
1165
1166func TestClient_ApplyAtLeastOnce(t *testing.T) {
1167	t.Parallel()
1168	server, client, teardown := setupMockedTestServer(t)
1169	defer teardown()
1170	ms := []*Mutation{
1171		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
1172		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
1173	}
1174	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
1175		SimulatedExecutionTime{
1176			Errors: []error{status.Error(codes.Aborted, "Transaction aborted")},
1177		})
1178	_, err := client.Apply(context.Background(), ms, ApplyAtLeastOnce())
1179	if err != nil {
1180		t.Fatal(err)
1181	}
1182}
1183
1184func TestClient_ApplyAtLeastOnceReuseSession(t *testing.T) {
1185	t.Parallel()
1186	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
1187		SessionPoolConfig: SessionPoolConfig{
1188			MinOpened:           0,
1189			WriteSessions:       0.0,
1190			TrackSessionHandles: true,
1191		},
1192	})
1193	defer teardown()
1194	sp := client.idleSessions
1195	ms := []*Mutation{
1196		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
1197		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
1198	}
1199	for i := 0; i < 10; i++ {
1200		_, err := client.Apply(context.Background(), ms, ApplyAtLeastOnce())
1201		if err != nil {
1202			t.Fatal(err)
1203		}
1204		sp.mu.Lock()
1205		if g, w := uint64(sp.idleList.Len())+sp.createReqs, sp.incStep; g != w {
1206			t.Fatalf("idle session count mismatch:\nGot: %v\nWant: %v", g, w)
1207		}
1208		if g, w := uint64(len(server.TestSpanner.DumpSessions())), sp.incStep; g != w {
1209			t.Fatalf("server session count mismatch:\nGot: %v\nWant: %v", g, w)
1210		}
1211		sp.mu.Unlock()
1212	}
1213	// There should be no sessions marked as checked out.
1214	sp.mu.Lock()
1215	g, w := sp.trackedSessionHandles.Len(), 0
1216	sp.mu.Unlock()
1217	if g != w {
1218		t.Fatalf("checked out sessions count mismatch:\nGot: %v\nWant: %v", g, w)
1219	}
1220}
1221
1222func TestClient_ApplyAtLeastOnceInvalidArgument(t *testing.T) {
1223	t.Parallel()
1224	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
1225		SessionPoolConfig: SessionPoolConfig{
1226			MinOpened:           0,
1227			WriteSessions:       0.0,
1228			TrackSessionHandles: true,
1229		},
1230	})
1231	defer teardown()
1232	sp := client.idleSessions
1233	ms := []*Mutation{
1234		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
1235		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
1236	}
1237	for i := 0; i < 10; i++ {
1238		server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
1239			SimulatedExecutionTime{
1240				Errors: []error{status.Error(codes.InvalidArgument, "Invalid data")},
1241			})
1242		_, err := client.Apply(context.Background(), ms, ApplyAtLeastOnce())
1243		if status.Code(err) != codes.InvalidArgument {
1244			t.Fatal(err)
1245		}
1246		sp.mu.Lock()
1247		if g, w := uint64(sp.idleList.Len())+sp.createReqs, sp.incStep; g != w {
1248			t.Fatalf("idle session count mismatch:\nGot: %v\nWant: %v", g, w)
1249		}
1250		if g, w := uint64(len(server.TestSpanner.DumpSessions())), sp.incStep; g != w {
1251			t.Fatalf("server session count mismatch:\nGot: %v\nWant: %v", g, w)
1252		}
1253		sp.mu.Unlock()
1254	}
1255	// There should be no sessions marked as checked out.
1256	client.idleSessions.mu.Lock()
1257	g, w := client.idleSessions.trackedSessionHandles.Len(), 0
1258	client.idleSessions.mu.Unlock()
1259	if g != w {
1260		t.Fatalf("checked out sessions count mismatch:\nGot: %v\nWant: %v", g, w)
1261	}
1262}
1263
1264func TestReadWriteTransaction_ErrUnexpectedEOF(t *testing.T) {
1265	t.Parallel()
1266	_, client, teardown := setupMockedTestServer(t)
1267	defer teardown()
1268	ctx := context.Background()
1269	var attempts int
1270	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1271		attempts++
1272		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
1273		defer iter.Stop()
1274		for {
1275			row, err := iter.Next()
1276			if err == iterator.Done {
1277				break
1278			}
1279			if err != nil {
1280				return err
1281			}
1282			var singerID, albumID int64
1283			var albumTitle string
1284			if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
1285				return err
1286			}
1287		}
1288		return io.ErrUnexpectedEOF
1289	})
1290	if err != io.ErrUnexpectedEOF {
1291		t.Fatalf("Missing expected error %v, got %v", io.ErrUnexpectedEOF, err)
1292	}
1293	if attempts != 1 {
1294		t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, 1)
1295	}
1296}
1297
1298func TestReadWriteTransaction_WrapError(t *testing.T) {
1299	t.Parallel()
1300	server, client, teardown := setupMockedTestServer(t)
1301	defer teardown()
1302	// Abort the transaction on both the query as well as commit.
1303	// The first abort error will be wrapped. The client will unwrap the cause
1304	// of the error and retry the transaction. The aborted error on commit
1305	// will not be wrapped, but will also be recognized by the client as an
1306	// abort that should be retried.
1307	server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql,
1308		SimulatedExecutionTime{
1309			Errors: []error{status.Error(codes.Aborted, "Transaction aborted")},
1310		})
1311	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
1312		SimulatedExecutionTime{
1313			Errors: []error{status.Error(codes.Aborted, "Transaction aborted")},
1314		})
1315	msg := "query failed"
1316	numAttempts := 0
1317	ctx := context.Background()
1318	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1319		numAttempts++
1320		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
1321		defer iter.Stop()
1322		for {
1323			_, err := iter.Next()
1324			if err == iterator.Done {
1325				break
1326			}
1327			if err != nil {
1328				// Wrap the error in another error that implements the
1329				// (xerrors|errors).Wrapper interface.
1330				return &wrappedTestError{err, msg}
1331			}
1332		}
1333		return nil
1334	})
1335	if err != nil {
1336		t.Fatalf("Unexpected error\nGot: %v\nWant: nil", err)
1337	}
1338	if g, w := numAttempts, 3; g != w {
1339		t.Fatalf("Number of transaction attempts mismatch\nGot: %d\nWant: %d", w, w)
1340	}
1341
1342	// Execute a transaction that returns a non-retryable error that is
1343	// wrapped in a custom error. The transaction should return the custom
1344	// error.
1345	server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql,
1346		SimulatedExecutionTime{
1347			Errors: []error{status.Error(codes.NotFound, "Table not found")},
1348		})
1349	_, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1350		numAttempts++
1351		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
1352		defer iter.Stop()
1353		for {
1354			_, err := iter.Next()
1355			if err == iterator.Done {
1356				break
1357			}
1358			if err != nil {
1359				// Wrap the error in another error that implements the
1360				// (xerrors|errors).Wrapper interface.
1361				return &wrappedTestError{err, msg}
1362			}
1363		}
1364		return nil
1365	})
1366	if err == nil || err.Error() != msg {
1367		t.Fatalf("Unexpected error\nGot: %v\nWant: %v", err, msg)
1368	}
1369}
1370
1371func TestReadWriteTransaction_WrapSessionNotFoundError(t *testing.T) {
1372	t.Parallel()
1373	server, client, teardown := setupMockedTestServer(t)
1374	defer teardown()
1375	server.TestSpanner.PutExecutionTime(MethodBeginTransaction,
1376		SimulatedExecutionTime{
1377			Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")},
1378		})
1379	server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql,
1380		SimulatedExecutionTime{
1381			Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")},
1382		})
1383	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
1384		SimulatedExecutionTime{
1385			Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")},
1386		})
1387	msg := "query failed"
1388	numAttempts := 0
1389	ctx := context.Background()
1390	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1391		numAttempts++
1392		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
1393		defer iter.Stop()
1394		for {
1395			_, err := iter.Next()
1396			if err == iterator.Done {
1397				break
1398			}
1399			if err != nil {
1400				// Wrap the error in another error that implements the
1401				// (xerrors|errors).Wrapper interface.
1402				return &wrappedTestError{err, msg}
1403			}
1404		}
1405		return nil
1406	})
1407	if err != nil {
1408		t.Fatalf("Unexpected error\nGot: %v\nWant: nil", err)
1409	}
1410	// We want 3 attempts. The 'Session not found' error on BeginTransaction
1411	// will not retry the entire transaction, which means that we will have two
1412	// failed attempts and then a successful attempt.
1413	if g, w := numAttempts, 3; g != w {
1414		t.Fatalf("Number of transaction attempts mismatch\nGot: %d\nWant: %d", g, w)
1415	}
1416}
1417
1418func TestClient_WriteStructWithPointers(t *testing.T) {
1419	t.Parallel()
1420	server, client, teardown := setupMockedTestServer(t)
1421	defer teardown()
1422	type T struct {
1423		ID    int64
1424		Col1  *string
1425		Col2  []*string
1426		Col3  *bool
1427		Col4  []*bool
1428		Col5  *int64
1429		Col6  []*int64
1430		Col7  *float64
1431		Col8  []*float64
1432		Col9  *time.Time
1433		Col10 []*time.Time
1434		Col11 *civil.Date
1435		Col12 []*civil.Date
1436	}
1437	t1 := T{
1438		ID:    1,
1439		Col2:  []*string{nil},
1440		Col4:  []*bool{nil},
1441		Col6:  []*int64{nil},
1442		Col8:  []*float64{nil},
1443		Col10: []*time.Time{nil},
1444		Col12: []*civil.Date{nil},
1445	}
1446	s := "foo"
1447	b := true
1448	i := int64(100)
1449	f := 3.14
1450	tm := time.Now()
1451	d := civil.DateOf(time.Now())
1452	t2 := T{
1453		ID:    2,
1454		Col1:  &s,
1455		Col2:  []*string{&s},
1456		Col3:  &b,
1457		Col4:  []*bool{&b},
1458		Col5:  &i,
1459		Col6:  []*int64{&i},
1460		Col7:  &f,
1461		Col8:  []*float64{&f},
1462		Col9:  &tm,
1463		Col10: []*time.Time{&tm},
1464		Col11: &d,
1465		Col12: []*civil.Date{&d},
1466	}
1467	m1, err := InsertStruct("Tab", &t1)
1468	if err != nil {
1469		t.Fatal(err)
1470	}
1471	m2, err := InsertStruct("Tab", &t2)
1472	if err != nil {
1473		t.Fatal(err)
1474	}
1475	_, err = client.Apply(context.Background(), []*Mutation{m1, m2})
1476	if err != nil {
1477		t.Fatal(err)
1478	}
1479	requests := drainRequestsFromServer(server.TestSpanner)
1480	for _, req := range requests {
1481		if commit, ok := req.(*sppb.CommitRequest); ok {
1482			if g, w := len(commit.Mutations), 2; w != g {
1483				t.Fatalf("mutation count mismatch\nGot: %v\nWant: %v", g, w)
1484			}
1485			insert := commit.Mutations[0].GetInsert()
1486			// The first insert should contain NULL values and arrays
1487			// containing exactly one NULL element.
1488			for i := 1; i < len(insert.Values[0].Values); i += 2 {
1489				// The non-array columns should contain NULL values.
1490				g, w := insert.Values[0].Values[i].GetKind(), &structpb.Value_NullValue{}
1491				if _, ok := g.(*structpb.Value_NullValue); !ok {
1492					t.Fatalf("type mismatch\nGot: %v\nWant: %v", g, w)
1493				}
1494				// The array columns should not be NULL.
1495				g, wList := insert.Values[0].Values[i+1].GetKind(), &structpb.Value_ListValue{}
1496				if _, ok := g.(*structpb.Value_ListValue); !ok {
1497					t.Fatalf("type mismatch\nGot: %v\nWant: %v", g, wList)
1498				}
1499				// The array should contain 1 NULL value.
1500				if gLength, wLength := len(insert.Values[0].Values[i+1].GetListValue().Values), 1; gLength != wLength {
1501					t.Fatalf("list value length mismatch\nGot: %v\nWant: %v", gLength, wLength)
1502				}
1503				g, w = insert.Values[0].Values[i+1].GetListValue().Values[0].GetKind(), &structpb.Value_NullValue{}
1504				if _, ok := g.(*structpb.Value_NullValue); !ok {
1505					t.Fatalf("type mismatch\nGot: %v\nWant: %v", g, w)
1506				}
1507			}
1508
1509			// The second insert should contain all non-NULL values.
1510			insert = commit.Mutations[1].GetInsert()
1511			for i := 1; i < len(insert.Values[0].Values); i += 2 {
1512				// The non-array columns should contain non-NULL values.
1513				g := insert.Values[0].Values[i].GetKind()
1514				if _, ok := g.(*structpb.Value_NullValue); ok {
1515					t.Fatalf("type mismatch\nGot: %v\nWant: non-NULL value", g)
1516				}
1517				// The array columns should also be non-NULL.
1518				g, wList := insert.Values[0].Values[i+1].GetKind(), &structpb.Value_ListValue{}
1519				if _, ok := g.(*structpb.Value_ListValue); !ok {
1520					t.Fatalf("type mismatch\nGot: %v\nWant: %v", g, wList)
1521				}
1522				// The array should contain exactly 1 non-NULL value.
1523				if gLength, wLength := len(insert.Values[0].Values[i+1].GetListValue().Values), 1; gLength != wLength {
1524					t.Fatalf("list value length mismatch\nGot: %v\nWant: %v", gLength, wLength)
1525				}
1526				g = insert.Values[0].Values[i+1].GetListValue().Values[0].GetKind()
1527				if _, ok := g.(*structpb.Value_NullValue); ok {
1528					t.Fatalf("type mismatch\nGot: %v\nWant: non-NULL value", g)
1529				}
1530			}
1531		}
1532	}
1533}
1534
1535func TestClient_WriteStructWithCustomTypes(t *testing.T) {
1536	t.Parallel()
1537	server, client, teardown := setupMockedTestServer(t)
1538	defer teardown()
1539	type CustomString string
1540	type CustomBool bool
1541	type CustomInt64 int64
1542	type CustomFloat64 float64
1543	type CustomTime time.Time
1544	type CustomDate civil.Date
1545	type T struct {
1546		ID    int64
1547		Col1  CustomString
1548		Col2  []CustomString
1549		Col3  CustomBool
1550		Col4  []CustomBool
1551		Col5  CustomInt64
1552		Col6  []CustomInt64
1553		Col7  CustomFloat64
1554		Col8  []CustomFloat64
1555		Col9  CustomTime
1556		Col10 []CustomTime
1557		Col11 CustomDate
1558		Col12 []CustomDate
1559	}
1560	t1 := T{
1561		ID:    1,
1562		Col2:  []CustomString{},
1563		Col4:  []CustomBool{},
1564		Col6:  []CustomInt64{},
1565		Col8:  []CustomFloat64{},
1566		Col10: []CustomTime{},
1567		Col12: []CustomDate{},
1568	}
1569	t2 := T{
1570		ID:    2,
1571		Col1:  "foo",
1572		Col2:  []CustomString{"foo"},
1573		Col3:  true,
1574		Col4:  []CustomBool{true},
1575		Col5:  100,
1576		Col6:  []CustomInt64{100},
1577		Col7:  3.14,
1578		Col8:  []CustomFloat64{3.14},
1579		Col9:  CustomTime(time.Now()),
1580		Col10: []CustomTime{CustomTime(time.Now())},
1581		Col11: CustomDate(civil.DateOf(time.Now())),
1582		Col12: []CustomDate{CustomDate(civil.DateOf(time.Now()))},
1583	}
1584	m1, err := InsertStruct("Tab", &t1)
1585	if err != nil {
1586		t.Fatal(err)
1587	}
1588	m2, err := InsertStruct("Tab", &t2)
1589	if err != nil {
1590		t.Fatal(err)
1591	}
1592	_, err = client.Apply(context.Background(), []*Mutation{m1, m2})
1593	if err != nil {
1594		t.Fatal(err)
1595	}
1596	requests := drainRequestsFromServer(server.TestSpanner)
1597	for _, req := range requests {
1598		if commit, ok := req.(*sppb.CommitRequest); ok {
1599			if g, w := len(commit.Mutations), 2; w != g {
1600				t.Fatalf("mutation count mismatch\nGot: %v\nWant: %v", g, w)
1601			}
1602			insert1 := commit.Mutations[0].GetInsert()
1603			row1 := insert1.Values[0]
1604			// The first insert should contain empty values and empty arrays
1605			for i := 1; i < len(row1.Values); i += 2 {
1606				// The non-array columns should contain empty values.
1607				g := row1.Values[i].GetKind()
1608				if _, ok := g.(*structpb.Value_NullValue); ok {
1609					t.Fatalf("type mismatch\nGot: %v\nWant: non-NULL value", g)
1610				}
1611				// The array columns should not be NULL.
1612				g, wList := row1.Values[i+1].GetKind(), &structpb.Value_ListValue{}
1613				if _, ok := g.(*structpb.Value_ListValue); !ok {
1614					t.Fatalf("type mismatch\nGot: %v\nWant: %v", g, wList)
1615				}
1616			}
1617
1618			// The second insert should contain all non-NULL values.
1619			insert2 := commit.Mutations[1].GetInsert()
1620			row2 := insert2.Values[0]
1621			for i := 1; i < len(row2.Values); i += 2 {
1622				// The non-array columns should contain non-NULL values.
1623				g := row2.Values[i].GetKind()
1624				if _, ok := g.(*structpb.Value_NullValue); ok {
1625					t.Fatalf("type mismatch\nGot: %v\nWant: non-NULL value", g)
1626				}
1627				// The array columns should also be non-NULL.
1628				g, wList := row2.Values[i+1].GetKind(), &structpb.Value_ListValue{}
1629				if _, ok := g.(*structpb.Value_ListValue); !ok {
1630					t.Fatalf("type mismatch\nGot: %v\nWant: %v", g, wList)
1631				}
1632				// The array should contain exactly 1 non-NULL value.
1633				if gLength, wLength := len(row2.Values[i+1].GetListValue().Values), 1; gLength != wLength {
1634					t.Fatalf("list value length mismatch\nGot: %v\nWant: %v", gLength, wLength)
1635				}
1636				g = row2.Values[i+1].GetListValue().Values[0].GetKind()
1637				if _, ok := g.(*structpb.Value_NullValue); ok {
1638					t.Fatalf("type mismatch\nGot: %v\nWant: non-NULL value", g)
1639				}
1640			}
1641		}
1642	}
1643}
1644
1645func TestReadWriteTransaction_ContextTimeoutDuringDuringCommit(t *testing.T) {
1646	t.Parallel()
1647	server, client, teardown := setupMockedTestServer(t)
1648	defer teardown()
1649	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
1650		SimulatedExecutionTime{
1651			MinimumExecutionTime: time.Minute,
1652		})
1653	ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
1654	defer cancel()
1655	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1656		tx.BufferWrite([]*Mutation{Insert("FOO", []string{"ID", "NAME"}, []interface{}{int64(1), "bar"})})
1657		return nil
1658	})
1659	errContext, _ := context.WithTimeout(context.Background(), -time.Second)
1660	w := toSpannerErrorWithCommitInfo(errContext.Err(), true).(*Error)
1661	var se *Error
1662	if !errorAs(err, &se) {
1663		t.Fatalf("Error mismatch\nGot: %v\nWant: %v", err, w)
1664	}
1665	if se.GRPCStatus().Code() != w.GRPCStatus().Code() {
1666		t.Fatalf("Error status mismatch:\nGot: %v\nWant: %v", se.GRPCStatus(), w.GRPCStatus())
1667	}
1668	if se.Error() != w.Error() {
1669		t.Fatalf("Error message mismatch:\nGot %s\nWant: %s", se.Error(), w.Error())
1670	}
1671	var outcome *TransactionOutcomeUnknownError
1672	if !errorAs(err, &outcome) {
1673		t.Fatalf("Missing wrapped TransactionOutcomeUnknownError error")
1674	}
1675}
1676
1677func TestFailedCommit_NoRollback(t *testing.T) {
1678	t.Parallel()
1679	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
1680		SessionPoolConfig: SessionPoolConfig{
1681			MinOpened:     0,
1682			MaxOpened:     1,
1683			WriteSessions: 0,
1684		},
1685	})
1686	defer teardown()
1687	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
1688		SimulatedExecutionTime{
1689			Errors: []error{status.Errorf(codes.InvalidArgument, "Invalid mutations")},
1690		})
1691	_, err := client.Apply(context.Background(), []*Mutation{
1692		Insert("FOO", []string{"ID", "BAR"}, []interface{}{1, "value"}),
1693	})
1694	if got, want := status.Convert(err).Code(), codes.InvalidArgument; got != want {
1695		t.Fatalf("Error mismatch\nGot: %v\nWant: %v", got, want)
1696	}
1697	// The failed commit should not trigger a rollback after the commit.
1698	if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{
1699		&sppb.BatchCreateSessionsRequest{},
1700		&sppb.BeginTransactionRequest{},
1701		&sppb.CommitRequest{},
1702	}); err != nil {
1703		t.Fatalf("Received RPCs mismatch: %v", err)
1704	}
1705}
1706
1707func TestFailedUpdate_ShouldRollback(t *testing.T) {
1708	t.Parallel()
1709	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
1710		SessionPoolConfig: SessionPoolConfig{
1711			MinOpened:     0,
1712			MaxOpened:     1,
1713			WriteSessions: 0,
1714		},
1715	})
1716	defer teardown()
1717	server.TestSpanner.PutExecutionTime(MethodExecuteSql,
1718		SimulatedExecutionTime{
1719			Errors: []error{status.Errorf(codes.InvalidArgument, "Invalid update")},
1720		})
1721	_, err := client.ReadWriteTransaction(context.Background(), func(ctx context.Context, tx *ReadWriteTransaction) error {
1722		_, err := tx.Update(ctx, NewStatement("UPDATE FOO SET BAR='value' WHERE ID=1"))
1723		return err
1724	})
1725	if got, want := status.Convert(err).Code(), codes.InvalidArgument; got != want {
1726		t.Fatalf("Error mismatch\nGot: %v\nWant: %v", got, want)
1727	}
1728	// The failed update should trigger a rollback.
1729	if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{
1730		&sppb.BatchCreateSessionsRequest{},
1731		&sppb.BeginTransactionRequest{},
1732		&sppb.ExecuteSqlRequest{},
1733		&sppb.RollbackRequest{},
1734	}); err != nil {
1735		t.Fatalf("Received RPCs mismatch: %v", err)
1736	}
1737}
1738
1739func TestClient_NumChannels(t *testing.T) {
1740	t.Parallel()
1741
1742	configuredNumChannels := 8
1743	_, client, teardown := setupMockedTestServerWithConfig(
1744		t,
1745		ClientConfig{NumChannels: configuredNumChannels},
1746	)
1747	defer teardown()
1748	if g, w := client.sc.connPool.Num(), configuredNumChannels; g != w {
1749		t.Fatalf("NumChannels mismatch\nGot: %v\nWant: %v", g, w)
1750	}
1751}
1752
1753func TestClient_WithGRPCConnectionPool(t *testing.T) {
1754	t.Parallel()
1755
1756	configuredConnPool := 8
1757	_, client, teardown := setupMockedTestServerWithConfigAndClientOptions(
1758		t,
1759		ClientConfig{},
1760		[]option.ClientOption{option.WithGRPCConnectionPool(configuredConnPool)},
1761	)
1762	defer teardown()
1763	if g, w := client.sc.connPool.Num(), configuredConnPool; g != w {
1764		t.Fatalf("NumChannels mismatch\nGot: %v\nWant: %v", g, w)
1765	}
1766}
1767
1768func TestClient_WithGRPCConnectionPoolAndNumChannels(t *testing.T) {
1769	t.Parallel()
1770
1771	configuredNumChannels := 8
1772	configuredConnPool := 8
1773	_, client, teardown := setupMockedTestServerWithConfigAndClientOptions(
1774		t,
1775		ClientConfig{NumChannels: configuredNumChannels},
1776		[]option.ClientOption{option.WithGRPCConnectionPool(configuredConnPool)},
1777	)
1778	defer teardown()
1779	if g, w := client.sc.connPool.Num(), configuredConnPool; g != w {
1780		t.Fatalf("NumChannels mismatch\nGot: %v\nWant: %v", g, w)
1781	}
1782}
1783
1784func TestClient_WithGRPCConnectionPoolAndNumChannels_Misconfigured(t *testing.T) {
1785	t.Parallel()
1786
1787	// Deliberately misconfigure NumChannels and ConnPool.
1788	configuredNumChannels := 8
1789	configuredConnPool := 16
1790
1791	_, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
1792	defer serverTeardown()
1793	opts = append(opts, option.WithGRPCConnectionPool(configuredConnPool))
1794
1795	_, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{NumChannels: configuredNumChannels}, opts...)
1796	msg := "Connection pool mismatch:"
1797	if err == nil {
1798		t.Fatalf("Error mismatch\nGot: nil\nWant: %s", msg)
1799	}
1800	var se *Error
1801	if ok := errorAs(err, &se); !ok {
1802		t.Fatalf("Error mismatch\nGot: %v\nWant: An instance of a Spanner error", err)
1803	}
1804	if g, w := se.GRPCStatus().Code(), codes.InvalidArgument; g != w {
1805		t.Fatalf("Error code mismatch\nGot: %v\nWant: %v", g, w)
1806	}
1807	if !strings.Contains(se.Error(), msg) {
1808		t.Fatalf("Error message mismatch\nGot: %s\nWant: %s", se.Error(), msg)
1809	}
1810}
1811
1812func TestClient_CallOptions(t *testing.T) {
1813	t.Parallel()
1814	co := &vkit.CallOptions{
1815		CreateSession: []gax.CallOption{
1816			gax.WithRetry(func() gax.Retryer {
1817				return gax.OnCodes([]codes.Code{
1818					codes.Unavailable, codes.DeadlineExceeded,
1819				}, gax.Backoff{
1820					Initial:    200 * time.Millisecond,
1821					Max:        30000 * time.Millisecond,
1822					Multiplier: 1.25,
1823				})
1824			}),
1825		},
1826	}
1827
1828	_, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{CallOptions: co})
1829	defer teardown()
1830
1831	c, err := client.sc.nextClient()
1832	if err != nil {
1833		t.Fatalf("failed to get a session client: %v", err)
1834	}
1835
1836	cs := &gax.CallSettings{}
1837	// This is the default retry setting.
1838	c.CallOptions.CreateSession[0].Resolve(cs)
1839	if got, want := fmt.Sprintf("%v", cs.Retry()), "&{{250000000 32000000000 1.3 0} [14]}"; got != want {
1840		t.Fatalf("merged CallOptions is incorrect: got %v, want %v", got, want)
1841	}
1842
1843	// This is the custom retry setting.
1844	c.CallOptions.CreateSession[1].Resolve(cs)
1845	if got, want := fmt.Sprintf("%v", cs.Retry()), "&{{200000000 30000000000 1.25 0} [14 4]}"; got != want {
1846		t.Fatalf("merged CallOptions is incorrect: got %v, want %v", got, want)
1847	}
1848}
1849
1850func TestClient_QueryWithCallOptions(t *testing.T) {
1851	t.Parallel()
1852	co := &vkit.CallOptions{
1853		ExecuteSql: []gax.CallOption{
1854			gax.WithRetry(func() gax.Retryer {
1855				return gax.OnCodes([]codes.Code{
1856					codes.DeadlineExceeded,
1857				}, gax.Backoff{
1858					Initial:    200 * time.Millisecond,
1859					Max:        30000 * time.Millisecond,
1860					Multiplier: 1.25,
1861				})
1862			}),
1863		},
1864	}
1865	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{CallOptions: co})
1866	server.TestSpanner.PutExecutionTime(MethodExecuteSql, SimulatedExecutionTime{
1867		Errors: []error{status.Error(codes.DeadlineExceeded, "Deadline exceeded")},
1868	})
1869	defer teardown()
1870	ctx := context.Background()
1871	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1872		_, err := tx.Update(ctx, Statement{SQL: UpdateBarSetFoo})
1873		if err != nil {
1874			return err
1875		}
1876		return nil
1877	})
1878	if err != nil {
1879		t.Fatal(err)
1880	}
1881}
1882
1883func TestClient_EncodeCustomFieldType(t *testing.T) {
1884	t.Parallel()
1885
1886	type typesTable struct {
1887		Int    customStructToInt    `spanner:"Int"`
1888		String customStructToString `spanner:"String"`
1889		Float  customStructToFloat  `spanner:"Float"`
1890		Bool   customStructToBool   `spanner:"Bool"`
1891		Time   customStructToTime   `spanner:"Time"`
1892		Date   customStructToDate   `spanner:"Date"`
1893	}
1894
1895	server, client, teardown := setupMockedTestServer(t)
1896	defer teardown()
1897	ctx := context.Background()
1898
1899	d := typesTable{
1900		Int:    customStructToInt{1, 23},
1901		String: customStructToString{"A", "B"},
1902		Float:  customStructToFloat{1.23, 12.3},
1903		Bool:   customStructToBool{true, false},
1904		Time:   customStructToTime{"A", "B"},
1905		Date:   customStructToDate{"A", "B"},
1906	}
1907
1908	m, err := InsertStruct("Types", &d)
1909	if err != nil {
1910		t.Fatalf("err: %v", err)
1911	}
1912
1913	ms := []*Mutation{m}
1914	_, err = client.Apply(ctx, ms)
1915	if err != nil {
1916		t.Fatalf("err: %v", err)
1917	}
1918
1919	reqs := drainRequestsFromServer(server.TestSpanner)
1920
1921	for _, req := range reqs {
1922		if commitReq, ok := req.(*sppb.CommitRequest); ok {
1923			val := commitReq.Mutations[0].GetInsert().Values[0]
1924
1925			if got, want := val.Values[0].GetStringValue(), "123"; got != want {
1926				t.Fatalf("value mismatch: got %v (kind %T), want %v", got, val.Values[0].GetKind(), want)
1927			}
1928			if got, want := val.Values[1].GetStringValue(), "A-B"; got != want {
1929				t.Fatalf("value mismatch: got %v (kind %T), want %v", got, val.Values[1].GetKind(), want)
1930			}
1931			if got, want := val.Values[2].GetNumberValue(), float64(123.123); got != want {
1932				t.Fatalf("value mismatch: got %v (kind %T), want %v", got, val.Values[2].GetKind(), want)
1933			}
1934			if got, want := val.Values[3].GetBoolValue(), true; got != want {
1935				t.Fatalf("value mismatch: got %v (kind %T), want %v", got, val.Values[3].GetKind(), want)
1936			}
1937			if got, want := val.Values[4].GetStringValue(), "2016-11-15T15:04:05.999999999Z"; got != want {
1938				t.Fatalf("value mismatch: got %v (kind %T), want %v", got, val.Values[4].GetKind(), want)
1939			}
1940			if got, want := val.Values[5].GetStringValue(), "2016-11-15"; got != want {
1941				t.Fatalf("value mismatch: got %v (kind %T), want %v", got, val.Values[5].GetKind(), want)
1942			}
1943		}
1944	}
1945}
1946
1947func setupDecodeCustomFieldResult(server *MockedSpannerInMemTestServer, stmt string) error {
1948	metadata := &sppb.ResultSetMetadata{
1949		RowType: &sppb.StructType{
1950			Fields: []*sppb.StructType_Field{
1951				{Name: "Int", Type: &sppb.Type{Code: sppb.TypeCode_INT64}},
1952				{Name: "String", Type: &sppb.Type{Code: sppb.TypeCode_STRING}},
1953				{Name: "Float", Type: &sppb.Type{Code: sppb.TypeCode_FLOAT64}},
1954				{Name: "Bool", Type: &sppb.Type{Code: sppb.TypeCode_BOOL}},
1955				{Name: "Time", Type: &sppb.Type{Code: sppb.TypeCode_TIMESTAMP}},
1956				{Name: "Date", Type: &sppb.Type{Code: sppb.TypeCode_DATE}},
1957			},
1958		},
1959	}
1960	rowValues := []*structpb.Value{
1961		{Kind: &structpb.Value_StringValue{StringValue: "123"}},
1962		{Kind: &structpb.Value_StringValue{StringValue: "A-B"}},
1963		{Kind: &structpb.Value_NumberValue{NumberValue: float64(123.123)}},
1964		{Kind: &structpb.Value_BoolValue{BoolValue: true}},
1965		{Kind: &structpb.Value_StringValue{StringValue: "2016-11-15T15:04:05.999999999Z"}},
1966		{Kind: &structpb.Value_StringValue{StringValue: "2016-11-15"}},
1967	}
1968	rows := []*structpb.ListValue{
1969		{Values: rowValues},
1970	}
1971	resultSet := &sppb.ResultSet{
1972		Metadata: metadata,
1973		Rows:     rows,
1974	}
1975	result := &StatementResult{
1976		Type:      StatementResultResultSet,
1977		ResultSet: resultSet,
1978	}
1979	return server.TestSpanner.PutStatementResult(stmt, result)
1980}
1981
1982func TestClient_DecodeCustomFieldType(t *testing.T) {
1983	t.Parallel()
1984
1985	type typesTable struct {
1986		Int    customStructToInt    `spanner:"Int"`
1987		String customStructToString `spanner:"String"`
1988		Float  customStructToFloat  `spanner:"Float"`
1989		Bool   customStructToBool   `spanner:"Bool"`
1990		Time   customStructToTime   `spanner:"Time"`
1991		Date   customStructToDate   `spanner:"Date"`
1992	}
1993
1994	server, client, teardown := setupMockedTestServer(t)
1995	defer teardown()
1996
1997	query := "SELECT * FROM Types"
1998	setupDecodeCustomFieldResult(server, query)
1999
2000	ctx := context.Background()
2001	stmt := Statement{SQL: query}
2002	iter := client.Single().Query(ctx, stmt)
2003	defer iter.Stop()
2004
2005	var results []typesTable
2006	for {
2007		row, err := iter.Next()
2008		if err == iterator.Done {
2009			break
2010		}
2011		if err != nil {
2012			t.Fatalf("failed to get next: %v", err)
2013		}
2014
2015		var d typesTable
2016		if err := row.ToStruct(&d); err != nil {
2017			t.Fatalf("failed to convert a row to a struct: %v", err)
2018		}
2019		results = append(results, d)
2020	}
2021
2022	if len(results) > 1 {
2023		t.Fatalf("mismatch length of array: got %v, want 1", results)
2024	}
2025
2026	want := typesTable{
2027		Int:    customStructToInt{1, 23},
2028		String: customStructToString{"A", "B"},
2029		Float:  customStructToFloat{1.23, 12.3},
2030		Bool:   customStructToBool{true, false},
2031		Time:   customStructToTime{"A", "B"},
2032		Date:   customStructToDate{"A", "B"},
2033	}
2034	got := results[0]
2035	if !testEqual(got, want) {
2036		t.Fatalf("mismatch result: got %v, want %v", got, want)
2037	}
2038}
2039
2040func TestClient_EmulatorWithCredentialsFile(t *testing.T) {
2041	old := os.Getenv("SPANNER_EMULATOR_HOST")
2042	defer os.Setenv("SPANNER_EMULATOR_HOST", old)
2043
2044	os.Setenv("SPANNER_EMULATOR_HOST", "localhost:1234")
2045
2046	client, err := NewClientWithConfig(
2047		context.Background(),
2048		"projects/p/instances/i/databases/d",
2049		ClientConfig{},
2050		option.WithCredentialsFile("/path/to/key.json"),
2051	)
2052	defer client.Close()
2053	if err != nil {
2054		t.Fatalf("Failed to create a client with credentials file when running against an emulator: %v", err)
2055	}
2056}
2057
2058func TestBatchReadOnlyTransaction_QueryOptions(t *testing.T) {
2059	ctx := context.Background()
2060	qo := QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}}
2061	_, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{QueryOptions: qo})
2062	defer teardown()
2063
2064	txn, err := client.BatchReadOnlyTransaction(ctx, StrongRead())
2065	if err != nil {
2066		t.Fatal(err)
2067	}
2068	defer txn.Cleanup(ctx)
2069
2070	if txn.qo != qo {
2071		t.Fatalf("Query options are mismatched: got %v, want %v", txn.qo, qo)
2072	}
2073}
2074
2075func TestBatchReadOnlyTransactionFromID_QueryOptions(t *testing.T) {
2076	qo := QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}}
2077	_, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{QueryOptions: qo})
2078	defer teardown()
2079
2080	txn := client.BatchReadOnlyTransactionFromID(BatchReadOnlyTransactionID{})
2081
2082	if txn.qo != qo {
2083		t.Fatalf("Query options are mismatched: got %v, want %v", txn.qo, qo)
2084	}
2085}
2086
2087type QueryOptionsTestCase struct {
2088	name   string
2089	client QueryOptions
2090	env    QueryOptions
2091	query  QueryOptions
2092	want   QueryOptions
2093}
2094
2095func queryOptionsTestCases() []QueryOptionsTestCase {
2096	return []QueryOptionsTestCase{
2097		{
2098			"Client level",
2099			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}},
2100			QueryOptions{Options: nil},
2101			QueryOptions{Options: nil},
2102			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}},
2103		},
2104		{
2105			"Environment level",
2106			QueryOptions{Options: nil},
2107			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}},
2108			QueryOptions{Options: nil},
2109			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}},
2110		},
2111		{
2112			"Query level",
2113			QueryOptions{Options: nil},
2114			QueryOptions{Options: nil},
2115			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}},
2116			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}},
2117		},
2118		{
2119			"Environment level has precedence",
2120			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}},
2121			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "2"}},
2122			QueryOptions{Options: nil},
2123			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "2"}},
2124		},
2125		{
2126			"Query level has precedence than client level",
2127			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}},
2128			QueryOptions{Options: nil},
2129			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "3"}},
2130			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "3"}},
2131		},
2132		{
2133			"Query level has highest precedence",
2134			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}},
2135			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "2"}},
2136			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "3"}},
2137			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "3"}},
2138		},
2139	}
2140}
2141