1/*
2Copyright 2017 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"context"
21	"fmt"
22	"io"
23	"io/ioutil"
24	"log"
25	"os"
26	"strings"
27	"testing"
28	"time"
29
30	"cloud.google.com/go/civil"
31	itestutil "cloud.google.com/go/internal/testutil"
32	. "cloud.google.com/go/spanner/internal/testutil"
33	"github.com/golang/protobuf/proto"
34	structpb "github.com/golang/protobuf/ptypes/struct"
35	"google.golang.org/api/iterator"
36	"google.golang.org/api/option"
37	instancepb "google.golang.org/genproto/googleapis/spanner/admin/instance/v1"
38	sppb "google.golang.org/genproto/googleapis/spanner/v1"
39	"google.golang.org/grpc/codes"
40	"google.golang.org/grpc/status"
41)
42
43func setupMockedTestServer(t *testing.T) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) {
44	return setupMockedTestServerWithConfig(t, ClientConfig{})
45}
46
47func setupMockedTestServerWithConfig(t *testing.T, config ClientConfig) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) {
48	return setupMockedTestServerWithConfigAndClientOptions(t, config, []option.ClientOption{})
49}
50
51func setupMockedTestServerWithConfigAndClientOptions(t *testing.T, config ClientConfig, clientOptions []option.ClientOption) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) {
52	grpcHeaderChecker := &itestutil.HeadersEnforcer{
53		OnFailure: t.Fatalf,
54		Checkers: []*itestutil.HeaderChecker{
55			{
56				Key: "x-goog-api-client",
57				ValuesValidator: func(token ...string) error {
58					if len(token) != 1 {
59						return status.Errorf(codes.Internal, "unexpected number of api client token headers: %v", len(token))
60					}
61					if !strings.HasPrefix(token[0], "gl-go/") {
62						return status.Errorf(codes.Internal, "unexpected api client token: %v", token[0])
63					}
64					return nil
65				},
66			},
67		},
68	}
69	clientOptions = append(clientOptions, grpcHeaderChecker.CallOptions()...)
70	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
71	opts = append(opts, clientOptions...)
72	ctx := context.Background()
73	formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
74	client, err := NewClientWithConfig(ctx, formattedDatabase, config, opts...)
75	if err != nil {
76		t.Fatal(err)
77	}
78	return server, client, func() {
79		client.Close()
80		serverTeardown()
81	}
82}
83
84// Test validDatabaseName()
85func TestValidDatabaseName(t *testing.T) {
86	validDbURI := "projects/spanner-cloud-test/instances/foo/databases/foodb"
87	invalidDbUris := []string{
88		// Completely wrong DB URI.
89		"foobarDB",
90		// Project ID contains "/".
91		"projects/spanner-cloud/test/instances/foo/databases/foodb",
92		// No instance ID.
93		"projects/spanner-cloud-test/instances//databases/foodb",
94	}
95	if err := validDatabaseName(validDbURI); err != nil {
96		t.Errorf("validateDatabaseName(%q) = %v, want nil", validDbURI, err)
97	}
98	for _, d := range invalidDbUris {
99		if err, wantErr := validDatabaseName(d), "should conform to pattern"; !strings.Contains(err.Error(), wantErr) {
100			t.Errorf("validateDatabaseName(%q) = %q, want error pattern %q", validDbURI, err, wantErr)
101		}
102	}
103}
104
105// Test getInstanceName()
106func TestGetInstanceName(t *testing.T) {
107	validDbURI := "projects/spanner-cloud-test/instances/foo/databases/foodb"
108	invalidDbUris := []string{
109		// Completely wrong DB URI.
110		"foobarDB",
111		// Project ID contains "/".
112		"projects/spanner-cloud/test/instances/foo/databases/foodb",
113		// No instance ID.
114		"projects/spanner-cloud-test/instances//databases/foodb",
115	}
116	want := "projects/spanner-cloud-test/instances/foo"
117	got, err := getInstanceName(validDbURI)
118	if err != nil {
119		t.Errorf("getInstanceName(%q) has an error: %q, want nil", validDbURI, err)
120	}
121	if got != want {
122		t.Errorf("getInstanceName(%q) = %q, want %q", validDbURI, got, want)
123	}
124	for _, d := range invalidDbUris {
125		wantErr := "Failed to retrieve instance name"
126		_, err = getInstanceName(d)
127		if !strings.Contains(err.Error(), wantErr) {
128			t.Errorf("getInstanceName(%q) has an error: %q, want error pattern %q", validDbURI, err, wantErr)
129		}
130	}
131}
132
133func TestReadOnlyTransactionClose(t *testing.T) {
134	// Closing a ReadOnlyTransaction shouldn't panic.
135	c := &Client{}
136	tx := c.ReadOnlyTransaction()
137	tx.Close()
138}
139
140func TestClient_Single(t *testing.T) {
141	t.Parallel()
142	err := testSingleQuery(t, nil)
143	if err != nil {
144		t.Fatal(err)
145	}
146}
147
148func TestClient_Single_Unavailable(t *testing.T) {
149	t.Parallel()
150	err := testSingleQuery(t, status.Error(codes.Unavailable, "Temporary unavailable"))
151	if err != nil {
152		t.Fatal(err)
153	}
154}
155
156func TestClient_Single_InvalidArgument(t *testing.T) {
157	t.Parallel()
158	err := testSingleQuery(t, status.Error(codes.InvalidArgument, "Invalid argument"))
159	if status.Code(err) != codes.InvalidArgument {
160		t.Fatalf("got: %v, want: %v", err, codes.InvalidArgument)
161	}
162}
163
164func TestClient_Single_SessionNotFound(t *testing.T) {
165	t.Parallel()
166
167	server, client, teardown := setupMockedTestServer(t)
168	defer teardown()
169	server.TestSpanner.PutExecutionTime(
170		MethodExecuteStreamingSql,
171		SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
172	)
173	ctx := context.Background()
174	iter := client.Single().Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
175	defer iter.Stop()
176	rowCount := int64(0)
177	for {
178		_, err := iter.Next()
179		if err == iterator.Done {
180			break
181		}
182		if err != nil {
183			t.Fatal(err)
184		}
185		rowCount++
186	}
187	if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
188		t.Fatalf("row count mismatch\nGot: %v\nWant: %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
189	}
190}
191
192func TestClient_Single_RetryableErrorOnPartialResultSet(t *testing.T) {
193	t.Parallel()
194	server, client, teardown := setupMockedTestServer(t)
195	defer teardown()
196
197	// Add two errors that will be returned by the mock server when the client
198	// is trying to fetch a partial result set. Both errors are retryable.
199	// The errors are not 'sticky' on the mocked server, i.e. once the error
200	// has been returned once, the next call for the same partial result set
201	// will succeed.
202
203	// When the client is fetching the partial result set with resume token 2,
204	// the mock server will respond with an internal error with the message
205	// 'stream terminated by RST_STREAM'. The client will retry the call to get
206	// this partial result set.
207	server.TestSpanner.AddPartialResultSetError(
208		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
209		PartialResultSetExecutionTime{
210			ResumeToken: EncodeResumeToken(2),
211			Err:         status.Errorf(codes.Internal, "stream terminated by RST_STREAM"),
212		},
213	)
214	// When the client is fetching the partial result set with resume token 3,
215	// the mock server will respond with a 'Unavailable' error. The client will
216	// retry the call to get this partial result set.
217	server.TestSpanner.AddPartialResultSetError(
218		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
219		PartialResultSetExecutionTime{
220			ResumeToken: EncodeResumeToken(3),
221			Err:         status.Errorf(codes.Unavailable, "server is unavailable"),
222		},
223	)
224	ctx := context.Background()
225	if err := executeSingerQuery(ctx, client.Single()); err != nil {
226		t.Fatal(err)
227	}
228}
229
230func TestClient_Single_NonRetryableErrorOnPartialResultSet(t *testing.T) {
231	t.Parallel()
232	server, client, teardown := setupMockedTestServer(t)
233	defer teardown()
234
235	// Add two errors that will be returned by the mock server when the client
236	// is trying to fetch a partial result set. The first error is retryable,
237	// the second is not.
238
239	// This error will automatically be retried.
240	server.TestSpanner.AddPartialResultSetError(
241		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
242		PartialResultSetExecutionTime{
243			ResumeToken: EncodeResumeToken(2),
244			Err:         status.Errorf(codes.Internal, "stream terminated by RST_STREAM"),
245		},
246	)
247	// 'Session not found' is not retryable and the error will be returned to
248	// the user.
249	server.TestSpanner.AddPartialResultSetError(
250		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
251		PartialResultSetExecutionTime{
252			ResumeToken: EncodeResumeToken(3),
253			Err:         newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s"),
254		},
255	)
256	ctx := context.Background()
257	err := executeSingerQuery(ctx, client.Single())
258	if status.Code(err) != codes.NotFound {
259		t.Fatalf("Error mismatch:\ngot: %v\nwant: %v", err, codes.NotFound)
260	}
261}
262
263func TestClient_Single_DeadlineExceeded_NoErrors(t *testing.T) {
264	t.Parallel()
265	server, client, teardown := setupMockedTestServer(t)
266	defer teardown()
267	server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql,
268		SimulatedExecutionTime{
269			MinimumExecutionTime: 50 * time.Millisecond,
270		})
271	ctx := context.Background()
272	ctx, cancel := context.WithDeadline(ctx, time.Now().Add(5*time.Millisecond))
273	defer cancel()
274	err := executeSingerQuery(ctx, client.Single())
275	if status.Code(err) != codes.DeadlineExceeded {
276		t.Fatalf("Error mismatch:\ngot: %v\nwant: %v", err, codes.DeadlineExceeded)
277	}
278}
279
280func TestClient_Single_DeadlineExceeded_WithErrors(t *testing.T) {
281	t.Parallel()
282	server, client, teardown := setupMockedTestServer(t)
283	defer teardown()
284	server.TestSpanner.AddPartialResultSetError(
285		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
286		PartialResultSetExecutionTime{
287			ResumeToken: EncodeResumeToken(2),
288			Err:         status.Errorf(codes.Internal, "stream terminated by RST_STREAM"),
289		},
290	)
291	server.TestSpanner.AddPartialResultSetError(
292		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
293		PartialResultSetExecutionTime{
294			ResumeToken:   EncodeResumeToken(3),
295			Err:           status.Errorf(codes.Unavailable, "server is unavailable"),
296			ExecutionTime: 50 * time.Millisecond,
297		},
298	)
299	ctx := context.Background()
300	ctx, cancel := context.WithDeadline(ctx, time.Now().Add(25*time.Millisecond))
301	defer cancel()
302	err := executeSingerQuery(ctx, client.Single())
303	if status.Code(err) != codes.DeadlineExceeded {
304		t.Fatalf("got unexpected error %v, expected DeadlineExceeded", err)
305	}
306}
307
308func TestClient_Single_ContextCanceled_noDeclaredServerErrors(t *testing.T) {
309	t.Parallel()
310	_, client, teardown := setupMockedTestServer(t)
311	defer teardown()
312	ctx := context.Background()
313	ctx, cancel := context.WithCancel(ctx)
314	cancel()
315	err := executeSingerQuery(ctx, client.Single())
316	if status.Code(err) != codes.Canceled {
317		t.Fatalf("got unexpected error %v, expected Canceled", err)
318	}
319}
320
321func TestClient_Single_ContextCanceled_withDeclaredServerErrors(t *testing.T) {
322	t.Parallel()
323	server, client, teardown := setupMockedTestServer(t)
324	defer teardown()
325	server.TestSpanner.AddPartialResultSetError(
326		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
327		PartialResultSetExecutionTime{
328			ResumeToken: EncodeResumeToken(2),
329			Err:         status.Errorf(codes.Internal, "stream terminated by RST_STREAM"),
330		},
331	)
332	server.TestSpanner.AddPartialResultSetError(
333		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
334		PartialResultSetExecutionTime{
335			ResumeToken: EncodeResumeToken(3),
336			Err:         status.Errorf(codes.Unavailable, "server is unavailable"),
337		},
338	)
339	ctx := context.Background()
340	ctx, cancel := context.WithCancel(ctx)
341	defer cancel()
342	f := func(rowCount int64) error {
343		if rowCount == 2 {
344			cancel()
345		}
346		return nil
347	}
348	iter := client.Single().Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
349	defer iter.Stop()
350	err := executeSingerQueryWithRowFunc(ctx, client.Single(), f)
351	if status.Code(err) != codes.Canceled {
352		t.Fatalf("got unexpected error %v, expected Canceled", err)
353	}
354}
355
356func TestClient_ResourceBasedRouting_WithEndpointsReturned(t *testing.T) {
357	os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "true")
358	defer os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "")
359
360	// Create two servers. The base server receives the GetInstance request and
361	// returns the instance endpoint of the target server. The client should contact
362	// the target server after getting the instance endpoint.
363	serverBase, optsBase, serverTeardownBase := NewMockedSpannerInMemTestServerWithAddr(t, "localhost:8081")
364	defer serverTeardownBase()
365	serverTarget, optsTarget, serverTeardownTarget := NewMockedSpannerInMemTestServerWithAddr(t, "localhost:8082")
366	defer serverTeardownTarget()
367
368	// Return the instance endpoint.
369	instanceEndpoint := fmt.Sprintf("%s", optsTarget[0])
370	resps := []proto.Message{&instancepb.Instance{
371		EndpointUris: []string{instanceEndpoint},
372	}}
373	serverBase.TestInstanceAdmin.SetResps(resps)
374
375	ctx := context.Background()
376	formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "some-project", "some-instance", "some-database")
377	client, err := NewClientWithConfig(ctx, formattedDatabase, ClientConfig{}, optsBase...)
378	if err != nil {
379		t.Fatal(err)
380	}
381
382	if err := executeSingerQuery(ctx, client.Single()); err != nil {
383		t.Fatal(err)
384	}
385
386	// The base server should not receive any requests.
387	if _, err := shouldHaveReceived(serverBase.TestSpanner, []interface{}{}); err != nil {
388		t.Fatal(err)
389	}
390
391	// The target server should receive requests.
392	if _, err = shouldHaveReceived(serverTarget.TestSpanner, []interface{}{
393		&sppb.CreateSessionRequest{},
394		&sppb.ExecuteSqlRequest{},
395	}); err != nil {
396		t.Fatal(err)
397	}
398}
399
400func TestClient_ResourceBasedRouting_WithoutEndpointsReturned(t *testing.T) {
401	os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "true")
402	defer os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "")
403
404	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
405	defer serverTeardown()
406
407	// Return an empty list of endpoints.
408	resps := []proto.Message{&instancepb.Instance{
409		EndpointUris: []string{},
410	}}
411	server.TestInstanceAdmin.SetResps(resps)
412
413	ctx := context.Background()
414	formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "some-project", "some-instance", "some-database")
415	client, err := NewClientWithConfig(ctx, formattedDatabase, ClientConfig{}, opts...)
416	if err != nil {
417		t.Fatal(err)
418	}
419
420	if err := executeSingerQuery(ctx, client.Single()); err != nil {
421		t.Fatal(err)
422	}
423
424	// Check if the request goes to the default endpoint.
425	if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{
426		&sppb.CreateSessionRequest{},
427		&sppb.ExecuteSqlRequest{},
428	}); err != nil {
429		t.Fatal(err)
430	}
431}
432
433func TestClient_ResourceBasedRouting_WithPermissionDeniedError(t *testing.T) {
434	os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "true")
435	defer os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "")
436
437	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
438	defer serverTeardown()
439
440	server.TestInstanceAdmin.SetErr(status.Error(codes.PermissionDenied, "Permission Denied"))
441
442	ctx := context.Background()
443	formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "some-project", "some-instance", "some-database")
444	// `PermissionDeniedError` causes a warning message to be logged, which is expected.
445	// We set the output to be discarded to avoid spamming the log.
446	logger := log.New(ioutil.Discard, "", log.LstdFlags)
447	client, err := NewClientWithConfig(ctx, formattedDatabase, ClientConfig{logger: logger}, opts...)
448	if err != nil {
449		t.Fatal(err)
450	}
451
452	if err := executeSingerQuery(ctx, client.Single()); err != nil {
453		t.Fatal(err)
454	}
455
456	// Fallback to use the default endpoint when calling GetInstance() returns
457	// a PermissionDenied error.
458	if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{
459		&sppb.CreateSessionRequest{},
460		&sppb.ExecuteSqlRequest{},
461	}); err != nil {
462		t.Fatal(err)
463	}
464}
465
466func TestClient_ResourceBasedRouting_WithUnavailableError(t *testing.T) {
467	os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "true")
468	defer os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "")
469
470	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
471	defer serverTeardown()
472
473	resps := []proto.Message{&instancepb.Instance{
474		EndpointUris: []string{},
475	}}
476	server.TestInstanceAdmin.SetResps(resps)
477	server.TestInstanceAdmin.SetErr(status.Error(codes.Unavailable, "Temporary unavailable"))
478
479	ctx := context.Background()
480	formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "some-project", "some-instance", "some-database")
481	_, err := NewClientWithConfig(ctx, formattedDatabase, ClientConfig{}, opts...)
482	// The first request will get an error and the server resets the error to nil,
483	// so the next request will be fine. Due to retrying, there is no errors.
484	if err != nil {
485		t.Fatal(err)
486	}
487}
488
489func TestClient_ResourceBasedRouting_WithInvalidArgumentError(t *testing.T) {
490	os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "true")
491	defer os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "")
492
493	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
494	defer serverTeardown()
495
496	server.TestInstanceAdmin.SetErr(status.Error(codes.InvalidArgument, "Invalid argument"))
497
498	ctx := context.Background()
499	formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "some-project", "some-instance", "some-database")
500	_, err := NewClientWithConfig(ctx, formattedDatabase, ClientConfig{}, opts...)
501
502	if status.Code(err) != codes.InvalidArgument {
503		t.Fatalf("got unexpected exception %v, expected InvalidArgument", err)
504	}
505}
506
507func TestClient_Single_QueryOptions(t *testing.T) {
508	for _, tt := range queryOptionsTestCases() {
509		t.Run(tt.name, func(t *testing.T) {
510			if tt.env.Options != nil {
511				os.Setenv("SPANNER_OPTIMIZER_VERSION", tt.env.Options.OptimizerVersion)
512				defer os.Setenv("SPANNER_OPTIMIZER_VERSION", "")
513			}
514
515			ctx := context.Background()
516			server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{QueryOptions: tt.client})
517			defer teardown()
518
519			var iter *RowIterator
520			if tt.query.Options == nil {
521				iter = client.Single().Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
522			} else {
523				iter = client.Single().QueryWithOptions(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums), tt.query)
524			}
525			testQueryOptions(t, iter, server.TestSpanner, tt.want)
526		})
527	}
528}
529
530func testQueryOptions(t *testing.T, iter *RowIterator, server InMemSpannerServer, qo QueryOptions) {
531	defer iter.Stop()
532
533	_, err := iter.Next()
534	if err != nil {
535		t.Fatalf("Failed to read from the iterator: %v", err)
536	}
537
538	checkReqsForQueryOptions(t, server, qo)
539}
540
541func checkReqsForQueryOptions(t *testing.T, server InMemSpannerServer, qo QueryOptions) {
542	reqs := drainRequestsFromServer(server)
543	sqlReqs := []*sppb.ExecuteSqlRequest{}
544
545	for _, req := range reqs {
546		if sqlReq, ok := req.(*sppb.ExecuteSqlRequest); ok {
547			sqlReqs = append(sqlReqs, sqlReq)
548		}
549	}
550
551	if got, want := len(sqlReqs), 1; got != want {
552		t.Fatalf("Length mismatch, got %v, want %v", got, want)
553	}
554
555	reqQueryOptions := sqlReqs[0].QueryOptions
556	if got, want := reqQueryOptions.OptimizerVersion, qo.Options.OptimizerVersion; got != want {
557		t.Fatalf("Optimizer version mismatch, got %v, want %v", got, want)
558	}
559}
560
561func testSingleQuery(t *testing.T, serverError error) error {
562	ctx := context.Background()
563	server, client, teardown := setupMockedTestServer(t)
564	defer teardown()
565	if serverError != nil {
566		server.TestSpanner.SetError(serverError)
567	}
568	return executeSingerQuery(ctx, client.Single())
569}
570
571func executeSingerQuery(ctx context.Context, tx *ReadOnlyTransaction) error {
572	return executeSingerQueryWithRowFunc(ctx, tx, nil)
573}
574
575func executeSingerQueryWithRowFunc(ctx context.Context, tx *ReadOnlyTransaction, f func(rowCount int64) error) error {
576	iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
577	defer iter.Stop()
578	rowCount := int64(0)
579	for {
580		row, err := iter.Next()
581		if err == iterator.Done {
582			break
583		}
584		if err != nil {
585			return err
586		}
587		var singerID, albumID int64
588		var albumTitle string
589		if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
590			return err
591		}
592		rowCount++
593		if f != nil {
594			if err := f(rowCount); err != nil {
595				return err
596			}
597		}
598	}
599	if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
600		return status.Errorf(codes.Internal, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
601	}
602	return nil
603}
604
605func createSimulatedExecutionTimeWithTwoUnavailableErrors(method string) map[string]SimulatedExecutionTime {
606	errors := make([]error, 2)
607	errors[0] = status.Error(codes.Unavailable, "Temporary unavailable")
608	errors[1] = status.Error(codes.Unavailable, "Temporary unavailable")
609	executionTimes := make(map[string]SimulatedExecutionTime)
610	executionTimes[method] = SimulatedExecutionTime{
611		Errors: errors,
612	}
613	return executionTimes
614}
615
616func TestClient_ReadOnlyTransaction(t *testing.T) {
617	t.Parallel()
618	if err := testReadOnlyTransaction(t, make(map[string]SimulatedExecutionTime)); err != nil {
619		t.Fatal(err)
620	}
621}
622
623func TestClient_ReadOnlyTransaction_UnavailableOnSessionCreate(t *testing.T) {
624	t.Parallel()
625	if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(MethodCreateSession)); err != nil {
626		t.Fatal(err)
627	}
628}
629
630func TestClient_ReadOnlyTransaction_UnavailableOnBeginTransaction(t *testing.T) {
631	t.Parallel()
632	if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(MethodBeginTransaction)); err != nil {
633		t.Fatal(err)
634	}
635}
636
637func TestClient_ReadOnlyTransaction_UnavailableOnExecuteStreamingSql(t *testing.T) {
638	t.Parallel()
639	if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(MethodExecuteStreamingSql)); err != nil {
640		t.Fatal(err)
641	}
642}
643
644func TestClient_ReadOnlyTransaction_SessionNotFoundOnExecuteStreamingSql(t *testing.T) {
645	t.Parallel()
646	// Session not found is not retryable for a query on a multi-use read-only
647	// transaction, as we would need to start a new transaction on a new
648	// session.
649	err := testReadOnlyTransaction(t, map[string]SimulatedExecutionTime{
650		MethodExecuteStreamingSql: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
651	})
652	want := toSpannerError(newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s"))
653	if err == nil {
654		t.Fatalf("missing expected error\nGot: nil\nWant: %v", want)
655	}
656	if status.Code(err) != status.Code(want) || !strings.Contains(err.Error(), want.Error()) {
657		t.Fatalf("error mismatch\nGot: %v\nWant: %v", err, want)
658	}
659}
660
661func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndBeginTransaction(t *testing.T) {
662	t.Parallel()
663	exec := map[string]SimulatedExecutionTime{
664		MethodCreateSession:    {Errors: []error{status.Error(codes.Unavailable, "Temporary unavailable")}},
665		MethodBeginTransaction: {Errors: []error{status.Error(codes.Unavailable, "Temporary unavailable")}},
666	}
667	if err := testReadOnlyTransaction(t, exec); err != nil {
668		t.Fatal(err)
669	}
670}
671
672func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndInvalidArgumentOnBeginTransaction(t *testing.T) {
673	t.Parallel()
674	exec := map[string]SimulatedExecutionTime{
675		MethodCreateSession:    {Errors: []error{status.Error(codes.Unavailable, "Temporary unavailable")}},
676		MethodBeginTransaction: {Errors: []error{status.Error(codes.InvalidArgument, "Invalid argument")}},
677	}
678	if err := testReadOnlyTransaction(t, exec); err == nil {
679		t.Fatalf("Missing expected exception")
680	} else if status.Code(err) != codes.InvalidArgument {
681		t.Fatalf("Got unexpected exception: %v", err)
682	}
683}
684
685func TestClient_ReadOnlyTransaction_SessionNotFoundOnBeginTransaction(t *testing.T) {
686	t.Parallel()
687	if err := testReadOnlyTransaction(
688		t,
689		map[string]SimulatedExecutionTime{
690			MethodBeginTransaction: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
691		},
692	); err != nil {
693		t.Fatal(err)
694	}
695}
696
697func TestClient_ReadOnlyTransaction_SessionNotFoundOnBeginTransaction_WithMaxOneSession(t *testing.T) {
698	t.Parallel()
699	server, client, teardown := setupMockedTestServerWithConfig(
700		t,
701		ClientConfig{
702			SessionPoolConfig: SessionPoolConfig{
703				MinOpened: 0,
704				MaxOpened: 1,
705			},
706		})
707	defer teardown()
708	server.TestSpanner.PutExecutionTime(
709		MethodBeginTransaction,
710		SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
711	)
712	tx := client.ReadOnlyTransaction()
713	defer tx.Close()
714	ctx := context.Background()
715	if err := executeSingerQuery(ctx, tx); err != nil {
716		t.Fatal(err)
717	}
718}
719
720func TestClient_ReadOnlyTransaction_QueryOptions(t *testing.T) {
721	for _, tt := range queryOptionsTestCases() {
722		t.Run(tt.name, func(t *testing.T) {
723			if tt.env.Options != nil {
724				os.Setenv("SPANNER_OPTIMIZER_VERSION", tt.env.Options.OptimizerVersion)
725				defer os.Setenv("SPANNER_OPTIMIZER_VERSION", "")
726			}
727
728			ctx := context.Background()
729			server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{QueryOptions: tt.client})
730			defer teardown()
731
732			tx := client.ReadOnlyTransaction()
733			defer tx.Close()
734
735			var iter *RowIterator
736			if tt.query.Options == nil {
737				iter = tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
738			} else {
739				iter = tx.QueryWithOptions(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums), tt.query)
740			}
741			testQueryOptions(t, iter, server.TestSpanner, tt.want)
742		})
743	}
744}
745
746func testReadOnlyTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime) error {
747	server, client, teardown := setupMockedTestServer(t)
748	defer teardown()
749	for method, exec := range executionTimes {
750		server.TestSpanner.PutExecutionTime(method, exec)
751	}
752	tx := client.ReadOnlyTransaction()
753	defer tx.Close()
754	ctx := context.Background()
755	return executeSingerQuery(ctx, tx)
756}
757
758func TestClient_ReadWriteTransaction(t *testing.T) {
759	t.Parallel()
760	if err := testReadWriteTransaction(t, make(map[string]SimulatedExecutionTime), 1); err != nil {
761		t.Fatal(err)
762	}
763}
764
765func TestClient_ReadWriteTransactionCommitAborted(t *testing.T) {
766	t.Parallel()
767	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
768		MethodCommitTransaction: {Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}},
769	}, 2); err != nil {
770		t.Fatal(err)
771	}
772}
773
774func TestClient_ReadWriteTransaction_SessionNotFoundOnCommit(t *testing.T) {
775	t.Parallel()
776	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
777		MethodCommitTransaction: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
778	}, 2); err != nil {
779		t.Fatal(err)
780	}
781}
782
783func TestClient_ReadWriteTransaction_SessionNotFoundOnBeginTransaction(t *testing.T) {
784	t.Parallel()
785	// We expect only 1 attempt, as the 'Session not found' error is already
786	//handled in the session pool where the session is prepared.
787	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
788		MethodBeginTransaction: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
789	}, 1); err != nil {
790		t.Fatal(err)
791	}
792}
793
794func TestClient_ReadWriteTransaction_SessionNotFoundOnBeginTransactionWithEmptySessionPool(t *testing.T) {
795	t.Parallel()
796	// There will be no prepared sessions in the pool, so the error will occur
797	// when the transaction tries to get a session from the pool. This will
798	// also be handled by the session pool, so the transaction itself does not
799	// need to retry, hence the expectedAttempts == 1.
800	if err := testReadWriteTransactionWithConfig(t, ClientConfig{
801		SessionPoolConfig: SessionPoolConfig{WriteSessions: 0.0},
802	}, map[string]SimulatedExecutionTime{
803		MethodBeginTransaction: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
804	}, 1); err != nil {
805		t.Fatal(err)
806	}
807}
808
809func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteStreamingSql(t *testing.T) {
810	t.Parallel()
811	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
812		MethodExecuteStreamingSql: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
813	}, 2); err != nil {
814		t.Fatal(err)
815	}
816}
817
818func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteUpdate(t *testing.T) {
819	t.Parallel()
820
821	server, client, teardown := setupMockedTestServer(t)
822	defer teardown()
823	server.TestSpanner.PutExecutionTime(
824		MethodExecuteSql,
825		SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
826	)
827	ctx := context.Background()
828	var attempts int
829	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
830		attempts++
831		rowCount, err := tx.Update(ctx, NewStatement(UpdateBarSetFoo))
832		if err != nil {
833			return err
834		}
835		if g, w := rowCount, int64(UpdateBarSetFooRowCount); g != w {
836			return status.Errorf(codes.FailedPrecondition, "Row count mismatch\nGot: %v\nWant: %v", g, w)
837		}
838		return nil
839	})
840	if err != nil {
841		t.Fatal(err)
842	}
843	if g, w := attempts, 2; g != w {
844		t.Fatalf("number of attempts mismatch:\nGot%d\nWant:%d", g, w)
845	}
846}
847
848func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteBatchUpdate(t *testing.T) {
849	t.Parallel()
850
851	server, client, teardown := setupMockedTestServer(t)
852	defer teardown()
853	server.TestSpanner.PutExecutionTime(
854		MethodExecuteBatchDml,
855		SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
856	)
857	ctx := context.Background()
858	var attempts int
859	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
860		attempts++
861		rowCounts, err := tx.BatchUpdate(ctx, []Statement{NewStatement(UpdateBarSetFoo)})
862		if err != nil {
863			return err
864		}
865		if g, w := len(rowCounts), 1; g != w {
866			return status.Errorf(codes.FailedPrecondition, "Row counts length mismatch\nGot: %v\nWant: %v", g, w)
867		}
868		if g, w := rowCounts[0], int64(UpdateBarSetFooRowCount); g != w {
869			return status.Errorf(codes.FailedPrecondition, "Row count mismatch\nGot: %v\nWant: %v", g, w)
870		}
871		return nil
872	})
873	if err != nil {
874		t.Fatal(err)
875	}
876	if g, w := attempts, 2; g != w {
877		t.Fatalf("number of attempts mismatch:\nGot%d\nWant:%d", g, w)
878	}
879}
880
881func TestClient_ReadWriteTransaction_Query_QueryOptions(t *testing.T) {
882	for _, tt := range queryOptionsTestCases() {
883		t.Run(tt.name, func(t *testing.T) {
884			if tt.env.Options != nil {
885				os.Setenv("SPANNER_OPTIMIZER_VERSION", tt.env.Options.OptimizerVersion)
886				defer os.Setenv("SPANNER_OPTIMIZER_VERSION", "")
887			}
888
889			ctx := context.Background()
890			server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{QueryOptions: tt.client})
891			defer teardown()
892
893			_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
894				var iter *RowIterator
895				if tt.query.Options == nil {
896					iter = tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
897				} else {
898					iter = tx.QueryWithOptions(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums), tt.query)
899				}
900				testQueryOptions(t, iter, server.TestSpanner, tt.want)
901				return nil
902			})
903			if err != nil {
904				t.Fatal(err)
905			}
906		})
907	}
908}
909
910func TestClient_ReadWriteTransaction_Update_QueryOptions(t *testing.T) {
911	for _, tt := range queryOptionsTestCases() {
912		t.Run(tt.name, func(t *testing.T) {
913			if tt.env.Options != nil {
914				os.Setenv("SPANNER_OPTIMIZER_VERSION", tt.env.Options.OptimizerVersion)
915				defer os.Setenv("SPANNER_OPTIMIZER_VERSION", "")
916			}
917
918			ctx := context.Background()
919			server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{QueryOptions: tt.client})
920			defer teardown()
921
922			_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
923				var rowCount int64
924				var err error
925				if tt.query.Options == nil {
926					rowCount, err = tx.Update(ctx, NewStatement(UpdateBarSetFoo))
927				} else {
928					rowCount, err = tx.UpdateWithOptions(ctx, NewStatement(UpdateBarSetFoo), tt.query)
929				}
930				if got, want := rowCount, int64(5); got != want {
931					t.Fatalf("Incorrect updated row count: got %v, want %v", got, want)
932				}
933				return err
934			})
935			if err != nil {
936				t.Fatalf("Failed to update rows: %v", err)
937			}
938			checkReqsForQueryOptions(t, server.TestSpanner, tt.want)
939		})
940	}
941}
942
943func TestClient_SessionNotFound(t *testing.T) {
944	// Ensure we always have at least one session in the pool.
945	sc := SessionPoolConfig{
946		MinOpened: 1,
947	}
948	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{SessionPoolConfig: sc})
949	defer teardown()
950	ctx := context.Background()
951	for {
952		client.idleSessions.mu.Lock()
953		numSessions := client.idleSessions.idleList.Len()
954		client.idleSessions.mu.Unlock()
955		if numSessions > 0 {
956			break
957		}
958		time.After(time.Millisecond)
959	}
960	// Remove the session from the server without the pool knowing it.
961	_, err := server.TestSpanner.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: client.idleSessions.idleList.Front().Value.(*session).id})
962	if err != nil {
963		t.Fatalf("Failed to delete session unexpectedly: %v", err)
964	}
965
966	_, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
967		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
968		defer iter.Stop()
969		rowCount := int64(0)
970		for {
971			row, err := iter.Next()
972			if err == iterator.Done {
973				break
974			}
975			if err != nil {
976				return err
977			}
978			var singerID, albumID int64
979			var albumTitle string
980			if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
981				return err
982			}
983			rowCount++
984		}
985		if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
986			return spannerErrorf(codes.FailedPrecondition, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
987		}
988		return nil
989	})
990	if err != nil {
991		t.Fatalf("Unexpected error during transaction: %v", err)
992	}
993}
994
995func TestClient_ReadWriteTransactionExecuteStreamingSqlAborted(t *testing.T) {
996	t.Parallel()
997	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
998		MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}},
999	}, 2); err != nil {
1000		t.Fatal(err)
1001	}
1002}
1003
1004func TestClient_ReadWriteTransaction_UnavailableOnBeginTransaction(t *testing.T) {
1005	t.Parallel()
1006	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
1007		MethodBeginTransaction: {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}},
1008	}, 1); err != nil {
1009		t.Fatal(err)
1010	}
1011}
1012
1013func TestClient_ReadWriteTransaction_UnavailableOnBeginAndAbortOnCommit(t *testing.T) {
1014	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
1015		MethodBeginTransaction:  {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}},
1016		MethodCommitTransaction: {Errors: []error{status.Error(codes.Aborted, "Aborted")}},
1017	}, 2); err != nil {
1018		t.Fatal(err)
1019	}
1020}
1021
1022func TestClient_ReadWriteTransaction_UnavailableOnExecuteStreamingSql(t *testing.T) {
1023	t.Parallel()
1024	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
1025		MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}},
1026	}, 1); err != nil {
1027		t.Fatal(err)
1028	}
1029}
1030
1031func TestClient_ReadWriteTransaction_UnavailableOnBeginAndExecuteStreamingSqlAndTwiceAbortOnCommit(t *testing.T) {
1032	t.Parallel()
1033	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
1034		MethodBeginTransaction:    {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}},
1035		MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}},
1036		MethodCommitTransaction:   {Errors: []error{status.Error(codes.Aborted, "Aborted"), status.Error(codes.Aborted, "Aborted")}},
1037	}, 3); err != nil {
1038		t.Fatal(err)
1039	}
1040}
1041
1042func TestClient_ReadWriteTransaction_CommitAborted(t *testing.T) {
1043	t.Parallel()
1044	server, client, teardown := setupMockedTestServer(t)
1045	server.TestSpanner.PutExecutionTime(MethodCommitTransaction, SimulatedExecutionTime{
1046		Errors: []error{status.Error(codes.Aborted, "Aborted")},
1047	})
1048	defer teardown()
1049	ctx := context.Background()
1050	attempts := 0
1051	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1052		attempts++
1053		_, err := tx.Update(ctx, Statement{SQL: UpdateBarSetFoo})
1054		if err != nil {
1055			return err
1056		}
1057		return nil
1058	})
1059	if err != nil {
1060		t.Fatal(err)
1061	}
1062	if g, w := attempts, 2; g != w {
1063		t.Fatalf("attempt count mismatch:\nWant: %v\nGot: %v", w, g)
1064	}
1065}
1066
1067func TestClient_ReadWriteTransaction_DMLAborted(t *testing.T) {
1068	t.Parallel()
1069	server, client, teardown := setupMockedTestServer(t)
1070	server.TestSpanner.PutExecutionTime(MethodExecuteSql, SimulatedExecutionTime{
1071		Errors: []error{status.Error(codes.Aborted, "Aborted")},
1072	})
1073	defer teardown()
1074	ctx := context.Background()
1075	attempts := 0
1076	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1077		attempts++
1078		_, err := tx.Update(ctx, Statement{SQL: UpdateBarSetFoo})
1079		if err != nil {
1080			return err
1081		}
1082		return nil
1083	})
1084	if err != nil {
1085		t.Fatal(err)
1086	}
1087	if g, w := attempts, 2; g != w {
1088		t.Fatalf("attempt count mismatch:\nWant: %v\nGot: %v", w, g)
1089	}
1090}
1091
1092func TestClient_ReadWriteTransaction_BatchDMLAborted(t *testing.T) {
1093	t.Parallel()
1094	server, client, teardown := setupMockedTestServer(t)
1095	server.TestSpanner.PutExecutionTime(MethodExecuteBatchDml, SimulatedExecutionTime{
1096		Errors: []error{status.Error(codes.Aborted, "Aborted")},
1097	})
1098	defer teardown()
1099	ctx := context.Background()
1100	attempts := 0
1101	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1102		attempts++
1103		_, err := tx.BatchUpdate(ctx, []Statement{{SQL: UpdateBarSetFoo}})
1104		if err != nil {
1105			return err
1106		}
1107		return nil
1108	})
1109	if err != nil {
1110		t.Fatal(err)
1111	}
1112	if g, w := attempts, 2; g != w {
1113		t.Fatalf("attempt count mismatch:\nWant: %v\nGot: %v", w, g)
1114	}
1115}
1116
1117func TestClient_ReadWriteTransaction_QueryAborted(t *testing.T) {
1118	t.Parallel()
1119	server, client, teardown := setupMockedTestServer(t)
1120	server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql, SimulatedExecutionTime{
1121		Errors: []error{status.Error(codes.Aborted, "Aborted")},
1122	})
1123	defer teardown()
1124	ctx := context.Background()
1125	attempts := 0
1126	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1127		attempts++
1128		iter := tx.Query(ctx, Statement{SQL: SelectFooFromBar})
1129		defer iter.Stop()
1130		for {
1131			_, err := iter.Next()
1132			if err == iterator.Done {
1133				break
1134			}
1135			if err != nil {
1136				return err
1137			}
1138		}
1139		return nil
1140	})
1141	if err != nil {
1142		t.Fatal(err)
1143	}
1144	if g, w := attempts, 2; g != w {
1145		t.Fatalf("attempt count mismatch:\nWant: %v\nGot: %v", w, g)
1146	}
1147}
1148
1149func TestClient_ReadWriteTransaction_AbortedOnExecuteStreamingSqlAndCommit(t *testing.T) {
1150	t.Parallel()
1151	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
1152		MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Aborted, "Aborted")}},
1153		MethodCommitTransaction:   {Errors: []error{status.Error(codes.Aborted, "Aborted"), status.Error(codes.Aborted, "Aborted")}},
1154	}, 4); err != nil {
1155		t.Fatal(err)
1156	}
1157}
1158
1159func TestClient_ReadWriteTransactionCommitAbortedAndUnavailable(t *testing.T) {
1160	t.Parallel()
1161	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
1162		MethodCommitTransaction: {
1163			Errors: []error{
1164				status.Error(codes.Aborted, "Transaction aborted"),
1165				status.Error(codes.Unavailable, "Unavailable"),
1166			},
1167		},
1168	}, 2); err != nil {
1169		t.Fatal(err)
1170	}
1171}
1172
1173func TestClient_ReadWriteTransactionCommitAlreadyExists(t *testing.T) {
1174	t.Parallel()
1175	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
1176		MethodCommitTransaction: {Errors: []error{status.Error(codes.AlreadyExists, "A row with this key already exists")}},
1177	}, 1); err != nil {
1178		if status.Code(err) != codes.AlreadyExists {
1179			t.Fatalf("Got unexpected error %v, expected %v", err, codes.AlreadyExists)
1180		}
1181	} else {
1182		t.Fatalf("Missing expected exception")
1183	}
1184}
1185
1186func testReadWriteTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime, expectedAttempts int) error {
1187	return testReadWriteTransactionWithConfig(t, ClientConfig{SessionPoolConfig: DefaultSessionPoolConfig}, executionTimes, expectedAttempts)
1188}
1189
1190func testReadWriteTransactionWithConfig(t *testing.T, config ClientConfig, executionTimes map[string]SimulatedExecutionTime, expectedAttempts int) error {
1191	server, client, teardown := setupMockedTestServer(t)
1192	defer teardown()
1193	for method, exec := range executionTimes {
1194		server.TestSpanner.PutExecutionTime(method, exec)
1195	}
1196	ctx := context.Background()
1197	var attempts int
1198	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1199		attempts++
1200		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
1201		defer iter.Stop()
1202		rowCount := int64(0)
1203		for {
1204			row, err := iter.Next()
1205			if err == iterator.Done {
1206				break
1207			}
1208			if err != nil {
1209				return err
1210			}
1211			var singerID, albumID int64
1212			var albumTitle string
1213			if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
1214				return err
1215			}
1216			rowCount++
1217		}
1218		if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
1219			return status.Errorf(codes.FailedPrecondition, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
1220		}
1221		return nil
1222	})
1223	if err != nil {
1224		return err
1225	}
1226	if expectedAttempts != attempts {
1227		t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts)
1228	}
1229	return nil
1230}
1231
1232func TestClient_ApplyAtLeastOnce(t *testing.T) {
1233	t.Parallel()
1234	server, client, teardown := setupMockedTestServer(t)
1235	defer teardown()
1236	ms := []*Mutation{
1237		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
1238		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
1239	}
1240	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
1241		SimulatedExecutionTime{
1242			Errors: []error{status.Error(codes.Aborted, "Transaction aborted")},
1243		})
1244	_, err := client.Apply(context.Background(), ms, ApplyAtLeastOnce())
1245	if err != nil {
1246		t.Fatal(err)
1247	}
1248}
1249
1250func TestClient_ApplyAtLeastOnceReuseSession(t *testing.T) {
1251	t.Parallel()
1252	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
1253		SessionPoolConfig: SessionPoolConfig{
1254			MinOpened:           0,
1255			WriteSessions:       0.0,
1256			TrackSessionHandles: true,
1257		},
1258	})
1259	defer teardown()
1260	ms := []*Mutation{
1261		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
1262		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
1263	}
1264	for i := 0; i < 10; i++ {
1265		_, err := client.Apply(context.Background(), ms, ApplyAtLeastOnce())
1266		if err != nil {
1267			t.Fatal(err)
1268		}
1269		if g, w := client.idleSessions.idleList.Len(), 1; g != w {
1270			t.Fatalf("idle session count mismatch:\nGot: %v\nWant: %v", g, w)
1271		}
1272		if g, w := len(server.TestSpanner.DumpSessions()), 1; g != w {
1273			t.Fatalf("server session count mismatch:\nGot: %v\nWant: %v", g, w)
1274		}
1275	}
1276	// There should be no sessions marked as checked out.
1277	client.idleSessions.mu.Lock()
1278	g, w := client.idleSessions.trackedSessionHandles.Len(), 0
1279	client.idleSessions.mu.Unlock()
1280	if g != w {
1281		t.Fatalf("checked out sessions count mismatch:\nGot: %v\nWant: %v", g, w)
1282	}
1283}
1284
1285func TestClient_ApplyAtLeastOnceInvalidArgument(t *testing.T) {
1286	t.Parallel()
1287	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
1288		SessionPoolConfig: SessionPoolConfig{
1289			MinOpened:           0,
1290			WriteSessions:       0.0,
1291			TrackSessionHandles: true,
1292		},
1293	})
1294	defer teardown()
1295	ms := []*Mutation{
1296		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
1297		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
1298	}
1299	for i := 0; i < 10; i++ {
1300		server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
1301			SimulatedExecutionTime{
1302				Errors: []error{status.Error(codes.InvalidArgument, "Invalid data")},
1303			})
1304		_, err := client.Apply(context.Background(), ms, ApplyAtLeastOnce())
1305		if status.Code(err) != codes.InvalidArgument {
1306			t.Fatal(err)
1307		}
1308		if g, w := client.idleSessions.idleList.Len(), 1; g != w {
1309			t.Fatalf("idle session count mismatch:\nGot: %v\nWant: %v", g, w)
1310		}
1311		if g, w := len(server.TestSpanner.DumpSessions()), 1; g != w {
1312			t.Fatalf("server session count mismatch:\nGot: %v\nWant: %v", g, w)
1313		}
1314	}
1315	// There should be no sessions marked as checked out.
1316	client.idleSessions.mu.Lock()
1317	g, w := client.idleSessions.trackedSessionHandles.Len(), 0
1318	client.idleSessions.mu.Unlock()
1319	if g != w {
1320		t.Fatalf("checked out sessions count mismatch:\nGot: %v\nWant: %v", g, w)
1321	}
1322}
1323
1324func TestReadWriteTransaction_ErrUnexpectedEOF(t *testing.T) {
1325	t.Parallel()
1326	_, client, teardown := setupMockedTestServer(t)
1327	defer teardown()
1328	ctx := context.Background()
1329	var attempts int
1330	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1331		attempts++
1332		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
1333		defer iter.Stop()
1334		for {
1335			row, err := iter.Next()
1336			if err == iterator.Done {
1337				break
1338			}
1339			if err != nil {
1340				return err
1341			}
1342			var singerID, albumID int64
1343			var albumTitle string
1344			if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
1345				return err
1346			}
1347		}
1348		return io.ErrUnexpectedEOF
1349	})
1350	if err != io.ErrUnexpectedEOF {
1351		t.Fatalf("Missing expected error %v, got %v", io.ErrUnexpectedEOF, err)
1352	}
1353	if attempts != 1 {
1354		t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, 1)
1355	}
1356}
1357
1358func TestReadWriteTransaction_WrapError(t *testing.T) {
1359	t.Parallel()
1360	server, client, teardown := setupMockedTestServer(t)
1361	defer teardown()
1362	// Abort the transaction on both the query as well as commit.
1363	// The first abort error will be wrapped. The client will unwrap the cause
1364	// of the error and retry the transaction. The aborted error on commit
1365	// will not be wrapped, but will also be recognized by the client as an
1366	// abort that should be retried.
1367	server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql,
1368		SimulatedExecutionTime{
1369			Errors: []error{status.Error(codes.Aborted, "Transaction aborted")},
1370		})
1371	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
1372		SimulatedExecutionTime{
1373			Errors: []error{status.Error(codes.Aborted, "Transaction aborted")},
1374		})
1375	msg := "query failed"
1376	numAttempts := 0
1377	ctx := context.Background()
1378	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1379		numAttempts++
1380		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
1381		defer iter.Stop()
1382		for {
1383			_, err := iter.Next()
1384			if err == iterator.Done {
1385				break
1386			}
1387			if err != nil {
1388				// Wrap the error in another error that implements the
1389				// (xerrors|errors).Wrapper interface.
1390				return &wrappedTestError{err, msg}
1391			}
1392		}
1393		return nil
1394	})
1395	if err != nil {
1396		t.Fatalf("Unexpected error\nGot: %v\nWant: nil", err)
1397	}
1398	if g, w := numAttempts, 3; g != w {
1399		t.Fatalf("Number of transaction attempts mismatch\nGot: %d\nWant: %d", w, w)
1400	}
1401
1402	// Execute a transaction that returns a non-retryable error that is
1403	// wrapped in a custom error. The transaction should return the custom
1404	// error.
1405	server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql,
1406		SimulatedExecutionTime{
1407			Errors: []error{status.Error(codes.NotFound, "Table not found")},
1408		})
1409	_, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1410		numAttempts++
1411		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
1412		defer iter.Stop()
1413		for {
1414			_, err := iter.Next()
1415			if err == iterator.Done {
1416				break
1417			}
1418			if err != nil {
1419				// Wrap the error in another error that implements the
1420				// (xerrors|errors).Wrapper interface.
1421				return &wrappedTestError{err, msg}
1422			}
1423		}
1424		return nil
1425	})
1426	if err == nil || err.Error() != msg {
1427		t.Fatalf("Unexpected error\nGot: %v\nWant: %v", err, msg)
1428	}
1429}
1430
1431func TestReadWriteTransaction_WrapSessionNotFoundError(t *testing.T) {
1432	t.Parallel()
1433	server, client, teardown := setupMockedTestServer(t)
1434	defer teardown()
1435	server.TestSpanner.PutExecutionTime(MethodBeginTransaction,
1436		SimulatedExecutionTime{
1437			Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")},
1438		})
1439	server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql,
1440		SimulatedExecutionTime{
1441			Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")},
1442		})
1443	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
1444		SimulatedExecutionTime{
1445			Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")},
1446		})
1447	msg := "query failed"
1448	numAttempts := 0
1449	ctx := context.Background()
1450	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1451		numAttempts++
1452		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
1453		defer iter.Stop()
1454		for {
1455			_, err := iter.Next()
1456			if err == iterator.Done {
1457				break
1458			}
1459			if err != nil {
1460				// Wrap the error in another error that implements the
1461				// (xerrors|errors).Wrapper interface.
1462				return &wrappedTestError{err, msg}
1463			}
1464		}
1465		return nil
1466	})
1467	if err != nil {
1468		t.Fatalf("Unexpected error\nGot: %v\nWant: nil", err)
1469	}
1470	// We want 3 attempts. The 'Session not found' error on BeginTransaction
1471	// will not retry the entire transaction, which means that we will have two
1472	// failed attempts and then a successful attempt.
1473	if g, w := numAttempts, 3; g != w {
1474		t.Fatalf("Number of transaction attempts mismatch\nGot: %d\nWant: %d", g, w)
1475	}
1476}
1477
1478func TestClient_WriteStructWithPointers(t *testing.T) {
1479	t.Parallel()
1480	server, client, teardown := setupMockedTestServer(t)
1481	defer teardown()
1482	type T struct {
1483		ID    int64
1484		Col1  *string
1485		Col2  []*string
1486		Col3  *bool
1487		Col4  []*bool
1488		Col5  *int64
1489		Col6  []*int64
1490		Col7  *float64
1491		Col8  []*float64
1492		Col9  *time.Time
1493		Col10 []*time.Time
1494		Col11 *civil.Date
1495		Col12 []*civil.Date
1496	}
1497	t1 := T{
1498		ID:    1,
1499		Col2:  []*string{nil},
1500		Col4:  []*bool{nil},
1501		Col6:  []*int64{nil},
1502		Col8:  []*float64{nil},
1503		Col10: []*time.Time{nil},
1504		Col12: []*civil.Date{nil},
1505	}
1506	s := "foo"
1507	b := true
1508	i := int64(100)
1509	f := 3.14
1510	tm := time.Now()
1511	d := civil.DateOf(time.Now())
1512	t2 := T{
1513		ID:    2,
1514		Col1:  &s,
1515		Col2:  []*string{&s},
1516		Col3:  &b,
1517		Col4:  []*bool{&b},
1518		Col5:  &i,
1519		Col6:  []*int64{&i},
1520		Col7:  &f,
1521		Col8:  []*float64{&f},
1522		Col9:  &tm,
1523		Col10: []*time.Time{&tm},
1524		Col11: &d,
1525		Col12: []*civil.Date{&d},
1526	}
1527	m1, err := InsertStruct("Tab", &t1)
1528	if err != nil {
1529		t.Fatal(err)
1530	}
1531	m2, err := InsertStruct("Tab", &t2)
1532	if err != nil {
1533		t.Fatal(err)
1534	}
1535	_, err = client.Apply(context.Background(), []*Mutation{m1, m2})
1536	if err != nil {
1537		t.Fatal(err)
1538	}
1539	requests := drainRequestsFromServer(server.TestSpanner)
1540	for _, req := range requests {
1541		if commit, ok := req.(*sppb.CommitRequest); ok {
1542			if g, w := len(commit.Mutations), 2; w != g {
1543				t.Fatalf("mutation count mismatch\nGot: %v\nWant: %v", g, w)
1544			}
1545			insert := commit.Mutations[0].GetInsert()
1546			// The first insert should contain NULL values and arrays
1547			// containing exactly one NULL element.
1548			for i := 1; i < len(insert.Values[0].Values); i += 2 {
1549				// The non-array columns should contain NULL values.
1550				g, w := insert.Values[0].Values[i].GetKind(), &structpb.Value_NullValue{}
1551				if _, ok := g.(*structpb.Value_NullValue); !ok {
1552					t.Fatalf("type mismatch\nGot: %v\nWant: %v", g, w)
1553				}
1554				// The array columns should not be NULL.
1555				g, wList := insert.Values[0].Values[i+1].GetKind(), &structpb.Value_ListValue{}
1556				if _, ok := g.(*structpb.Value_ListValue); !ok {
1557					t.Fatalf("type mismatch\nGot: %v\nWant: %v", g, wList)
1558				}
1559				// The array should contain 1 NULL value.
1560				if gLength, wLength := len(insert.Values[0].Values[i+1].GetListValue().Values), 1; gLength != wLength {
1561					t.Fatalf("list value length mismatch\nGot: %v\nWant: %v", gLength, wLength)
1562				}
1563				g, w = insert.Values[0].Values[i+1].GetListValue().Values[0].GetKind(), &structpb.Value_NullValue{}
1564				if _, ok := g.(*structpb.Value_NullValue); !ok {
1565					t.Fatalf("type mismatch\nGot: %v\nWant: %v", g, w)
1566				}
1567			}
1568
1569			// The second insert should contain all non-NULL values.
1570			insert = commit.Mutations[1].GetInsert()
1571			for i := 1; i < len(insert.Values[0].Values); i += 2 {
1572				// The non-array columns should contain non-NULL values.
1573				g := insert.Values[0].Values[i].GetKind()
1574				if _, ok := g.(*structpb.Value_NullValue); ok {
1575					t.Fatalf("type mismatch\nGot: %v\nWant: non-NULL value", g)
1576				}
1577				// The array columns should also be non-NULL.
1578				g, wList := insert.Values[0].Values[i+1].GetKind(), &structpb.Value_ListValue{}
1579				if _, ok := g.(*structpb.Value_ListValue); !ok {
1580					t.Fatalf("type mismatch\nGot: %v\nWant: %v", g, wList)
1581				}
1582				// The array should contain exactly 1 non-NULL value.
1583				if gLength, wLength := len(insert.Values[0].Values[i+1].GetListValue().Values), 1; gLength != wLength {
1584					t.Fatalf("list value length mismatch\nGot: %v\nWant: %v", gLength, wLength)
1585				}
1586				g = insert.Values[0].Values[i+1].GetListValue().Values[0].GetKind()
1587				if _, ok := g.(*structpb.Value_NullValue); ok {
1588					t.Fatalf("type mismatch\nGot: %v\nWant: non-NULL value", g)
1589				}
1590			}
1591		}
1592	}
1593}
1594
1595func TestReadWriteTransaction_ContextTimeoutDuringDuringCommit(t *testing.T) {
1596	t.Parallel()
1597	server, client, teardown := setupMockedTestServer(t)
1598	defer teardown()
1599	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
1600		SimulatedExecutionTime{
1601			MinimumExecutionTime: time.Minute,
1602		})
1603	ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
1604	defer cancel()
1605	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1606		tx.BufferWrite([]*Mutation{Insert("FOO", []string{"ID", "NAME"}, []interface{}{int64(1), "bar"})})
1607		return nil
1608	})
1609	errContext, _ := context.WithTimeout(context.Background(), -time.Second)
1610	w := toSpannerErrorWithCommitInfo(errContext.Err(), true).(*Error)
1611	var se *Error
1612	if !errorAs(err, &se) {
1613		t.Fatalf("Error mismatch\nGot: %v\nWant: %v", err, w)
1614	}
1615	if se.GRPCStatus().Code() != w.GRPCStatus().Code() {
1616		t.Fatalf("Error status mismatch:\nGot: %v\nWant: %v", se.GRPCStatus(), w.GRPCStatus())
1617	}
1618	if se.Error() != w.Error() {
1619		t.Fatalf("Error message mismatch:\nGot %s\nWant: %s", se.Error(), w.Error())
1620	}
1621	var outcome *TransactionOutcomeUnknownError
1622	if !errorAs(err, &outcome) {
1623		t.Fatalf("Missing wrapped TransactionOutcomeUnknownError error")
1624	}
1625}
1626
1627func TestFailedCommit_NoRollback(t *testing.T) {
1628	t.Parallel()
1629	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
1630		SessionPoolConfig: SessionPoolConfig{
1631			MinOpened:     0,
1632			MaxOpened:     1,
1633			WriteSessions: 0,
1634		},
1635	})
1636	defer teardown()
1637	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
1638		SimulatedExecutionTime{
1639			Errors: []error{status.Errorf(codes.InvalidArgument, "Invalid mutations")},
1640		})
1641	_, err := client.Apply(context.Background(), []*Mutation{
1642		Insert("FOO", []string{"ID", "BAR"}, []interface{}{1, "value"}),
1643	})
1644	if got, want := status.Convert(err).Code(), codes.InvalidArgument; got != want {
1645		t.Fatalf("Error mismatch\nGot: %v\nWant: %v", got, want)
1646	}
1647	// The failed commit should not trigger a rollback after the commit.
1648	if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{
1649		&sppb.CreateSessionRequest{},
1650		&sppb.BeginTransactionRequest{},
1651		&sppb.CommitRequest{},
1652	}); err != nil {
1653		t.Fatalf("Received RPCs mismatch: %v", err)
1654	}
1655}
1656
1657func TestFailedUpdate_ShouldRollback(t *testing.T) {
1658	t.Parallel()
1659	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
1660		SessionPoolConfig: SessionPoolConfig{
1661			MinOpened:     0,
1662			MaxOpened:     1,
1663			WriteSessions: 0,
1664		},
1665	})
1666	defer teardown()
1667	server.TestSpanner.PutExecutionTime(MethodExecuteSql,
1668		SimulatedExecutionTime{
1669			Errors: []error{status.Errorf(codes.InvalidArgument, "Invalid update")},
1670		})
1671	_, err := client.ReadWriteTransaction(context.Background(), func(ctx context.Context, tx *ReadWriteTransaction) error {
1672		_, err := tx.Update(ctx, NewStatement("UPDATE FOO SET BAR='value' WHERE ID=1"))
1673		return err
1674	})
1675	if got, want := status.Convert(err).Code(), codes.InvalidArgument; got != want {
1676		t.Fatalf("Error mismatch\nGot: %v\nWant: %v", got, want)
1677	}
1678	// The failed update should trigger a rollback.
1679	if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{
1680		&sppb.CreateSessionRequest{},
1681		&sppb.BeginTransactionRequest{},
1682		&sppb.ExecuteSqlRequest{},
1683		&sppb.RollbackRequest{},
1684	}); err != nil {
1685		t.Fatalf("Received RPCs mismatch: %v", err)
1686	}
1687}
1688
1689func TestClient_NumChannels(t *testing.T) {
1690	t.Parallel()
1691
1692	configuredNumChannels := 8
1693	_, client, teardown := setupMockedTestServerWithConfig(
1694		t,
1695		ClientConfig{NumChannels: configuredNumChannels},
1696	)
1697	defer teardown()
1698	if g, w := client.sc.connPool.Num(), configuredNumChannels; g != w {
1699		t.Fatalf("NumChannels mismatch\nGot: %v\nWant: %v", g, w)
1700	}
1701}
1702
1703func TestClient_WithGRPCConnectionPool(t *testing.T) {
1704	t.Parallel()
1705
1706	configuredConnPool := 8
1707	_, client, teardown := setupMockedTestServerWithConfigAndClientOptions(
1708		t,
1709		ClientConfig{},
1710		[]option.ClientOption{option.WithGRPCConnectionPool(configuredConnPool)},
1711	)
1712	defer teardown()
1713	if g, w := client.sc.connPool.Num(), configuredConnPool; g != w {
1714		t.Fatalf("NumChannels mismatch\nGot: %v\nWant: %v", g, w)
1715	}
1716}
1717
1718func TestClient_WithGRPCConnectionPoolAndNumChannels(t *testing.T) {
1719	t.Parallel()
1720
1721	configuredNumChannels := 8
1722	configuredConnPool := 8
1723	_, client, teardown := setupMockedTestServerWithConfigAndClientOptions(
1724		t,
1725		ClientConfig{NumChannels: configuredNumChannels},
1726		[]option.ClientOption{option.WithGRPCConnectionPool(configuredConnPool)},
1727	)
1728	defer teardown()
1729	if g, w := client.sc.connPool.Num(), configuredConnPool; g != w {
1730		t.Fatalf("NumChannels mismatch\nGot: %v\nWant: %v", g, w)
1731	}
1732}
1733
1734func TestClient_WithGRPCConnectionPoolAndNumChannels_Misconfigured(t *testing.T) {
1735	t.Parallel()
1736
1737	// Deliberately misconfigure NumChannels and ConnPool.
1738	configuredNumChannels := 8
1739	configuredConnPool := 16
1740	_, err := NewClientWithConfig(
1741		context.Background(),
1742		"projects/p/instances/i/databases/d",
1743		ClientConfig{NumChannels: configuredNumChannels},
1744		option.WithGRPCConnectionPool(configuredConnPool),
1745	)
1746	msg := "Connection pool mismatch:"
1747	if err == nil {
1748		t.Fatalf("Error mismatch\nGot: nil\nWant: %s", msg)
1749	}
1750	var se *Error
1751	if ok := errorAs(err, &se); !ok {
1752		t.Fatalf("Error mismatch\nGot: %v\nWant: An instance of a Spanner error", err)
1753	}
1754	if g, w := se.GRPCStatus().Code(), codes.InvalidArgument; g != w {
1755		t.Fatalf("Error code mismatch\nGot: %v\nWant: %v", g, w)
1756	}
1757	if !strings.Contains(se.Error(), msg) {
1758		t.Fatalf("Error message mismatch\nGot: %s\nWant: %s", se.Error(), msg)
1759	}
1760}
1761
1762func TestBatchReadOnlyTransaction_QueryOptions(t *testing.T) {
1763	ctx := context.Background()
1764	qo := QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}}
1765	_, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{QueryOptions: qo})
1766	defer teardown()
1767
1768	txn, err := client.BatchReadOnlyTransaction(ctx, StrongRead())
1769	if err != nil {
1770		t.Fatal(err)
1771	}
1772	defer txn.Cleanup(ctx)
1773
1774	if txn.qo != qo {
1775		t.Fatalf("Query options are mismatched: got %v, want %v", txn.qo, qo)
1776	}
1777}
1778
1779func TestBatchReadOnlyTransactionFromID_QueryOptions(t *testing.T) {
1780	qo := QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}}
1781	_, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{QueryOptions: qo})
1782	defer teardown()
1783
1784	txn := client.BatchReadOnlyTransactionFromID(BatchReadOnlyTransactionID{})
1785
1786	if txn.qo != qo {
1787		t.Fatalf("Query options are mismatched: got %v, want %v", txn.qo, qo)
1788	}
1789}
1790
1791type QueryOptionsTestCase struct {
1792	name   string
1793	client QueryOptions
1794	env    QueryOptions
1795	query  QueryOptions
1796	want   QueryOptions
1797}
1798
1799func queryOptionsTestCases() []QueryOptionsTestCase {
1800	return []QueryOptionsTestCase{
1801		{
1802			"Client level",
1803			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}},
1804			QueryOptions{Options: nil},
1805			QueryOptions{Options: nil},
1806			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}},
1807		},
1808		{
1809			"Environment level",
1810			QueryOptions{Options: nil},
1811			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}},
1812			QueryOptions{Options: nil},
1813			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}},
1814		},
1815		{
1816			"Query level",
1817			QueryOptions{Options: nil},
1818			QueryOptions{Options: nil},
1819			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}},
1820			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}},
1821		},
1822		{
1823			"Environment level has precedence",
1824			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}},
1825			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "2"}},
1826			QueryOptions{Options: nil},
1827			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "2"}},
1828		},
1829		{
1830			"Query level has precedence than client level",
1831			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}},
1832			QueryOptions{Options: nil},
1833			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "3"}},
1834			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "3"}},
1835		},
1836		{
1837			"Query level has highest precedence",
1838			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "1"}},
1839			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "2"}},
1840			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "3"}},
1841			QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{OptimizerVersion: "3"}},
1842		},
1843	}
1844}
1845